tatk.policy package¶
Subpackages¶
- tatk.policy.gdpl package
- tatk.policy.mdrg package
- Subpackages
- tatk.policy.mdrg.multiwoz package
- Subpackages
- Submodules
- tatk.policy.mdrg.multiwoz.auto_download module
- tatk.policy.mdrg.multiwoz.create_delex_data module
- tatk.policy.mdrg.multiwoz.default_policy module
- tatk.policy.mdrg.multiwoz.evaluator module
- tatk.policy.mdrg.multiwoz.mdrg_model module
- tatk.policy.mdrg.multiwoz.model module
- tatk.policy.mdrg.multiwoz.policy module
- tatk.policy.mdrg.multiwoz.test module
- tatk.policy.mdrg.multiwoz.train module
- tatk.policy.mdrg.multiwoz package
- Module contents
- Subpackages
- tatk.policy.mle package
- tatk.policy.pg package
- tatk.policy.ppo package
- tatk.policy.rule package
- tatk.policy.vector package
- tatk.policy.vhus package
Submodules¶
tatk.policy.policy module¶
Policy Interface
-
class
tatk.policy.policy.
Policy
¶ Bases:
tatk.util.module.Module
Base class for policy model.
-
predict
(state)¶ Predict the next agent action given dialog state. update state[‘system_action’] with predict system action
- Args:
- state (tuple or dict):
when the DST and Policy module are separated, the type of state is tuple. else when they are aggregated together, the type of state is dict (dialog act).
- Returns:
- action (list of list):
The next dialog action.
-
tatk.policy.rlmodule module¶
-
class
tatk.policy.rlmodule.
ContinuousPolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(s_dim, h_dim, a_dim)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, a_dim]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[a_dim]
-
-
class
tatk.policy.rlmodule.
DiscretePolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(s_dim, h_dim, a_dim)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, 1]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[1]
-
-
class
tatk.policy.rlmodule.
Memory
¶ Bases:
object
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
append
(new_memory)¶
-
get_batch
(batch_size=None)¶
-
push
(*args)¶ Saves a transition.
-
-
class
tatk.policy.rlmodule.
MultiDiscretePolicy
(s_dim, h_dim, a_dim)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(s_dim, h_dim, a_dim)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(s)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_log_prob
(s, a)¶ - Parameters
s – [b, s_dim]
a – [b, a_dim]
- Returns
[b, 1]
-
select_action
(s, sample=True)¶ - Parameters
s – [s_dim]
- Returns
[a_dim]
-
-
class
tatk.policy.rlmodule.
Transition
(state, action, reward, next_state, mask)¶ Bases:
tuple
-
property
action
¶ Alias for field number 1
-
property
mask
¶ Alias for field number 4
-
property
next_state
¶ Alias for field number 3
-
property
reward
¶ Alias for field number 2
-
property
state
¶ Alias for field number 0
-
property
tatk.policy.vec module¶
Vector Interface
-
class
tatk.policy.vec.
Vector
¶ Bases:
object
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
action_devectorize
(action_vec)¶ recover an action
- Args:
- action_vec (np.array):
Dialog act vector
- Returns:
- action (tuple):
Dialog act
-
generate_dict
()¶ init the dict for mapping state/action into vector
-
state_vectorize
(state)¶ vectorize a state
- Args:
- state (tuple):
Dialog state
- Returns:
- state_vec (np.array):
Dialog state vector
-