tatk.policy.vhus package


tatk.policy.vhus.train module

@author: truthless

class tatk.policy.vhus.train.VHUS_Trainer(config, manager, goal_gen)

Bases: object

__init__(config, manager, goal_gen)

imit_test(epoch, best)

provide an unbiased evaluation of the user simulator fit on the training dataset


train the user simulator by simple imitation learning (behavioral cloning)

save(directory, epoch)
tatk.policy.vhus.train.batch_iter(x, y, z, batch_size=64)

tatk.policy.vhus.usermodule module

@author: truthless

class tatk.policy.vhus.usermodule.Decoder(vocab_size, max_len, embed_size, hidden_size, sos_id=2, eos_id=3, n_layers=1, rnn_cell='GRU', input_dropout_p=0, dropout_p=0, use_attention=False)

Bases: torch.nn.modules.module.Module

KEY_ATTN_SCORE = 'attention_score'
KEY_LENGTH = 'length'
KEY_SEQUENCE = 'sequence'
__init__(vocab_size, max_len, embed_size, hidden_size, sos_id=2, eos_id=3, n_layers=1, rnn_cell='GRU', input_dropout_p=0, dropout_p=0, use_attention=False)

forward(inputs=None, encoder_hidden=None, encoder_outputs=None, function=<built-in method log_softmax of type object>, teacher_forcing_ratio=0)

forward_step(input_var, hidden, encoder_outputs, function)
class tatk.policy.vhus.usermodule.Encoder(vocab_size, embed_size, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell='GRU', variable_lengths=False, embedding=None, update_embedding=True)

Bases: torch.nn.modules.module.Module

__init__(vocab_size, embed_size, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell='GRU', variable_lengths=False, embedding=None, update_embedding=True)

forward(input_var, input_lengths=None)

Applies a multi-layer RNN to an input sequence. Args:

input_var (batch, seq_len): tensor containing the features of the input sequence. input_lengths (list of int, optional): A list that contains the lengths of sequences

in the mini-batch

Returns: output, hidden
  • output (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence

  • hidden (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h

class tatk.policy.vhus.usermodule.VHUS(cfg, voc_goal_size, voc_usr_size, voc_sys_size)

Bases: torch.nn.modules.module.Module

__init__(cfg, voc_goal_size, voc_usr_size, voc_sys_size)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(goals, goals_length, posts, posts_length, origin_responses=None)

select_action(goal, goal_length, post, post_length)
  • goal – [goal_len]

  • goal_length – []

  • post – [sen_len, word_len]

  • post_length – [sen_len]


[act_len], [1]

tatk.policy.vhus.usermodule.batch_gather_3_1(inputs, dim)

inputs (batchsz, sen_len, embed_dim) dim (batchsz)


output (batch, embed_dim)

tatk.policy.vhus.usermodule.batch_gather_4_2(inputs, dim)

inputs (batchsz, sen_len, word_len, embed_dim) dim (batchsz, sen_len)


output (batch, sen_len, embed_dim)

tatk.policy.vhus.usermodule.reparameterize(mu, logvar)

tatk.policy.vhus.util module

tatk.policy.vhus.util.padding(old, l)

pad a list of different lens “old” to the same len “l”


tatk.policy.vhus.vhus module

class tatk.policy.vhus.vhus.UserPolicyVHUSAbstract(archive_file, model_file)

Bases: tatk.policy.policy.Policy

__init__(archive_file, model_file)

Init the class variables for a new session.

load(archive_file, model_file, filename)

Predict an user act based on state and preorder system action.

state (tuple):

Dialog state.

usr_action (tuple):

User act.

session_over (boolean):

True to terminate session, otherwise session continues.