convlab2.e2e.rnn_rollout.models package¶
Submodules¶
convlab2.e2e.rnn_rollout.models.attn module¶
-
class
convlab2.e2e.rnn_rollout.models.attn.
Attention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
forward
(q, v, mask=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(h)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
BiRnnAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention
-
flatten_parameters
()¶
-
forward
(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(query, fwd_hs, bwd_hs, lens)¶
-
forward_rnn
(rnn, inpts, lens, hid_idxs)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
ChunkedAttention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
make_mask
(fwd_hs, lens)¶
-
reverse
(fwd_inpts, fwd_lens, rev_idxs)¶
-
zero_h
(bsz, n=1)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
HierarchicalAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention
-
flatten_parameters
()¶
-
forward
(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_sent_attn
(query, word_hs)¶
-
forward_word_attn
(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)¶
-
forward_word_rnn
(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
KeyValueAttention
(query_size, key_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
forward
(q, k, v, mask=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(h)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
MaskedAttention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
forward
(q, v, ln=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
make_mask
(v, ln)¶
-
-
class
convlab2.e2e.rnn_rollout.models.attn.
SentenceAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention
-
flatten_parameters
()¶
-
forward
(query, fwd_inpt, fwd_len, rev_idx, hid_idx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_word_attn
(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)¶
-
forward_word_rnn
(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)¶
-
convlab2.e2e.rnn_rollout.models.ctx_encoder module¶
Set of context encoders.
-
class
convlab2.e2e.rnn_rollout.models.ctx_encoder.
MlpContextEncoder
(n, k, nembed, nhid, dropout, init_range, skip_values=False)¶ Bases:
torch.nn.modules.module.Module
Simple encoder for the dialogue context. Encoder counts and values via MLP.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
convlab2.e2e.rnn_rollout.models.latent_clustering_model module¶
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
BaselineClusteringModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, cond_h)¶
-
embed_sentence
(inpt)¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.BaselineClusteringEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, cond_h)¶
-
forward_encoder
(inpt, hid_idx)¶
-
forward_marginal_loss
(p_z, inpt, tgt, cnt)¶
-
forward_memory
(mem_h, inpt, hid_idx)¶
-
init_memory
(ctx_h)¶
-
read
(inpt, mem_h)¶
-
unembed_sentence
(dec_hs)¶
-
word2var
(word)¶
-
write
(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringLanguageModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, cond_h)¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx, enc_h)¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, cond_h)¶
-
forward_encoder
(inpt, hid_idx, enc_h)¶
-
unembed_sentence
(dec_hs)¶
-
word2var
(word)¶
-
write
(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
zero_grad
()¶ Sets gradients of all model parameters to zero.
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, mem_h)¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx)¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, mem_h)¶
-
forward_e_step
(z_prob, mem_h, inpt, tgt, sel_tgt_prob, cnt)¶
-
forward_encoder
(ctx_h, inpt, hid_idx)¶
-
unembed_sentence
(dec_hs)¶
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringPredictionModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
clear_weights
()¶
-
corpus_ty
¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx)¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_encoder
(ctx_h, inpt=None, hid_idx=None)¶
-
forward_kldiv
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶
-
forward_marginal_loss
(q_z, lang_enc_h, inpt, tgt, cnt)¶
-
forward_memory
(ctx_h, mem_h=None, inpt=None, hid_idx=None)¶
-
forward_prediction
(cnt, mem_h, sample=False)¶
-
forward_prediction_multi
(cnt, mem_h, num_samples, sample=False)¶
-
forward_validation
(inpts, tgts, hid_idxs, ctx, cnt)¶
-
forward_validation_marginal
(inpts, tgts, hid_idxs, ctx, cnt)¶
-
read
(inpt, lang_enc_h, mem_h, ctx_h)¶
-
write
(lang_enc_h, lat_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
RecurrentUnit
(input_size, hidden_size, args)¶ Bases:
torch.nn.modules.module.Module
-
forward
(x, h)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
ShardedLatentBottleneckModule
(num_shards, num_clusters, input_size, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
forward
(shard, key)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
select
(shard, idx)¶
-
select_shard
(shard)¶
-
zero_grad
()¶ Sets gradients of all model parameters to zero.
-
-
class
convlab2.e2e.rnn_rollout.models.latent_clustering_model.
SimpleSeparateSelectionModule
(input_size, hidden_size, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
flatten_parameters
()¶
-
forward
(h)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
convlab2.e2e.rnn_rollout.models.modules module¶
Helper functions for module initialization.
-
class
convlab2.e2e.rnn_rollout.models.modules.
CudaModule
(device_id)¶ Bases:
torch.nn.modules.module.Module
A helper to run a module on a particular device using CUDA.
-
to_device
(m)¶
-
-
class
convlab2.e2e.rnn_rollout.models.modules.
MlpContextEncoder
(n, k, nembed, nhid, init_range, device_id)¶ Bases:
convlab2.e2e.rnn_rollout.models.modules.CudaModule
A module that encodes dialogues context using an MLP.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
convlab2.e2e.rnn_rollout.models.modules.
RnnContextEncoder
(n, k, nembed, nhid, init_range, device_id)¶ Bases:
convlab2.e2e.rnn_rollout.models.modules.CudaModule
A module that encodes dialogues context using an RNN.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
convlab2.e2e.rnn_rollout.models.modules.
init_cont
(cont, init_range)¶ Initializes a container uniformly.
-
convlab2.e2e.rnn_rollout.models.modules.
init_rnn
(rnn, init_range, weights=None, biases=None)¶ Initializes RNN uniformly.
-
convlab2.e2e.rnn_rollout.models.modules.
init_rnn_cell
(rnn, init_range)¶ Initializes RNNCell uniformly.
convlab2.e2e.rnn_rollout.models.rnn_model module¶
-
class
convlab2.e2e.rnn_rollout.models.rnn_model.
RnnModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.rnn_engine.RnnEngine
-
flatten_parameters
()¶
-
forward
(inpt, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_context
(ctx)¶
-
forward_lm
(inpt_emb, lang_h, ctx_h)¶
-
forward_selection
(inpt_emb, lang_h, ctx_h)¶
-
generate_choice_logits
(inpt, lang_h, ctx_h)¶
-
init_weights
()¶
-
read
(inpt, lang_h, ctx_h, prefix_token='THEM:')¶
-
score_sent
(sent, lang_h, ctx_h, temperature)¶
-
word2var
(word)¶
-
write
(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)¶ Generate a sentence word by word and feed the output of the previous timestep as input to the next.
-
write_batch
(bsz, lang_h, ctx_h, temperature, max_words=100)¶
-
zero_h
(bsz, nhid=None, copies=None)¶
-
convlab2.e2e.rnn_rollout.models.selection_model module¶
-
class
convlab2.e2e.rnn_rollout.models.selection_model.
SelectionModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.selection_engine.SelectionEngine
-
flatten_parameters
()¶
-
forward
(inpts, lens, rev_idxs, hid_idxs, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_each_timestamp
(inpts, lens, rev_idxs, hid_idxs, ctx)¶
-
forward_inpts
(inpts, ctx_h)¶
-
-
class
convlab2.e2e.rnn_rollout.models.selection_model.
SelectionModule
(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
flatten_parameters
()¶
-
forward
(q, hs, lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
convlab2.e2e.rnn_rollout.models.utils module¶
A set of useful tools.
-
convlab2.e2e.rnn_rollout.models.utils.
init_cont
(cont, init_range)¶ Uniform initialization of a container.
-
convlab2.e2e.rnn_rollout.models.utils.
init_linear
(linear, init_range)¶ Uniform initialization of Linear.
-
convlab2.e2e.rnn_rollout.models.utils.
init_rnn
(rnn, init_range, weights=None, biases=None)¶ Orthogonal initialization of RNN.
-
convlab2.e2e.rnn_rollout.models.utils.
init_rnn_cell
(rnn, init_range)¶ Orthogonal initialization of RNNCell.
-
convlab2.e2e.rnn_rollout.models.utils.
make_mask
(n, marked, value=- 1000)¶ Create a masked tensor.