convlab2.policy package¶
Subpackages¶
- convlab2.policy.gdpl package
- convlab2.policy.mdrg package
- convlab2.policy.mle package
- convlab2.policy.pg package
- convlab2.policy.ppo package
- convlab2.policy.rule package
- convlab2.policy.vector package
- convlab2.policy.vhus package
Submodules¶
convlab2.policy.evaluate module¶
-
convlab2.policy.evaluate.
evaluate
(dataset_name, model_name, load_path, calculate_reward=True)¶
-
convlab2.policy.evaluate.
init_logging
(log_dir_path, path_suffix=None)¶
-
convlab2.policy.evaluate.
sample
(env, policy, batchsz, process_num)¶ Given batchsz number of task, the batchsz will be splited equally to each processes and when processes return, it merge all data and return
- param env
- param policy
- Parameters
batchsz –
- param process_num
- Returns
batch
-
convlab2.policy.evaluate.
sampler
(pid, queue, evt, env, policy, batchsz)¶ This is a sampler function, and it will be called by multiprocess.Process to sample data from environment by multiple processes. :param pid: process id :param queue: multiprocessing.Queue, to collect sampled data :param evt: multiprocessing.Event, to keep the process alive :param env: environment instance :param policy: policy network, to generate action from current policy :param batchsz: total sampled items :return:
convlab2.policy.policy module¶
Policy Interface
-
class
convlab2.policy.policy.
Policy
¶ Bases:
convlab2.util.module.Module
Policy module interface.
-
predict
(state)¶ Predict the next agent action given dialog state.
- Args:
- state (dict or list of list):
when the policy takes dialogue state as input, the type is dict. else when the policy takes dialogue act as input, the type is list of list.
- Returns:
- action (list of list or str):
when the policy outputs dialogue act, the type is list of list. else when the policy outputs utterance directly, the type is str.
-
convlab2.policy.rlmodule module¶
-
class
convlab2.policy.rlmodule.
ContinuousPolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, a_dim]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[a_dim]
-
-
class
convlab2.policy.rlmodule.
DiscretePolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, 1]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[1]
-
-
class
convlab2.policy.rlmodule.
EpsilonGreedyPolicy
(s_dim, h_dim, a_dim, epsilon_spec={'end': 0.0, 'end_epoch': 200, 'start': 0.1})¶ Bases:
torch.nn.modules.module.Module
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
select_action
(s, is_train=True)¶ - Parameters
s – [s_dim]
- Returns
[1]
-
update_epsilon
(epoch)¶
-
-
class
convlab2.policy.rlmodule.
Memory
¶ Bases:
object
-
append
(new_memory)¶
-
get_batch
(batch_size=None)¶
-
push
(*args)¶ Saves a transition.
-
-
class
convlab2.policy.rlmodule.
MemoryReplay
(max_size)¶ Bases:
object
The difference to class Memory is that MemoryReplay has a limited size. It is mainly used for off-policy algorithms.
-
append
(new_memory)¶
-
get_batch
(batch_size=None)¶
-
push
(*args)¶ Saves a transition.
-
reset
()¶
-
-
class
convlab2.policy.rlmodule.
MultiDiscretePolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, a_dim]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[a_dim]
-
-
class
convlab2.policy.rlmodule.
Transition
(state, action, reward, next_state, mask)¶ Bases:
tuple
-
property
action
¶ Alias for field number 1
-
property
mask
¶ Alias for field number 4
-
property
next_state
¶ Alias for field number 3
-
property
reward
¶ Alias for field number 2
-
property
state
¶ Alias for field number 0
-
property
convlab2.policy.vec module¶
Vector Interface
-
class
convlab2.policy.vec.
Vector
¶ Bases:
object
-
action_devectorize
(action_vec)¶ recover an action
- Args:
- action_vec (np.array):
Dialog act vector
- Returns:
- action (tuple):
Dialog act
-
generate_dict
()¶ init the dict for mapping state/action into vector
-
state_vectorize
(state)¶ vectorize a state
- Args:
- state (tuple):
Dialog state
- Returns:
- state_vec (np.array):
Dialog state vector
-