tatk.policy.gdpl package

Submodules

tatk.policy.gdpl.estimator module

@author: truthless

class tatk.policy.gdpl.estimator.AIRL(gamma, h_dim, s_dim, a_dim)

Bases: torch.nn.modules.module.Module

label: 1 for real, 0 for generated

__init__(gamma, h_dim, s_dim, a_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(s, a, next_s)
Parameters
  • s – [b, s_dim]

  • a – [b, a_dim]

  • next_s – [b, s_dim]

Returns

[b, 1]

class tatk.policy.gdpl.estimator.ActEstimatorDataLoaderMultiWoz

Bases: tatk.policy.mle.multiwoz.loader.ActMLEPolicyDataLoaderMultiWoz

__init__()

Initialize self. See help(type(self)) for accurate signature.

create_dataset_irl(part, batchsz)
class tatk.policy.gdpl.estimator.RewardEstimator(vector, pretrain=False)

Bases: object

__init__(vector, pretrain=False)

Initialize self. See help(type(self)) for accurate signature.

estimate(s, a, next_s, log_pi)

infer the reward of state action pair with the estimator

irl_loop(data_real, data_gen)
kl_divergence(mu, logvar, istrain)
load_irl(filename)
save_irl(directory, epoch)
test_irl(batch, epoch, best)
train_irl(batch, epoch)
update_irl(inputs, batchsz, epoch)

train the reward estimator (together with encoder) using cross entropy loss (real, mixed, generated) Args:

inputs: (s, a, next_s)

tatk.policy.gdpl.gdpl module

class tatk.policy.gdpl.gdpl.GDPL(is_train=False, dataset='Multiwoz')

Bases: tatk.policy.policy.Policy

__init__(is_train=False, dataset='Multiwoz')

Initialize self. See help(type(self)) for accurate signature.

est_adv(r, v, mask)

we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param v: estimated value, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: A(s, a), V-target(s), both Tensor

init_session()

Restore after one session

load(filename)
predict(state)

Predict an system action given state. Args:

state (dict): Dialog state. Please refer to util/state.py

Returns:

action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, …})

save(directory, epoch)
update(epoch, batchsz, s, a, next_s, mask, rewarder)