convlab2.policy.vhus package

Submodules

convlab2.policy.vhus.train module

@author: truthless

class convlab2.policy.vhus.train.VHUS_Trainer(config, manager, goal_gen)

Bases: object

imit_test(epoch, best)

provide an unbiased evaluation of the user simulator fit on the training dataset

imitating(epoch)

train the user simulator by simple imitation learning (behavioral cloning)

save(directory, epoch)
test()
user_loop(data)
convlab2.policy.vhus.train.batch_iter(x, y, z, batch_size=64)

convlab2.policy.vhus.usermodule module

@author: truthless

class convlab2.policy.vhus.usermodule.Decoder(vocab_size, max_len, embed_size, hidden_size, sos_id=2, eos_id=3, n_layers=1, rnn_cell='GRU', input_dropout_p=0, dropout_p=0, use_attention=False)

Bases: torch.nn.modules.module.Module

KEY_ATTN_SCORE = 'attention_score'
KEY_LENGTH = 'length'
KEY_SEQUENCE = 'sequence'
forward(inputs=None, encoder_hidden=None, encoder_outputs=None, function=<built-in method log_softmax of type object>, teacher_forcing_ratio=0)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_step(input_var, hidden, encoder_outputs, function)
class convlab2.policy.vhus.usermodule.Encoder(vocab_size, embed_size, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell='GRU', variable_lengths=False, embedding=None, update_embedding=True)

Bases: torch.nn.modules.module.Module

forward(input_var, input_lengths=None)

Applies a multi-layer RNN to an input sequence. Args:

input_var (batch, seq_len): tensor containing the features of the input sequence. input_lengths (list of int, optional): A list that contains the lengths of sequences

in the mini-batch

Returns: output, hidden
  • output (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence

  • hidden (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h

class convlab2.policy.vhus.usermodule.VHUS(cfg, voc_goal_size, voc_usr_size, voc_sys_size)

Bases: torch.nn.modules.module.Module

forward(goals, goals_length, posts, posts_length, origin_responses=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

select_action(goal, goal_length, post, post_length)
Parameters
  • goal – [goal_len]

  • goal_length – []

  • post – [sen_len, word_len]

  • post_length – [sen_len]

Returns

[act_len], [1]

convlab2.policy.vhus.usermodule.batch_gather_3_1(inputs, dim)
Args:

inputs (batchsz, sen_len, embed_dim) dim (batchsz)

Returns:

output (batch, embed_dim)

convlab2.policy.vhus.usermodule.batch_gather_4_2(inputs, dim)
Args:

inputs (batchsz, sen_len, word_len, embed_dim) dim (batchsz, sen_len)

Returns:

output (batch, sen_len, embed_dim)

convlab2.policy.vhus.usermodule.reparameterize(mu, logvar)

convlab2.policy.vhus.util module

convlab2.policy.vhus.util.capital(da)
convlab2.policy.vhus.util.kl_gaussian(argu)
convlab2.policy.vhus.util.padding(old, l)

pad a list of different lens “old” to the same len “l”

convlab2.policy.vhus.util.padding_data(data)

convlab2.policy.vhus.vhus module

class convlab2.policy.vhus.vhus.UserPolicyVHUSAbstract(archive_file, model_file)

Bases: convlab2.policy.policy.Policy

get_goal()
init_session()

Init the class variables for a new session.

is_terminated()
load(archive_file, model_file, filename)
load_from_local_path(path)
predict(state)

Predict an user act based on state and preorder system action.

Args:
state (tuple):

Dialog state.

Returns:
usr_action (tuple):

User act.

session_over (boolean):

True to terminate session, otherwise session continues.

Module contents