tatk.policy package

Subpackages

Submodules

tatk.policy.policy module

Policy Interface

class tatk.policy.policy.Policy

Bases: tatk.util.module.Module

Base class for policy model.

predict(state)

Predict the next agent action given dialog state. update state[‘system_action’] with predict system action

Args:
state (tuple or dict):

when the DST and Policy module are separated, the type of state is tuple. else when they are aggregated together, the type of state is dict (dialog act).

Returns:
action (list of list):

The next dialog action.

tatk.policy.rlmodule module

class tatk.policy.rlmodule.ContinuousPolicy(s_dim, h_dim, a_dim)

Bases: torch.nn.modules.module.Module

__init__(s_dim, h_dim, a_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(s)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_log_prob(s, a)
Parameters
  • s – [b, s_dim]

  • a – [b, a_dim]

Returns

[b, 1]

select_action(s, sample=True)
Parameters

s – [s_dim]

Returns

[a_dim]

class tatk.policy.rlmodule.DiscretePolicy(s_dim, h_dim, a_dim)

Bases: torch.nn.modules.module.Module

__init__(s_dim, h_dim, a_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(s)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_log_prob(s, a)
Parameters
  • s – [b, s_dim]

  • a – [b, 1]

Returns

[b, 1]

select_action(s, sample=True)
Parameters

s – [s_dim]

Returns

[1]

class tatk.policy.rlmodule.Memory

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

append(new_memory)
get_batch(batch_size=None)
push(*args)

Saves a transition.

class tatk.policy.rlmodule.MultiDiscretePolicy(s_dim, h_dim, a_dim)

Bases: torch.nn.modules.module.Module

__init__(s_dim, h_dim, a_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(s)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_log_prob(s, a)
Parameters
  • s – [b, s_dim]

  • a – [b, a_dim]

Returns

[b, 1]

select_action(s, sample=True)
Parameters

s – [s_dim]

Returns

[a_dim]

class tatk.policy.rlmodule.Transition(state, action, reward, next_state, mask)

Bases: tuple

property action

Alias for field number 1

property mask

Alias for field number 4

property next_state

Alias for field number 3

property reward

Alias for field number 2

property state

Alias for field number 0

class tatk.policy.rlmodule.Value(s_dim, hv_dim)

Bases: torch.nn.modules.module.Module

__init__(s_dim, hv_dim)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(s)
Parameters

s – [b, s_dim]

Returns

[b, 1]

tatk.policy.vec module

Vector Interface

class tatk.policy.vec.Vector

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

action_devectorize(action_vec)

recover an action

Args:
action_vec (np.array):

Dialog act vector

Returns:
action (tuple):

Dialog act

generate_dict()

init the dict for mapping state/action into vector

state_vectorize(state)

vectorize a state

Args:
state (tuple):

Dialog state

Returns:
state_vec (np.array):

Dialog state vector