tatk.e2e.rnn_rollout package

Submodules

tatk.e2e.rnn_rollout.agent module

class tatk.e2e.rnn_rollout.agent.RnnRolloutAgent(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)

Bases: tatk.dialog_agent.agent.Agent

__init__(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)

Initialize self. See help(type(self)) for accurate signature.

choose()
feed_context(context)
feed_partner_context(partner_context)
init_session()

Reset the class variables to prepare for a new session.

read(inpt)
response(observation)

Generate agent response given user input.

The data type of input and response can be either str or list of tuples, condition on the form of agent.

Example:

If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.

Args:
observation (str or list of tuples):

The input to the agent.

Returns:
response (str or list of tuples):

The response generated by the agent.

update(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)
write(max_words=20)

tatk.e2e.rnn_rollout.avg_rank module

tatk.e2e.rnn_rollout.chat module

tatk.e2e.rnn_rollout.config module

Configuration script. Stores variables and settings used across application

tatk.e2e.rnn_rollout.data module

class tatk.e2e.rnn_rollout.data.CountDictionary(init=True)

Bases: tatk.e2e.rnn_rollout.data.Dictionary

__init__(init=True)

Initialize self. See help(type(self)) for accurate signature.

get_idx(words)
get_key()
read_tag(file_name, tag, init_dict=False)
class tatk.e2e.rnn_rollout.data.Dictionary(init=True)

Bases: object

__init__(init=True)

Initialize self. See help(type(self)) for accurate signature.

add_word(word)
get_idx(word)
get_word(idx)
i2w(idx)
read_tag(file_name, tag, freq_cutoff=-1, init_dict=True)
w2i(words)
class tatk.e2e.rnn_rollout.data.ItemDictionary(selection_size, init=True)

Bases: tatk.e2e.rnn_rollout.data.Dictionary

__init__(selection_size, init=True)

Initialize self. See help(type(self)) for accurate signature.

read_tag(file_name, tag, init_dict=False)
w2i(words, inv=False)
class tatk.e2e.rnn_rollout.data.PhraseCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: tatk.e2e.rnn_rollout.data.WordCorpus

tokenize(file_name)
class tatk.e2e.rnn_rollout.data.SentenceCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: tatk.e2e.rnn_rollout.data.WordCorpus

class tatk.e2e.rnn_rollout.data.WordCorpus(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Bases: object

__init__(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)

Initialize self. See help(type(self)) for accurate signature.

test_dataset(bsz, shuffle=True)
tokenize(file_name)
train_dataset(bsz, shuffle=True)
valid_dataset(bsz, shuffle=True)
tatk.e2e.rnn_rollout.data.create_dicts_from_file(domain, file_name, freq_cutoff)
tatk.e2e.rnn_rollout.data.get_tag(tokens, tag)
tatk.e2e.rnn_rollout.data.read_lines(file_name)

tatk.e2e.rnn_rollout.dialog module

class tatk.e2e.rnn_rollout.dialog.Dialog(agents, args)

Bases: object

__init__(agents, args)

Initialize self. See help(type(self)) for accurate signature.

run(ctxs, logger, max_words=5000)
show_metrics()
class tatk.e2e.rnn_rollout.dialog.DialogLogger(verbose=False, log_file=None, append=False)

Bases: object

CODE2ITEM = [('item0', 'book'), ('item1', 'hat'), ('item2', 'ball')]
__init__(verbose=False, log_file=None, append=False)

Initialize self. See help(type(self)) for accurate signature.

dump(s, forced=False)
dump_agreement(agree)
dump_choice(name, choice)
dump_ctx(name, ctx)
dump_reward(name, agree, reward)
dump_sent(name, sent)
class tatk.e2e.rnn_rollout.dialog.DialogSelfTrainLogger(verbose=False, log_file=None)

Bases: tatk.e2e.rnn_rollout.dialog.DialogLogger

__init__(verbose=False, log_file=None)

Initialize self. See help(type(self)) for accurate signature.

dump_agreement(agree)
dump_choice(name, choice)
dump_ctx(name, ctx)
dump_reward(name, agree, reward)

tatk.e2e.rnn_rollout.domain module

class tatk.e2e.rnn_rollout.domain.Domain

Bases: object

Domain interface.

generate_choices(input)
input_length()
parse_choice(choice)
parse_context(ctx)
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
class tatk.e2e.rnn_rollout.domain.ObjectDivisionDomain

Bases: tatk.e2e.rnn_rollout.domain.Domain

__init__()

Initialize self. See help(type(self)) for accurate signature.

generate_choices(input, with_disagreement=True)
input_length()
num_choices()
parse_choice(choice)
parse_context(ctx)
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
class tatk.e2e.rnn_rollout.domain.ObjectTradeDomain(max_items=1)

Bases: tatk.e2e.rnn_rollout.domain.ObjectDivisionDomain

__init__(max_items=1)

Initialize self. See help(type(self)) for accurate signature.

generate_choices(input)
input_length()
parse_human_choice(input, output)
score(context, choice)
score_choices(choices, ctxs)
selection_length()
tatk.e2e.rnn_rollout.domain.get_domain(name)

tatk.e2e.rnn_rollout.eval_selfplay module

Script to evaluate selfplay. It computes agreement rate, average score and Pareto optimality.

tatk.e2e.rnn_rollout.eval_selfplay.compute_score(vals, picks)

Compute the score of the selection.

tatk.e2e.rnn_rollout.eval_selfplay.gen_choices(cnts, idx=0, choice=[])

Generate all the valid choices. It generates both yours and your opponent choices.

tatk.e2e.rnn_rollout.eval_selfplay.main()
tatk.e2e.rnn_rollout.eval_selfplay.parse_line(line, domain)
tatk.e2e.rnn_rollout.eval_selfplay.parse_log(file_name, domain)

Parse the log file produced by selfplay. See the format of that log file to get more details.

tatk.e2e.rnn_rollout.metric module

class tatk.e2e.rnn_rollout.metric.AverageMetric

Bases: tatk.e2e.rnn_rollout.metric.NumericMetric

show()
class tatk.e2e.rnn_rollout.metric.MetricsContainer

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

dict()
record(name, *args, **kwargs)
register_average(name, *args, **kwargs)
register_moving_average(name, *args, **kwargs)
register_moving_percentage(name, *args, **kwargs)
register_ngram(name, *args, **kwargs)
register_percentage(name, *args, **kwargs)
register_similarity(name, *args, **kwargs)
register_time(name, *args, **kwargs)
register_uniqueness(name, *args, **kwargs)
reset()
show()
value(name)
class tatk.e2e.rnn_rollout.metric.MovingAverageMetric(window=100)

Bases: tatk.e2e.rnn_rollout.metric.MovingNumericMetric

show()
class tatk.e2e.rnn_rollout.metric.MovingNumericMetric(window=100)

Bases: object

__init__(window=100)

Initialize self. See help(type(self)) for accurate signature.

record(k)
reset()
value()
class tatk.e2e.rnn_rollout.metric.MovingPercentageMetric(window=100)

Bases: tatk.e2e.rnn_rollout.metric.MovingNumericMetric

show()
class tatk.e2e.rnn_rollout.metric.NGramMetric(text, ngram=-1)

Bases: tatk.e2e.rnn_rollout.metric.TextMetric

__init__(text, ngram=-1)

Initialize self. See help(type(self)) for accurate signature.

record(sen)
class tatk.e2e.rnn_rollout.metric.NumericMetric

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

record(k, n=1)
reset()
value()
class tatk.e2e.rnn_rollout.metric.PercentageMetric

Bases: tatk.e2e.rnn_rollout.metric.NumericMetric

show()
class tatk.e2e.rnn_rollout.metric.SimilarityMetric

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

record(sen)
reset()
show()
value()
class tatk.e2e.rnn_rollout.metric.TextMetric(text)

Bases: object

__init__(text)

Initialize self. See help(type(self)) for accurate signature.

reset()
show()
value()
class tatk.e2e.rnn_rollout.metric.TimeMetric

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

record(n=1)
reset()
show()
value()
class tatk.e2e.rnn_rollout.metric.UniquenessMetric

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

record(sen)
reset()
show()
value()

tatk.e2e.rnn_rollout.reinforce module

tatk.e2e.rnn_rollout.rnn_model module

class tatk.e2e.rnn_rollout.rnn_model.RnnModel(word_dict, item_dict, context_dict, count_dict, args)

Bases: torch.nn.modules.module.Module

__init__(word_dict, item_dict, context_dict, count_dict, args)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

corpus_ty

alias of tatk.e2e.rnn_rollout.data.WordCorpus

engine_ty

alias of tatk.e2e.rnn_rollout.engines.rnn_engine.RnnEngine

flatten_parameters()
forward(inpt, ctx)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

forward_context(ctx)
forward_lm(inpt_emb, lang_h, ctx_h)
forward_selection(inpt_emb, lang_h, ctx_h)
generate_choice_logits(inpt, lang_h, ctx_h)
init_weights()
read(inpt, lang_h, ctx_h, prefix_token='THEM:')
score_sent(sent, lang_h, ctx_h, temperature)
word2var(word)
write(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)

Generate a sentence word by word and feed the output of the previous timestep as input to the next.

write_batch(bsz, lang_h, ctx_h, temperature, max_words=100)
zero_h(bsz, nhid=None, copies=None)

tatk.e2e.rnn_rollout.rnnrollout module

class tatk.e2e.rnn_rollout.rnnrollout.RNNRolloutAgent(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)

Bases: tatk.dialog_agent.agent.Agent

RNN dialog agent with rollout decoding.

__init__(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)

Constructor of RNNRollout model.

choose()
feed_context(context)
feed_partner_context(partner_context)
get_reward()
init_session()

Reset the class variables to prepare for a new session.

is_terminated()
load_model()
read(inpt)
response(observation, max_words=20)

Generate agent response given user input.

The data type of input and response can be either str or list of tuples, condition on the form of agent.

Example:

If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.

Args:
observation (str or list of tuples):

The input to the agent.

Returns:
response (str or list of tuples):

The response generated by the agent.

update(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)
write(max_words=20)

tatk.e2e.rnn_rollout.split module

tatk.e2e.rnn_rollout.split.conv(line)
tatk.e2e.rnn_rollout.split.dialog_len(line)
tatk.e2e.rnn_rollout.split.find(tokens, tag)
tatk.e2e.rnn_rollout.split.invert(cnts, sel)
tatk.e2e.rnn_rollout.split.main()
tatk.e2e.rnn_rollout.split.select(line)

tatk.e2e.rnn_rollout.utils module

Various helpers.

class tatk.e2e.rnn_rollout.utils.ContextGenerator(context_file)

Bases: object

Dialogue context generator. Generates contexes from the file.

__init__(context_file)

Initialize self. See help(type(self)) for accurate signature.

iter(nepoch=1)
sample()
class tatk.e2e.rnn_rollout.utils.ManualContextGenerator(num_types=3, num_objects=10, max_score=10)

Bases: object

Dialogue context generator. Takes contexes from stdin.

__init__(num_types=3, num_objects=10, max_score=10)

Initialize self. See help(type(self)) for accurate signature.

sample()
tatk.e2e.rnn_rollout.utils.backward_hook(grad)

Hook for backward pass.

tatk.e2e.rnn_rollout.utils.is_selection(out)

if dialog end

tatk.e2e.rnn_rollout.utils.load_model(file_name, map_location=None)

Reads model from a file.

tatk.e2e.rnn_rollout.utils.prob_random()

Prints out the states of various RNGs.

tatk.e2e.rnn_rollout.utils.save_model(model, file_name)

Serializes model to a file.

tatk.e2e.rnn_rollout.utils.set_seed(seed)

Sets random seed everywhere.

tatk.e2e.rnn_rollout.utils.use_cuda(enabled, device_id=0)

Verifies if CUDA is available and sets default device to be device_id.

tatk.e2e.rnn_rollout.vis module

A visualization library. Relies on visdom.

class tatk.e2e.rnn_rollout.vis.ModulePlot(module, plot_weight=False, plot_grad=False, running_n=100)

Bases: object

A helper class that plots norms of weights and gradients for a given module.

__init__(module, plot_weight=False, plot_grad=False, running_n=100)

Initialize self. See help(type(self)) for accurate signature.

update(x)
class tatk.e2e.rnn_rollout.vis.Plot(metrics, title, ylabel, xlabel='t', running_n=100)

Bases: object

A class for plotting and updating the plot in real time.

__init__(metrics, title, ylabel, xlabel='t', running_n=100)

Initialize self. See help(type(self)) for accurate signature.

update(metric, x, y)