convlab2.policy.gdpl package

Submodules

convlab2.policy.gdpl.estimator module

@author: truthless

class convlab2.policy.gdpl.estimator.AIRL(gamma, h_dim, s_dim, a_dim)

Bases: torch.nn.modules.module.Module

label: 1 for real, 0 for generated

forward(s, a, next_s)
Parameters
  • s – [b, s_dim]

  • a – [b, a_dim]

  • next_s – [b, s_dim]

Returns

[b, 1]

class convlab2.policy.gdpl.estimator.ActEstimatorDataLoaderMultiWoz

Bases: convlab2.policy.mle.multiwoz.loader.ActMLEPolicyDataLoaderMultiWoz

create_dataset_irl(part, batchsz)
class convlab2.policy.gdpl.estimator.RewardEstimator(vector, pretrain=False)

Bases: object

estimate(s, a, next_s, log_pi)

infer the reward of state action pair with the estimator

irl_loop(data_real, data_gen)
kl_divergence(mu, logvar, istrain)
load_irl(filename)
save_irl(directory, epoch)
test_irl(batch, epoch, best)
train_irl(batch, epoch)
update_irl(inputs, batchsz, epoch)

train the reward estimator (together with encoder) using cross entropy loss (real, mixed, generated) Args:

inputs: (s, a, next_s)

convlab2.policy.gdpl.gdpl module

class convlab2.policy.gdpl.gdpl.GDPL(is_train=False, dataset='Multiwoz')

Bases: convlab2.policy.policy.Policy

est_adv(r, v, mask)

we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param v: estimated value, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: A(s, a), V-target(s), both Tensor

classmethod from_pretrained(archive_file='', model_file='https://convlab.blob.core.windows.net/convlab-2/gdpl_policy_multiwoz.zip', is_train=False, dataset='Multiwoz')
init_session()

Restore after one session

load(filename)
load_from_pretrained(archive_file, model_file, filename)
predict(state)

Predict an system action given state. Args:

state (dict): Dialog state. Please refer to util/state.py

Returns:

action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, …})

save(directory, epoch)
update(epoch, batchsz, s, a, next_s, mask, rewarder)

convlab2.policy.gdpl.train module

Created on Tue Dec 31 10:57:51 2019 @author: truthless

convlab2.policy.gdpl.train.sample(env, policy, batchsz, process_num)

Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return

param env

param policy

Parameters

batchsz

param process_num

Returns

batch

convlab2.policy.gdpl.train.sampler(pid, queue, evt, env, policy, batchsz)

This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. :param pid: process id :param queue: multiprocessing.Queue, to collect sampled data :param evt: multiprocessing.Event, to keep the process alive :param env: environment instance :param policy: policy network, to generate action from current policy :param batchsz: total sampled items :return:

convlab2.policy.gdpl.train.update(env, policy, batchsz, epoch, process_num, rewarder)

Module contents