convlab2.policy.vhus package¶
Submodules¶
convlab2.policy.vhus.train module¶
@author: truthless
-
class
convlab2.policy.vhus.train.
VHUS_Trainer
(config, manager, goal_gen)¶ Bases:
object
-
imit_test
(epoch, best)¶ provide an unbiased evaluation of the user simulator fit on the training dataset
-
imitating
(epoch)¶ train the user simulator by simple imitation learning (behavioral cloning)
-
save
(directory, epoch)¶
-
test
()¶
-
user_loop
(data)¶
-
-
convlab2.policy.vhus.train.
batch_iter
(x, y, z, batch_size=64)¶
convlab2.policy.vhus.usermodule module¶
@author: truthless
-
class
convlab2.policy.vhus.usermodule.
Decoder
(vocab_size, max_len, embed_size, hidden_size, sos_id=2, eos_id=3, n_layers=1, rnn_cell='GRU', input_dropout_p=0, dropout_p=0, use_attention=False)¶ Bases:
torch.nn.modules.module.Module
-
KEY_ATTN_SCORE
= 'attention_score'¶
-
KEY_LENGTH
= 'length'¶
-
KEY_SEQUENCE
= 'sequence'¶
-
forward
(inputs=None, encoder_hidden=None, encoder_outputs=None, function=<built-in method log_softmax of type object>, teacher_forcing_ratio=0)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_step
(input_var, hidden, encoder_outputs, function)¶
-
-
class
convlab2.policy.vhus.usermodule.
Encoder
(vocab_size, embed_size, hidden_size, input_dropout_p=0, dropout_p=0, n_layers=1, rnn_cell='GRU', variable_lengths=False, embedding=None, update_embedding=True)¶ Bases:
torch.nn.modules.module.Module
-
forward
(input_var, input_lengths=None)¶ Applies a multi-layer RNN to an input sequence. Args:
input_var (batch, seq_len): tensor containing the features of the input sequence. input_lengths (list of int, optional): A list that contains the lengths of sequences
in the mini-batch
- Returns: output, hidden
output (batch, seq_len, hidden_size): variable containing the encoded features of the input sequence
hidden (num_layers * num_directions, batch, hidden_size): variable containing the features in the hidden state h
-
-
class
convlab2.policy.vhus.usermodule.
VHUS
(cfg, voc_goal_size, voc_usr_size, voc_sys_size)¶ Bases:
torch.nn.modules.module.Module
-
forward
(goals, goals_length, posts, posts_length, origin_responses=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
select_action
(goal, goal_length, post, post_length)¶ - Parameters
goal – [goal_len]
goal_length – []
post – [sen_len, word_len]
post_length – [sen_len]
- Returns
[act_len], [1]
-
-
convlab2.policy.vhus.usermodule.
batch_gather_3_1
(inputs, dim)¶ - Args:
inputs (batchsz, sen_len, embed_dim) dim (batchsz)
- Returns:
output (batch, embed_dim)
-
convlab2.policy.vhus.usermodule.
batch_gather_4_2
(inputs, dim)¶ - Args:
inputs (batchsz, sen_len, word_len, embed_dim) dim (batchsz, sen_len)
- Returns:
output (batch, sen_len, embed_dim)
-
convlab2.policy.vhus.usermodule.
reparameterize
(mu, logvar)¶
convlab2.policy.vhus.util module¶
-
convlab2.policy.vhus.util.
capital
(da)¶
-
convlab2.policy.vhus.util.
kl_gaussian
(argu)¶
-
convlab2.policy.vhus.util.
padding
(old, l)¶ pad a list of different lens “old” to the same len “l”
-
convlab2.policy.vhus.util.
padding_data
(data)¶
convlab2.policy.vhus.vhus module¶
-
class
convlab2.policy.vhus.vhus.
UserPolicyVHUSAbstract
(archive_file, model_file)¶ Bases:
convlab2.policy.policy.Policy
-
get_goal
()¶
-
init_session
()¶ Init the class variables for a new session.
-
is_terminated
()¶
-
load
(archive_file, model_file, filename)¶
-
load_from_local_path
(path)¶
-
predict
(state)¶ Predict an user act based on state and preorder system action.
- Args:
- state (tuple):
Dialog state.
- Returns:
- usr_action (tuple):
User act.
- session_over (boolean):
True to terminate session, otherwise session continues.
-