tatk.e2e.rnn_rollout.models package¶
Submodules¶
tatk.e2e.rnn_rollout.models.attn module¶
-
class
tatk.e2e.rnn_rollout.models.attn.
Attention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(query_size, value_size, hid_size, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(q, v, mask=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(h)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
BiRnnAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
tatk.e2e.rnn_rollout.models.attn.ChunkedAttention
-
__init__
(query_size, value_size, hid_size, dropout, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
flatten_parameters
()¶
-
forward
(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(query, fwd_hs, bwd_hs, lens)¶
-
forward_rnn
(rnn, inpts, lens, hid_idxs)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
ChunkedAttention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(query_size, value_size, hid_size, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
make_mask
(fwd_hs, lens)¶
-
reverse
(fwd_inpts, fwd_lens, rev_idxs)¶
-
zero_h
(bsz, n=1)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
HierarchicalAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
tatk.e2e.rnn_rollout.models.attn.ChunkedAttention
-
__init__
(query_size, value_size, hid_size, dropout, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
flatten_parameters
()¶
-
forward
(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_sent_attn
(query, word_hs)¶
-
forward_word_attn
(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)¶
-
forward_word_rnn
(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
KeyValueAttention
(query_size, key_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(query_size, key_size, value_size, hid_size, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(q, k, v, mask=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_attn
(h)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
MaskedAttention
(query_size, value_size, hid_size, init_range)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(query_size, value_size, hid_size, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(q, v, ln=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
make_mask
(v, ln)¶
-
-
class
tatk.e2e.rnn_rollout.models.attn.
SentenceAttention
(query_size, value_size, hid_size, dropout, init_range)¶ Bases:
tatk.e2e.rnn_rollout.models.attn.ChunkedAttention
-
__init__
(query_size, value_size, hid_size, dropout, init_range)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
flatten_parameters
()¶
-
forward
(query, fwd_inpt, fwd_len, rev_idx, hid_idx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_word_attn
(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)¶
-
forward_word_rnn
(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)¶
-
tatk.e2e.rnn_rollout.models.ctx_encoder module¶
Set of context encoders.
-
class
tatk.e2e.rnn_rollout.models.ctx_encoder.
MlpContextEncoder
(n, k, nembed, nhid, dropout, init_range, skip_values=False)¶ Bases:
torch.nn.modules.module.Module
Simple encoder for the dialogue context. Encoder counts and values via MLP.
-
__init__
(n, k, nembed, nhid, dropout, init_range, skip_values=False)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
tatk.e2e.rnn_rollout.models.latent_clustering_model module¶
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
BaselineClusteringModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, cond_h)¶
-
embed_sentence
(inpt)¶
-
flatten_parameters
()¶
-
forward
(inpts, tgts, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, cond_h)¶
-
forward_encoder
(inpt, hid_idx)¶
-
forward_marginal_loss
(p_z, inpt, tgt, cnt)¶
-
forward_memory
(mem_h, inpt, hid_idx)¶
-
init_memory
(ctx_h)¶
-
read
(inpt, mem_h)¶
-
unembed_sentence
(dec_hs)¶
-
word2var
(word)¶
-
write
(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringLanguageModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, cond_h)¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx, enc_h)¶
-
engine_ty
¶ alias of
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, cond_h)¶
-
forward_encoder
(inpt, hid_idx, enc_h)¶
-
unembed_sentence
(dec_hs)¶
-
word2var
(word)¶
-
write
(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
zero_grad
()¶ Sets gradients of all model parameters to zero.
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶
-
decode_sentence
(inpt_emb, mem_h)¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx)¶
-
flatten_parameters
()¶
-
forward
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_decoder
(inpt, mem_h)¶
-
forward_e_step
(z_prob, mem_h, inpt, tgt, sel_tgt_prob, cnt)¶
-
forward_encoder
(ctx_h, inpt, hid_idx)¶
-
unembed_sentence
(dec_hs)¶
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
LatentClusteringPredictionModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
clear_weights
()¶
-
corpus_ty
¶
-
embed_sentence
(inpt)¶
-
encode_sentence
(inpt_emb, hid_idx)¶
-
engine_ty
¶ alias of
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine
-
flatten_parameters
()¶
-
forward
(inpts, tgts, hid_idxs, ctx, cnt)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_encoder
(ctx_h, inpt=None, hid_idx=None)¶
-
forward_kldiv
(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)¶
-
forward_marginal_loss
(q_z, lang_enc_h, inpt, tgt, cnt)¶
-
forward_memory
(ctx_h, mem_h=None, inpt=None, hid_idx=None)¶
-
forward_prediction
(cnt, mem_h, sample=False)¶
-
forward_prediction_multi
(cnt, mem_h, num_samples, sample=False)¶
-
forward_validation
(inpts, tgts, hid_idxs, ctx, cnt)¶
-
forward_validation_marginal
(inpts, tgts, hid_idxs, ctx, cnt)¶
-
read
(inpt, lang_enc_h, mem_h, ctx_h)¶
-
write
(lang_enc_h, lat_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])¶
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
RecurrentUnit
(input_size, hidden_size, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(input_size, hidden_size, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(x, h)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
ShardedLatentBottleneckModule
(num_shards, num_clusters, input_size, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(num_shards, num_clusters, input_size, output_size, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(shard, key)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
select
(shard, idx)¶
-
select_shard
(shard)¶
-
zero_grad
()¶ Sets gradients of all model parameters to zero.
-
-
class
tatk.e2e.rnn_rollout.models.latent_clustering_model.
SimpleSeparateSelectionModule
(input_size, hidden_size, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(input_size, hidden_size, output_size, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
flatten_parameters
()¶
-
forward
(h)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
tatk.e2e.rnn_rollout.models.modules module¶
Helper functions for module initialization.
-
class
tatk.e2e.rnn_rollout.models.modules.
CudaModule
(device_id)¶ Bases:
torch.nn.modules.module.Module
A helper to run a module on a particular device using CUDA.
-
__init__
(device_id)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
to_device
(m)¶
-
-
class
tatk.e2e.rnn_rollout.models.modules.
MlpContextEncoder
(n, k, nembed, nhid, init_range, device_id)¶ Bases:
tatk.e2e.rnn_rollout.models.modules.CudaModule
A module that encodes dialogues context using an MLP.
-
__init__
(n, k, nembed, nhid, init_range, device_id)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tatk.e2e.rnn_rollout.models.modules.
RnnContextEncoder
(n, k, nembed, nhid, init_range, device_id)¶ Bases:
tatk.e2e.rnn_rollout.models.modules.CudaModule
A module that encodes dialogues context using an RNN.
-
__init__
(n, k, nembed, nhid, init_range, device_id)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
tatk.e2e.rnn_rollout.models.modules.
init_cont
(cont, init_range)¶ Initializes a container uniformly.
-
tatk.e2e.rnn_rollout.models.modules.
init_rnn
(rnn, init_range, weights=None, biases=None)¶ Initializes RNN uniformly.
-
tatk.e2e.rnn_rollout.models.modules.
init_rnn_cell
(rnn, init_range)¶ Initializes RNNCell uniformly.
tatk.e2e.rnn_rollout.models.rnn_model module¶
-
class
tatk.e2e.rnn_rollout.models.rnn_model.
RnnModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶ alias of
tatk.e2e.rnn_rollout.data.WordCorpus
-
engine_ty
¶
-
flatten_parameters
()¶
-
forward
(inpt, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_context
(ctx)¶
-
forward_lm
(inpt_emb, lang_h, ctx_h)¶
-
forward_selection
(inpt_emb, lang_h, ctx_h)¶
-
generate_choice_logits
(inpt, lang_h, ctx_h)¶
-
init_weights
()¶
-
read
(inpt, lang_h, ctx_h, prefix_token='THEM:')¶
-
score_sent
(sent, lang_h, ctx_h, temperature)¶
-
word2var
(word)¶
-
write
(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)¶ Generate a sentence word by word and feed the output of the previous timestep as input to the next.
-
write_batch
(bsz, lang_h, ctx_h, temperature, max_words=100)¶
-
zero_h
(bsz, nhid=None, copies=None)¶
-
tatk.e2e.rnn_rollout.models.selection_model module¶
-
class
tatk.e2e.rnn_rollout.models.selection_model.
SelectionModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶
-
engine_ty
¶ alias of
tatk.e2e.rnn_rollout.engines.selection_engine.SelectionEngine
-
flatten_parameters
()¶
-
forward
(inpts, lens, rev_idxs, hid_idxs, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_each_timestamp
(inpts, lens, rev_idxs, hid_idxs, ctx)¶
-
forward_inpts
(inpts, ctx_h)¶
-
-
class
tatk.e2e.rnn_rollout.models.selection_model.
SelectionModule
(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
flatten_parameters
()¶
-
forward
(q, hs, lens, rev_idxs, hid_idxs)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
tatk.e2e.rnn_rollout.models.utils module¶
A set of useful tools.
-
tatk.e2e.rnn_rollout.models.utils.
init_cont
(cont, init_range)¶ Uniform initialization of a container.
-
tatk.e2e.rnn_rollout.models.utils.
init_linear
(linear, init_range)¶ Uniform initialization of Linear.
-
tatk.e2e.rnn_rollout.models.utils.
init_rnn
(rnn, init_range, weights=None, biases=None)¶ Orthogonal initialization of RNN.
-
tatk.e2e.rnn_rollout.models.utils.
init_rnn_cell
(rnn, init_range)¶ Orthogonal initialization of RNNCell.
-
tatk.e2e.rnn_rollout.models.utils.
make_mask
(n, marked, value=-1000)¶ Create a masked tensor.