convlab2.policy.gdpl package¶
Subpackages¶
Submodules¶
convlab2.policy.gdpl.estimator module¶
@author: truthless
-
class
convlab2.policy.gdpl.estimator.
AIRL
(gamma, h_dim, s_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
label: 1 for real, 0 for generated
-
forward
(s, a, next_s)¶ - Parameters
s – [b, s_dim]
a – [b, a_dim]
next_s – [b, s_dim]
- Returns
[b, 1]
-
-
class
convlab2.policy.gdpl.estimator.
ActEstimatorDataLoaderMultiWoz
¶ Bases:
convlab2.policy.mle.multiwoz.loader.ActMLEPolicyDataLoaderMultiWoz
-
create_dataset_irl
(part, batchsz)¶
-
-
class
convlab2.policy.gdpl.estimator.
RewardEstimator
(vector, pretrain=False)¶ Bases:
object
-
estimate
(s, a, next_s, log_pi)¶ infer the reward of state action pair with the estimator
-
irl_loop
(data_real, data_gen)¶
-
kl_divergence
(mu, logvar, istrain)¶
-
load_irl
(filename)¶
-
save_irl
(directory, epoch)¶
-
test_irl
(batch, epoch, best)¶
-
train_irl
(batch, epoch)¶
-
update_irl
(inputs, batchsz, epoch)¶ train the reward estimator (together with encoder) using cross entropy loss (real, mixed, generated) Args:
inputs: (s, a, next_s)
-
convlab2.policy.gdpl.gdpl module¶
-
class
convlab2.policy.gdpl.gdpl.
GDPL
(is_train=False, dataset='Multiwoz')¶ Bases:
convlab2.policy.policy.Policy
-
est_adv
(r, v, mask)¶ we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param v: estimated value, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: A(s, a), V-target(s), both Tensor
-
classmethod
from_pretrained
(archive_file='', model_file='https://convlab.blob.core.windows.net/convlab-2/gdpl_policy_multiwoz.zip', is_train=False, dataset='Multiwoz')¶
-
init_session
()¶ Restore after one session
-
load
(filename)¶
-
load_from_pretrained
(archive_file, model_file, filename)¶
-
predict
(state)¶ Predict an system action given state. Args:
state (dict): Dialog state. Please refer to util/state.py
- Returns:
action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, …})
-
save
(directory, epoch)¶
-
update
(epoch, batchsz, s, a, next_s, mask, rewarder)¶
-
convlab2.policy.gdpl.train module¶
Created on Tue Dec 31 10:57:51 2019 @author: truthless
-
convlab2.policy.gdpl.train.
sample
(env, policy, batchsz, process_num)¶ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return
- param env
- param policy
- Parameters
batchsz –
- param process_num
- Returns
batch
-
convlab2.policy.gdpl.train.
sampler
(pid, queue, evt, env, policy, batchsz)¶ This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. :param pid: process id :param queue: multiprocessing.Queue, to collect sampled data :param evt: multiprocessing.Event, to keep the process alive :param env: environment instance :param policy: policy network, to generate action from current policy :param batchsz: total sampled items :return:
-
convlab2.policy.gdpl.train.
update
(env, policy, batchsz, epoch, process_num, rewarder)¶