convlab2.e2e.rnn_rollout.models package

Submodules

convlab2.e2e.rnn_rollout.models.attn module

class convlab2.e2e.rnn_rollout.models.attn.Attention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

forward(q, v, mask=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(h)
class convlab2.e2e.rnn_rollout.models.attn.BiRnnAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention

flatten_parameters()
forward(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(query, fwd_hs, bwd_hs, lens)
forward_rnn(rnn, inpts, lens, hid_idxs)
class convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

make_mask(fwd_hs, lens)
reverse(fwd_inpts, fwd_lens, rev_idxs)
zero_h(bsz, n=1)
class convlab2.e2e.rnn_rollout.models.attn.HierarchicalAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention

flatten_parameters()
forward(query, fwd_inpts, fwd_lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_sent_attn(query, word_hs)
forward_word_attn(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)
forward_word_rnn(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)
class convlab2.e2e.rnn_rollout.models.attn.KeyValueAttention(query_size, key_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

forward(q, k, v, mask=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_attn(h)
class convlab2.e2e.rnn_rollout.models.attn.MaskedAttention(query_size, value_size, hid_size, init_range)

Bases: torch.nn.modules.module.Module

forward(q, v, ln=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

make_mask(v, ln)
class convlab2.e2e.rnn_rollout.models.attn.SentenceAttention(query_size, value_size, hid_size, dropout, init_range)

Bases: convlab2.e2e.rnn_rollout.models.attn.ChunkedAttention

flatten_parameters()
forward(query, fwd_inpt, fwd_len, rev_idx, hid_idx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_word_attn(query, fwd_word_hs, bwd_word_hs, ln, rev_idx, hid_idx)
forward_word_rnn(rnn, bsz, inpts, lens, rev_idxs, hid_idxs)

convlab2.e2e.rnn_rollout.models.ctx_encoder module

Set of context encoders.

class convlab2.e2e.rnn_rollout.models.ctx_encoder.MlpContextEncoder(n, k, nembed, nhid, dropout, init_range, skip_values=False)

Bases: torch.nn.modules.module.Module

Simple encoder for the dialogue context. Encoder counts and values via MLP.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

convlab2.e2e.rnn_rollout.models.latent_clustering_model module

class convlab2.e2e.rnn_rollout.models.latent_clustering_model.BaselineClusteringModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, cond_h)
embed_sentence(inpt)
engine_ty

alias of convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.BaselineClusteringEngine

flatten_parameters()
forward(inpts, tgts, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, cond_h)
forward_encoder(inpt, hid_idx)
forward_marginal_loss(p_z, inpt, tgt, cnt)
forward_memory(mem_h, inpt, hid_idx)
init_memory(ctx_h)
read(inpt, mem_h)
unembed_sentence(dec_hs)
word2var(word)
write(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
class convlab2.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringLanguageModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, cond_h)
embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx, enc_h)
engine_ty

alias of convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine

flatten_parameters()
forward(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, cond_h)
forward_encoder(inpt, hid_idx, enc_h)
unembed_sentence(dec_hs)
word2var(word)
write(cond_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
zero_grad()

Sets gradients of all model parameters to zero.

class convlab2.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.SentenceCorpus

decode_sentence(inpt_emb, mem_h)
embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx)
engine_ty

alias of convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringEngine

flatten_parameters()
forward(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_decoder(inpt, mem_h)
forward_e_step(z_prob, mem_h, inpt, tgt, sel_tgt_prob, cnt)
forward_encoder(ctx_h, inpt, hid_idx)
unembed_sentence(dec_hs)
class convlab2.e2e.rnn_rollout.models.latent_clustering_model.LatentClusteringPredictionModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

clear_weights()
corpus_ty

alias of convlab2.e2e.rnn_rollout.data.SentenceCorpus

embed_sentence(inpt)
encode_sentence(inpt_emb, hid_idx)
engine_ty

alias of convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine

flatten_parameters()
forward(inpts, tgts, hid_idxs, ctx, cnt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_encoder(ctx_h, inpt=None, hid_idx=None)
forward_kldiv(inpts, tgts, sel_tgt_probs, hid_idxs, ctx, cnt)
forward_marginal_loss(q_z, lang_enc_h, inpt, tgt, cnt)
forward_memory(ctx_h, mem_h=None, inpt=None, hid_idx=None)
forward_prediction(cnt, mem_h, sample=False)
forward_prediction_multi(cnt, mem_h, num_samples, sample=False)
forward_validation(inpts, tgts, hid_idxs, ctx, cnt)
forward_validation_marginal(inpts, tgts, hid_idxs, ctx, cnt)
read(inpt, lang_enc_h, mem_h, ctx_h)
write(lang_enc_h, lat_h, max_words, temperature, start_token='YOU:', stop_tokens=['<eos>', '<selection>'])
class convlab2.e2e.rnn_rollout.models.latent_clustering_model.RecurrentUnit(input_size, hidden_size, args)

Bases: torch.nn.modules.module.Module

forward(x, h)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class convlab2.e2e.rnn_rollout.models.latent_clustering_model.ShardedLatentBottleneckModule(num_shards, num_clusters, input_size, output_size, args)

Bases: torch.nn.modules.module.Module

forward(shard, key)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

select(shard, idx)
select_shard(shard)
zero_grad()

Sets gradients of all model parameters to zero.

class convlab2.e2e.rnn_rollout.models.latent_clustering_model.SimpleSeparateSelectionModule(input_size, hidden_size, output_size, args)

Bases: torch.nn.modules.module.Module

flatten_parameters()
forward(h)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

convlab2.e2e.rnn_rollout.models.modules module

Helper functions for module initialization.

class convlab2.e2e.rnn_rollout.models.modules.CudaModule(device_id)

Bases: torch.nn.modules.module.Module

A helper to run a module on a particular device using CUDA.

to_device(m)
class convlab2.e2e.rnn_rollout.models.modules.MlpContextEncoder(n, k, nembed, nhid, init_range, device_id)

Bases: convlab2.e2e.rnn_rollout.models.modules.CudaModule

A module that encodes dialogues context using an MLP.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class convlab2.e2e.rnn_rollout.models.modules.RnnContextEncoder(n, k, nembed, nhid, init_range, device_id)

Bases: convlab2.e2e.rnn_rollout.models.modules.CudaModule

A module that encodes dialogues context using an RNN.

forward(ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

convlab2.e2e.rnn_rollout.models.modules.init_cont(cont, init_range)

Initializes a container uniformly.

convlab2.e2e.rnn_rollout.models.modules.init_rnn(rnn, init_range, weights=None, biases=None)

Initializes RNN uniformly.

convlab2.e2e.rnn_rollout.models.modules.init_rnn_cell(rnn, init_range)

Initializes RNNCell uniformly.

convlab2.e2e.rnn_rollout.models.rnn_model module

class convlab2.e2e.rnn_rollout.models.rnn_model.RnnModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.WordCorpus

engine_ty

alias of convlab2.e2e.rnn_rollout.engines.rnn_engine.RnnEngine

flatten_parameters()
forward(inpt, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_context(ctx)
forward_lm(inpt_emb, lang_h, ctx_h)
forward_selection(inpt_emb, lang_h, ctx_h)
generate_choice_logits(inpt, lang_h, ctx_h)
init_weights()
read(inpt, lang_h, ctx_h, prefix_token='THEM:')
score_sent(sent, lang_h, ctx_h, temperature)
word2var(word)
write(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)

Generate a sentence word by word and feed the output of the previous timestep as input to the next.

write_batch(bsz, lang_h, ctx_h, temperature, max_words=100)
zero_h(bsz, nhid=None, copies=None)

convlab2.e2e.rnn_rollout.models.selection_model module

class convlab2.e2e.rnn_rollout.models.selection_model.SelectionModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.SentenceCorpus

engine_ty

alias of convlab2.e2e.rnn_rollout.engines.selection_engine.SelectionEngine

flatten_parameters()
forward(inpts, lens, rev_idxs, hid_idxs, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_each_timestamp(inpts, lens, rev_idxs, hid_idxs, ctx)
forward_inpts(inpts, ctx_h)
class convlab2.e2e.rnn_rollout.models.selection_model.SelectionModule(query_size, value_size, hidden_size, selection_size, num_heads, output_size, args)

Bases: torch.nn.modules.module.Module

flatten_parameters()
forward(q, hs, lens, rev_idxs, hid_idxs)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

convlab2.e2e.rnn_rollout.models.utils module

A set of useful tools.

convlab2.e2e.rnn_rollout.models.utils.init_cont(cont, init_range)

Uniform initialization of a container.

convlab2.e2e.rnn_rollout.models.utils.init_linear(linear, init_range)

Uniform initialization of Linear.

convlab2.e2e.rnn_rollout.models.utils.init_rnn(rnn, init_range, weights=None, biases=None)

Orthogonal initialization of RNN.

convlab2.e2e.rnn_rollout.models.utils.init_rnn_cell(rnn, init_range)

Orthogonal initialization of RNNCell.

convlab2.e2e.rnn_rollout.models.utils.make_mask(n, marked, value=- 1000)

Create a masked tensor.

Module contents

convlab2.e2e.rnn_rollout.models.get_model_names()
convlab2.e2e.rnn_rollout.models.get_model_type(name)