tatk.e2e.rnn_rollout.models package

Submodules

tatk.e2e.rnn_rollout.models.attn module

class tatk.e2e.rnn_rollout.models.attn.Attention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

__init__(query_size, value_size, hid_size, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q, v, mask=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(h)
class tatk.e2e.rnn_rollout.models.attn.BiRnnAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: tatk.e2e.rnn_rollout.models.attn.ChunkedAttention

__init__(query_size, value_size, hid_size, dropout, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flatten_parameters()
forward(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(query, fwd_hs, bwd_hs, lens)
forward_rnn(rnn, inpts, lens, hid_idxs)
class tatk.e2e.rnn_rollout.models.attn.ChunkedAttention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

__init__(query_size, value_size, hid_size, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

make_mask(fwd_hs, lens)
reverse(fwd_inpts, fwd_lens, rev_idxs)
zero_h(bsz, n=1)
class tatk.e2e.rnn_rollout.models.attn.HierarchicalAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: tatk.e2e.rnn_rollout.models.attn.ChunkedAttention

__init__(query_size, value_size, hid_size, dropout, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flatten_parameters()
forward(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_sent_attn(query, word_hs)
forward_word_attn(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)
forward_word_rnn(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)
class tatk.e2e.rnn_rollout.models.attn.KeyValueAttention(query_size, key_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

__init__(query_size, key_size, value_size, hid_size, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q, k, v, mask=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(h)
class tatk.e2e.rnn_rollout.models.attn.MaskedAttention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

__init__(query_size, value_size, hid_size, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(q, v, ln=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

make_mask(v, ln)
class tatk.e2e.rnn_rollout.models.attn.SentenceAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: tatk.e2e.rnn_rollout.models.attn.ChunkedAttention

__init__(query_size, value_size, hid_size, dropout, init_range)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flatten_parameters()
forward(query, fwd_inpt, fwd_len, rev_idx, hid_idx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_word_attn(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)
forward_word_rnn(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)

tatk.e2e.rnn_rollout.models.ctx_encoder module

Set of context encoders.

class tatk.e2e.rnn_rollout.models.ctx_encoder.MlpContextEncoder(n, k, nembed, nhid, dropout, init_range, skip_values=False)

Bases: torch.nn.modules.module.Module

Simple encoder for the dialogue context. Encoder counts and values via MLP.

__init__(n, k, nembed, nhid, dropout, init_range, skip_values=False)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

tatk.e2e.rnn_rollout.models.latent_clustering_model module

class tatk.e2e.rnn_rollout.models.latent_clustering_model.BaselineClusteringModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, cond_h)
embed_sentence(inpt)
engine_ty

alias of tatk.e2e.rnn_rollout.engines.latent_clustering_engine.BaselineClusteringEngine

flatten_parameters()
forward(inpts, tgts, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, cond_h)
forward_encoder(inpt, hid_idx)
forward_marginal_loss(p_z, inpt, tgt, cnt)
forward_memory(mem_h, inpt, hid_idx)
init_memory(ctx_h)
read(inpt, mem_h)
unembed_sentence(dec_hs)
word2var(word)
write(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
class tatk.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringLanguageModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, cond_h)
embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx, enc_h)
engine_ty

alias of tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine

flatten_parameters()
forward(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, cond_h)
forward_encoder(inpt, hid_idx, enc_h)
unembed_sentence(dec_hs)
word2var(word)
write(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
zero_grad()

Sets gradients of all model parameters to zero.

class tatk.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, mem_h)
embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx)
engine_ty

alias of tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringEngine

flatten_parameters()
forward(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, mem_h)
forward_e_step(z_prob, mem_h, inpt, tgt, sel_tgt_prob, cnt)
forward_encoder(ctx_h, inpt, hid_idx)
unembed_sentence(dec_hs)
class tatk.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringPredictionModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

clear_weights()
corpus_ty

alias of tatk.e2e.rnn_rollout.data.SentenceCorpus

embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx)
engine_ty

alias of tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine

flatten_parameters()
forward(inpts, tgts, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_encoder(ctx_h, inpt=None, hid_idx=None)
forward_kldiv(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)
forward_marginal_loss(q_z, lang_enc_h, inpt, tgt, cnt)
forward_memory(ctx_h, mem_h=None, inpt=None, hid_idx=None)
forward_prediction(cnt, mem_h, sample=False)
forward_prediction_multi(cnt, mem_h, num_samples, sample=False)
forward_validation(inpts, tgts, hid_idxs, ctx, cnt)
forward_validation_marginal(inpts, tgts, hid_idxs, ctx, cnt)
read(inpt, lang_enc_h, mem_h, ctx_h)
write(lang_enc_h, lat_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
class tatk.e2e.rnn_rollout.models.latent_clustering_model.RecurrentUnit(input_size, hidden_size, args)

Bases: torch.nn.modules.module.Module

__init__(input_size, hidden_size, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(x, h)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tatk.e2e.rnn_rollout.models.latent_clustering_model.ShardedLatentBottleneckModule(num_shards, num_clusters, input_size, output_size, args)

Bases: torch.nn.modules.module.Module

__init__(num_shards, num_clusters, input_size, output_size, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(shard, key)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

select(shard, idx)
select_shard(shard)
zero_grad()

Sets gradients of all model parameters to zero.

class tatk.e2e.rnn_rollout.models.latent_clustering_model.SimpleSeparateSelectionModule(input_size, hidden_size, output_size, args)

Bases: torch.nn.modules.module.Module

__init__(input_size, hidden_size, output_size, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flatten_parameters()
forward(h)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

tatk.e2e.rnn_rollout.models.modules module

Helper functions for module initialization.

class tatk.e2e.rnn_rollout.models.modules.CudaModule(device_id)

Bases: torch.nn.modules.module.Module

A helper to run a module on a particular device using CUDA.

__init__(device_id)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

to_device(m)
class tatk.e2e.rnn_rollout.models.modules.MlpContextEncoder(n, k, nembed, nhid, init_range, device_id)

Bases: tatk.e2e.rnn_rollout.models.modules.CudaModule

A module that encodes dialogues context using an MLP.

__init__(n, k, nembed, nhid, init_range, device_id)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tatk.e2e.rnn_rollout.models.modules.RnnContextEncoder(n, k, nembed, nhid, init_range, device_id)

Bases: tatk.e2e.rnn_rollout.models.modules.CudaModule

A module that encodes dialogues context using an RNN.

__init__(n, k, nembed, nhid, init_range, device_id)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

tatk.e2e.rnn_rollout.models.modules.init_cont(cont, init_range)

Initializes a container uniformly.

tatk.e2e.rnn_rollout.models.modules.init_rnn(rnn, init_range, weights=None, biases=None)

Initializes RNN uniformly.

tatk.e2e.rnn_rollout.models.modules.init_rnn_cell(rnn, init_range)

Initializes RNNCell uniformly.

tatk.e2e.rnn_rollout.models.rnn_model module

class tatk.e2e.rnn_rollout.models.rnn_model.RnnModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.WordCorpus

engine_ty

alias of tatk.e2e.rnn_rollout.engines.rnn_engine.RnnEngine

flatten_parameters()
forward(inpt, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_context(ctx)
forward_lm(inpt_emb, lang_h, ctx_h)
forward_selection(inpt_emb, lang_h, ctx_h)
generate_choice_logits(inpt, lang_h, ctx_h)
init_weights()
read(inpt, lang_h, ctx_h, prefix_token='THEM:')
score_sent(sent, lang_h, ctx_h, temperature)
word2var(word)
write(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)

Generate a sentence word by word and feed the output of the previous timestep as input to the next.

write_batch(bsz, lang_h, ctx_h, temperature, max_words=100)
zero_h(bsz, nhid=None, copies=None)

tatk.e2e.rnn_rollout.models.selection_model module

class tatk.e2e.rnn_rollout.models.selection_model.SelectionModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.SentenceCorpus

engine_ty

alias of tatk.e2e.rnn_rollout.engines.selection_engine.SelectionEngine

flatten_parameters()
forward(inpts, lens, rev_idxs, hid_idxs, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_each_timestamp(inpts, lens, rev_idxs, hid_idxs, ctx)
forward_inpts(inpts, ctx_h)
class tatk.e2e.rnn_rollout.models.selection_model.SelectionModule(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)

Bases: torch.nn.modules.module.Module

__init__(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

flatten_parameters()
forward(q, hs, lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

tatk.e2e.rnn_rollout.models.utils module

A set of useful tools.

tatk.e2e.rnn_rollout.models.utils.init_cont(cont, init_range)

Uniform initialization of a container.

tatk.e2e.rnn_rollout.models.utils.init_linear(linear, init_range)

Uniform initialization of Linear.

tatk.e2e.rnn_rollout.models.utils.init_rnn(rnn, init_range, weights=None, biases=None)

Orthogonal initialization of RNN.

tatk.e2e.rnn_rollout.models.utils.init_rnn_cell(rnn, init_range)

Orthogonal initialization of RNNCell.

tatk.e2e.rnn_rollout.models.utils.make_mask(n, marked, value=-1000)

Create a masked tensor.