convlab2.e2e.rnn_rollout package

Submodules

convlab2.e2e.rnn_rollout.agent module

class convlab2.e2e.rnn_rollout.agent.RnnRolloutAgent(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)

Bases: convlab2.dialog_agent.agent.Agent

choose()
feed_context(context)
feed_partner_context(partner_context)
init_session()

Reset the class variables to prepare for a new session.

read(inpt)
response(observation)

Generate agent response given user input.

The data type of input and response can be either str or list of tuples, condition on the form of agent.

Example:

If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.

Args:
observation (str or list of tuples):

The input to the agent.

Returns:
response (str or list of tuples):

The response generated by the agent.

update(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)
write(max_words=20)

convlab2.e2e.rnn_rollout.avg_rank module

convlab2.e2e.rnn_rollout.chat module

convlab2.e2e.rnn_rollout.config module

Configuration script. Stores variables and settings used across application

convlab2.e2e.rnn_rollout.data module

class convlab2.e2e.rnn_rollout.data.CountDictionary(init=True)

Bases: convlab2.e2e.rnn_rollout.data.Dictionary

get_idx(words)
get_key()
read_tag(file_name, tag, init_dict=False)
class convlab2.e2e.rnn_rollout.data.Dictionary(init=True)

Bases: object

add_word(word)
get_idx(word)
get_word(idx)
i2w(idx)
read_tag(file_name, tag, freq_cutoff=- 1, init_dict=True)
w2i(words)
class convlab2.e2e.rnn_rollout.data.ItemDictionary(selection_size, init=True)

Bases: convlab2.e2e.rnn_rollout.data.Dictionary

read_tag(file_name, tag, init_dict=False)
w2i(words, inv=False)
class convlab2.e2e.rnn_rollout.data.PhraseCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: convlab2.e2e.rnn_rollout.data.WordCorpus

tokenize(file_name)
class convlab2.e2e.rnn_rollout.data.SentenceCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: convlab2.e2e.rnn_rollout.data.WordCorpus

class convlab2.e2e.rnn_rollout.data.WordCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: object

test_dataset(bsz, shuffle=True)
tokenize(file_name)
train_dataset(bsz, shuffle=True)
valid_dataset(bsz, shuffle=True)
convlab2.e2e.rnn_rollout.data.create_dicts_from_file(domain, file_name, freq_cutoff)
convlab2.e2e.rnn_rollout.data.get_tag(tokens, tag)
convlab2.e2e.rnn_rollout.data.read_lines(file_name)

convlab2.e2e.rnn_rollout.dialog module

class convlab2.e2e.rnn_rollout.dialog.Dialog(agents, args)

Bases: object

run(ctxs, logger, max_words=5000)
show_metrics()
class convlab2.e2e.rnn_rollout.dialog.DialogLogger(verbose=False, log_file=None, append=False)

Bases: object

CODE2ITEM = [('item0', 'book'), ('item1', 'hat'), ('item2', 'ball')]
dump(s, forced=False)
dump_agreement(agree)
dump_choice(name, choice)
dump_ctx(name, ctx)
dump_reward(name, agree, reward)
dump_sent(name, sent)
class convlab2.e2e.rnn_rollout.dialog.DialogSelfTrainLogger(verbose=False, log_file=None)

Bases: convlab2.e2e.rnn_rollout.dialog.DialogLogger

dump_agreement(agree)
dump_choice(name, choice)
dump_ctx(name, ctx)
dump_reward(name, agree, reward)

convlab2.e2e.rnn_rollout.domain module

class convlab2.e2e.rnn_rollout.domain.Domain

Bases: object

Domain interface.

generate_choices(input)
input_length()
parse_choice(choice)
parse_context(ctx)
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
class convlab2.e2e.rnn_rollout.domain.ObjectDivisionDomain

Bases: convlab2.e2e.rnn_rollout.domain.Domain

generate_choices(input, with_disagreement=True)
input_length()
num_choices()
parse_choice(choice)
parse_context(ctx)
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
class convlab2.e2e.rnn_rollout.domain.ObjectTradeDomain(max_items=1)

Bases: convlab2.e2e.rnn_rollout.domain.ObjectDivisionDomain

generate_choices(input)
input_length()
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
convlab2.e2e.rnn_rollout.domain.get_domain(name)

convlab2.e2e.rnn_rollout.eval_selfplay module

Script to evaluate selfplay. It computes agreement rate, average score and Pareto optimality.

convlab2.e2e.rnn_rollout.eval_selfplay.compute_score(vals, picks)

Compute the score of the selection.

convlab2.e2e.rnn_rollout.eval_selfplay.gen_choices(cnts, idx=0, choice=[])

Generate all the valid choices. It generates both yours and your opponent choices.

convlab2.e2e.rnn_rollout.eval_selfplay.main()
convlab2.e2e.rnn_rollout.eval_selfplay.parse_line(line, domain)
convlab2.e2e.rnn_rollout.eval_selfplay.parse_log(file_name, domain)

Parse the log file produced by selfplay. See the format of that log file to get more details.

convlab2.e2e.rnn_rollout.metric module

class convlab2.e2e.rnn_rollout.metric.AverageMetric

Bases: convlab2.e2e.rnn_rollout.metric.NumericMetric

show()
class convlab2.e2e.rnn_rollout.metric.MetricsContainer

Bases: object

dict()
record(name, *args, **kwargs)
register_average(name, *args, **kwargs)
register_moving_average(name, *args, **kwargs)
register_moving_percentage(name, *args, **kwargs)
register_ngram(name, *args, **kwargs)
register_percentage(name, *args, **kwargs)
register_similarity(name, *args, **kwargs)
register_time(name, *args, **kwargs)
register_uniqueness(name, *args, **kwargs)
reset()
show()
value(name)
class convlab2.e2e.rnn_rollout.metric.MovingAverageMetric(window=100)

Bases: convlab2.e2e.rnn_rollout.metric.MovingNumericMetric

show()
class convlab2.e2e.rnn_rollout.metric.MovingNumericMetric(window=100)

Bases: object

record(k)
reset()
value()
class convlab2.e2e.rnn_rollout.metric.MovingPercentageMetric(window=100)

Bases: convlab2.e2e.rnn_rollout.metric.MovingNumericMetric

show()
class convlab2.e2e.rnn_rollout.metric.NGramMetric(text, ngram=- 1)

Bases: convlab2.e2e.rnn_rollout.metric.TextMetric

record(sen)
class convlab2.e2e.rnn_rollout.metric.NumericMetric

Bases: object

record(k, n=1)
reset()
value()
class convlab2.e2e.rnn_rollout.metric.PercentageMetric

Bases: convlab2.e2e.rnn_rollout.metric.NumericMetric

show()
class convlab2.e2e.rnn_rollout.metric.SimilarityMetric

Bases: object

record(sen)
reset()
show()
value()
class convlab2.e2e.rnn_rollout.metric.TextMetric(text)

Bases: object

reset()
show()
value()
class convlab2.e2e.rnn_rollout.metric.TimeMetric

Bases: object

record(n=1)
reset()
show()
value()
class convlab2.e2e.rnn_rollout.metric.UniquenessMetric

Bases: object

record(sen)
reset()
show()
value()

convlab2.e2e.rnn_rollout.reinforce module

convlab2.e2e.rnn_rollout.rnn_model module

class convlab2.e2e.rnn_rollout.rnn_model.RnnModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

corpus_ty

alias of convlab2.e2e.rnn_rollout.data.WordCorpus

engine_ty

alias of convlab2.e2e.rnn_rollout.engines.rnn_engine.RnnEngine

flatten_parameters()
forward(inpt, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_context(ctx)
forward_lm(inpt_emb, lang_h, ctx_h)
forward_selection(inpt_emb, lang_h, ctx_h)
generate_choice_logits(inpt, lang_h, ctx_h)
init_weights()
read(inpt, lang_h, ctx_h, prefix_token='THEM:')
score_sent(sent, lang_h, ctx_h, temperature)
word2var(word)
write(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)

Generate a sentence word by word and feed the output of the previous timestep as input to the next.

write_batch(bsz, lang_h, ctx_h, temperature, max_words=100)
zero_h(bsz, nhid=None, copies=None)

convlab2.e2e.rnn_rollout.rnnrollout module

class convlab2.e2e.rnn_rollout.rnnrollout.RNNRolloutAgent(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)

Bases: convlab2.dialog_agent.agent.Agent

RNN dialog agent with rollout decoding.

choose()
feed_context(context)
feed_partner_context(partner_context)
get_reward()
init_session()

Reset the class variables to prepare for a new session.

is_terminated()
load_model()
read(inpt)
response(observation, max_words=20)

Generate agent response given user input.

The data type of input and response can be either str or list of tuples, condition on the form of agent.

Example:

If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.

Args:
observation (str or list of tuples):

The input to the agent.

Returns:
response (str or list of tuples):

The response generated by the agent.

update(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)
write(max_words=20)

convlab2.e2e.rnn_rollout.split module

convlab2.e2e.rnn_rollout.split.conv(line)
convlab2.e2e.rnn_rollout.split.dialog_len(line)
convlab2.e2e.rnn_rollout.split.find(tokens, tag)
convlab2.e2e.rnn_rollout.split.invert(cnts, sel)
convlab2.e2e.rnn_rollout.split.main()
convlab2.e2e.rnn_rollout.split.select(line)

convlab2.e2e.rnn_rollout.utils module

Various helpers.

class convlab2.e2e.rnn_rollout.utils.ContextGenerator(context_file)

Bases: object

Dialogue context generator. Generates contexes from the file.

iter(nepoch=1)
sample()
class convlab2.e2e.rnn_rollout.utils.ManualContextGenerator(num_types=3, num_objects=10, max_score=10)

Bases: object

Dialogue context generator. Takes contexes from stdin.

sample()
convlab2.e2e.rnn_rollout.utils.backward_hook(grad)

Hook for backward pass.

convlab2.e2e.rnn_rollout.utils.is_selection(out)

if dialog end

convlab2.e2e.rnn_rollout.utils.load_model(file_name, map_location=None)

Reads model from a file.

convlab2.e2e.rnn_rollout.utils.prob_random()

Prints out the states of various RNGs.

convlab2.e2e.rnn_rollout.utils.save_model(model, file_name)

Serializes model to a file.

convlab2.e2e.rnn_rollout.utils.set_seed(seed)

Sets random seed everywhere.

convlab2.e2e.rnn_rollout.utils.use_cuda(enabled, device_id=0)

Verifies if CUDA is available and sets default device to be device_id.

convlab2.e2e.rnn_rollout.vis module

A visualization library. Relies on visdom.

class convlab2.e2e.rnn_rollout.vis.ModulePlot(module, plot_weight=False, plot_grad=False, running_n=100)

Bases: object

A helper class that plots norms of weights and gradients for a given module.

update(x)
class convlab2.e2e.rnn_rollout.vis.Plot(metrics, title, ylabel, xlabel='t', running_n=100)

Bases: object

A class for plotting and updating the plot in real time.

update(metric, x, y)

Module contents