tatk.policy.gdpl package¶
Submodules¶
tatk.policy.gdpl.estimator module¶
@author: truthless
- 
class 
tatk.policy.gdpl.estimator.AIRL(gamma, h_dim, s_dim, a_dim)¶ Bases:
torch.nn.modules.module.Modulelabel: 1 for real, 0 for generated
- 
__init__(gamma, h_dim, s_dim, a_dim)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
- 
forward(s, a, next_s)¶ - Parameters
 s – [b, s_dim]
a – [b, a_dim]
next_s – [b, s_dim]
- Returns
 [b, 1]
- 
 
- 
class 
tatk.policy.gdpl.estimator.ActEstimatorDataLoaderMultiWoz¶ Bases:
tatk.policy.mle.multiwoz.loader.ActMLEPolicyDataLoaderMultiWoz- 
__init__()¶ Initialize self. See help(type(self)) for accurate signature.
- 
create_dataset_irl(part, batchsz)¶ 
- 
 
- 
class 
tatk.policy.gdpl.estimator.RewardEstimator(vector, pretrain=False)¶ Bases:
object- 
__init__(vector, pretrain=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
estimate(s, a, next_s, log_pi)¶ infer the reward of state action pair with the estimator
- 
irl_loop(data_real, data_gen)¶ 
- 
kl_divergence(mu, logvar, istrain)¶ 
- 
load_irl(filename)¶ 
- 
save_irl(directory, epoch)¶ 
- 
test_irl(batch, epoch, best)¶ 
- 
train_irl(batch, epoch)¶ 
- 
update_irl(inputs, batchsz, epoch)¶ train the reward estimator (together with encoder) using cross entropy loss (real, mixed, generated) Args:
inputs: (s, a, next_s)
- 
 
tatk.policy.gdpl.gdpl module¶
- 
class 
tatk.policy.gdpl.gdpl.GDPL(is_train=False, dataset='Multiwoz')¶ Bases:
tatk.policy.policy.Policy- 
__init__(is_train=False, dataset='Multiwoz')¶ Initialize self. See help(type(self)) for accurate signature.
- 
est_adv(r, v, mask)¶ we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param v: estimated value, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: A(s, a), V-target(s), both Tensor
- 
init_session()¶ Restore after one session
- 
load(filename)¶ 
- 
predict(state)¶ Predict an system action given state. Args:
state (dict): Dialog state. Please refer to util/state.py
- Returns:
 action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, …})
- 
save(directory, epoch)¶ 
- 
update(epoch, batchsz, s, a, next_s, mask, rewarder)¶ 
-