convlab2.policy.pg package

Submodules

convlab2.policy.pg.pg module

class convlab2.policy.pg.pg.PG(is_train=False, dataset='Multiwoz')

Bases: convlab2.policy.policy.Policy

est_return(r, mask)

we save a trajectory in continuous space and it reaches the ending of current trajectory when mask=0. :param r: reward, Tensor, [b] :param mask: indicates ending for 0 otherwise 1, Tensor, [b] :return: V-target(s), Tensor

classmethod from_pretrained(archive_file='', model_file='https://convlab.blob.core.windows.net/convlab-2/pg_policy_multiwoz.zip')
init_session()

Restore after one session

load(filename)
load_from_pretrained(archive_file, model_file, filename)
predict(state)

Predict an system action given state. Args:

state (dict): Dialog state. Please refer to util/state.py

Returns:

action : System act, with the form of (act_type, {slot_name_1: value_1, slot_name_2, value_2, …})

save(directory, epoch)
update(epoch, batchsz, s, a, r, mask)

convlab2.policy.pg.train module

Created on Sun Jul 14 16:14:07 2019 @author: truthless

convlab2.policy.pg.train.sample(env, policy, batchsz, process_num)

Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return

param env

param policy

Parameters

batchsz

param process_num

Returns

batch

convlab2.policy.pg.train.sampler(pid, queue, evt, env, policy, batchsz)

This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. :param pid: process id :param queue: multiprocessing.Queue, to collect sampled data :param evt: multiprocessing.Event, to keep the process alive :param env: environment instance :param policy: policy network, to generate action from current policy :param batchsz: total sampled items :return:

convlab2.policy.pg.train.update(env, policy, batchsz, epoch, process_num)

Module contents