tatk.e2e.rnn_rollout package¶
Subpackages¶
- tatk.e2e.rnn_rollout.deal_or_not package
- tatk.e2e.rnn_rollout.engines package
- tatk.e2e.rnn_rollout.models package
- Submodules
- tatk.e2e.rnn_rollout.models.attn module
- tatk.e2e.rnn_rollout.models.ctx_encoder module
- tatk.e2e.rnn_rollout.models.latent_clustering_model module
- tatk.e2e.rnn_rollout.models.modules module
- tatk.e2e.rnn_rollout.models.rnn_model module
- tatk.e2e.rnn_rollout.models.selection_model module
- tatk.e2e.rnn_rollout.models.utils module
Submodules¶
tatk.e2e.rnn_rollout.agent module¶
-
class
tatk.e2e.rnn_rollout.agent.
RnnRolloutAgent
(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)¶ Bases:
tatk.dialog_agent.agent.Agent
-
__init__
(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)¶ Initialize self. See help(type(self)) for accurate signature.
-
choose
()¶
-
feed_context
(context)¶
-
feed_partner_context
(partner_context)¶
-
init_session
()¶ Reset the class variables to prepare for a new session.
-
read
(inpt)¶
-
response
(observation)¶ Generate agent response given user input.
The data type of input and response can be either str or list of tuples, condition on the form of agent.
- Example:
If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.
- Args:
- observation (str or list of tuples):
The input to the agent.
- Returns:
- response (str or list of tuples):
The response generated by the agent.
-
update
(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)¶
-
write
(max_words=20)¶
-
tatk.e2e.rnn_rollout.avg_rank module¶
tatk.e2e.rnn_rollout.chat module¶
tatk.e2e.rnn_rollout.config module¶
Configuration script. Stores variables and settings used across application
tatk.e2e.rnn_rollout.data module¶
-
class
tatk.e2e.rnn_rollout.data.
CountDictionary
(init=True)¶ Bases:
tatk.e2e.rnn_rollout.data.Dictionary
-
__init__
(init=True)¶ Initialize self. See help(type(self)) for accurate signature.
-
get_idx
(words)¶
-
get_key
()¶
-
read_tag
(file_name, tag, init_dict=False)¶
-
-
class
tatk.e2e.rnn_rollout.data.
Dictionary
(init=True)¶ Bases:
object
-
__init__
(init=True)¶ Initialize self. See help(type(self)) for accurate signature.
-
add_word
(word)¶
-
get_idx
(word)¶
-
get_word
(idx)¶
-
i2w
(idx)¶
-
read_tag
(file_name, tag, freq_cutoff=-1, init_dict=True)¶
-
w2i
(words)¶
-
-
class
tatk.e2e.rnn_rollout.data.
ItemDictionary
(selection_size, init=True)¶ Bases:
tatk.e2e.rnn_rollout.data.Dictionary
-
__init__
(selection_size, init=True)¶ Initialize self. See help(type(self)) for accurate signature.
-
read_tag
(file_name, tag, init_dict=False)¶
-
w2i
(words, inv=False)¶
-
-
class
tatk.e2e.rnn_rollout.data.
PhraseCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶ Bases:
tatk.e2e.rnn_rollout.data.WordCorpus
-
tokenize
(file_name)¶
-
-
class
tatk.e2e.rnn_rollout.data.
SentenceCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶
-
class
tatk.e2e.rnn_rollout.data.
WordCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶ Bases:
object
-
__init__
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶ Initialize self. See help(type(self)) for accurate signature.
-
test_dataset
(bsz, shuffle=True)¶
-
tokenize
(file_name)¶
-
train_dataset
(bsz, shuffle=True)¶
-
valid_dataset
(bsz, shuffle=True)¶
-
-
tatk.e2e.rnn_rollout.data.
create_dicts_from_file
(domain, file_name, freq_cutoff)¶
-
tatk.e2e.rnn_rollout.data.
get_tag
(tokens, tag)¶
-
tatk.e2e.rnn_rollout.data.
read_lines
(file_name)¶
tatk.e2e.rnn_rollout.dialog module¶
-
class
tatk.e2e.rnn_rollout.dialog.
Dialog
(agents, args)¶ Bases:
object
-
__init__
(agents, args)¶ Initialize self. See help(type(self)) for accurate signature.
-
run
(ctxs, logger, max_words=5000)¶
-
show_metrics
()¶
-
-
class
tatk.e2e.rnn_rollout.dialog.
DialogLogger
(verbose=False, log_file=None, append=False)¶ Bases:
object
-
CODE2ITEM
= [('item0', 'book'), ('item1', 'hat'), ('item2', 'ball')]¶
-
__init__
(verbose=False, log_file=None, append=False)¶ Initialize self. See help(type(self)) for accurate signature.
-
dump
(s, forced=False)¶
-
dump_agreement
(agree)¶
-
dump_choice
(name, choice)¶
-
dump_ctx
(name, ctx)¶
-
dump_reward
(name, agree, reward)¶
-
dump_sent
(name, sent)¶
-
-
class
tatk.e2e.rnn_rollout.dialog.
DialogSelfTrainLogger
(verbose=False, log_file=None)¶ Bases:
tatk.e2e.rnn_rollout.dialog.DialogLogger
-
__init__
(verbose=False, log_file=None)¶ Initialize self. See help(type(self)) for accurate signature.
-
dump_agreement
(agree)¶
-
dump_choice
(name, choice)¶
-
dump_ctx
(name, ctx)¶
-
dump_reward
(name, agree, reward)¶
-
tatk.e2e.rnn_rollout.domain module¶
-
class
tatk.e2e.rnn_rollout.domain.
Domain
¶ Bases:
object
Domain interface.
-
generate_choices
(input)¶
-
input_length
()¶
-
parse_choice
(choice)¶
-
parse_context
(ctx)¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
class
tatk.e2e.rnn_rollout.domain.
ObjectDivisionDomain
¶ Bases:
tatk.e2e.rnn_rollout.domain.Domain
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
generate_choices
(input, with_disagreement=True)¶
-
input_length
()¶
-
num_choices
()¶
-
parse_choice
(choice)¶
-
parse_context
(ctx)¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
class
tatk.e2e.rnn_rollout.domain.
ObjectTradeDomain
(max_items=1)¶ Bases:
tatk.e2e.rnn_rollout.domain.ObjectDivisionDomain
-
__init__
(max_items=1)¶ Initialize self. See help(type(self)) for accurate signature.
-
generate_choices
(input)¶
-
input_length
()¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
tatk.e2e.rnn_rollout.domain.
get_domain
(name)¶
tatk.e2e.rnn_rollout.eval_selfplay module¶
Script to evaluate selfplay. It computes agreement rate, average score and Pareto optimality.
-
tatk.e2e.rnn_rollout.eval_selfplay.
compute_score
(vals, picks)¶ Compute the score of the selection.
-
tatk.e2e.rnn_rollout.eval_selfplay.
gen_choices
(cnts, idx=0, choice=[])¶ Generate all the valid choices. It generates both yours and your opponent choices.
-
tatk.e2e.rnn_rollout.eval_selfplay.
main
()¶
-
tatk.e2e.rnn_rollout.eval_selfplay.
parse_line
(line, domain)¶
-
tatk.e2e.rnn_rollout.eval_selfplay.
parse_log
(file_name, domain)¶ Parse the log file produced by selfplay. See the format of that log file to get more details.
tatk.e2e.rnn_rollout.metric module¶
-
class
tatk.e2e.rnn_rollout.metric.
AverageMetric
¶ Bases:
tatk.e2e.rnn_rollout.metric.NumericMetric
-
show
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
MetricsContainer
¶ Bases:
object
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
dict
()¶
-
record
(name, *args, **kwargs)¶
-
register_average
(name, *args, **kwargs)¶
-
register_moving_average
(name, *args, **kwargs)¶
-
register_moving_percentage
(name, *args, **kwargs)¶
-
register_ngram
(name, *args, **kwargs)¶
-
register_percentage
(name, *args, **kwargs)¶
-
register_similarity
(name, *args, **kwargs)¶
-
register_time
(name, *args, **kwargs)¶
-
register_uniqueness
(name, *args, **kwargs)¶
-
reset
()¶
-
show
()¶
-
value
(name)¶
-
-
class
tatk.e2e.rnn_rollout.metric.
MovingAverageMetric
(window=100)¶ Bases:
tatk.e2e.rnn_rollout.metric.MovingNumericMetric
-
show
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
MovingNumericMetric
(window=100)¶ Bases:
object
-
__init__
(window=100)¶ Initialize self. See help(type(self)) for accurate signature.
-
record
(k)¶
-
reset
()¶
-
value
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
MovingPercentageMetric
(window=100)¶ Bases:
tatk.e2e.rnn_rollout.metric.MovingNumericMetric
-
show
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
NGramMetric
(text, ngram=-1)¶ Bases:
tatk.e2e.rnn_rollout.metric.TextMetric
-
__init__
(text, ngram=-1)¶ Initialize self. See help(type(self)) for accurate signature.
-
record
(sen)¶
-
-
class
tatk.e2e.rnn_rollout.metric.
NumericMetric
¶ Bases:
object
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
record
(k, n=1)¶
-
reset
()¶
-
value
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
PercentageMetric
¶ Bases:
tatk.e2e.rnn_rollout.metric.NumericMetric
-
show
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
SimilarityMetric
¶ Bases:
object
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
record
(sen)¶
-
reset
()¶
-
show
()¶
-
value
()¶
-
-
class
tatk.e2e.rnn_rollout.metric.
TextMetric
(text)¶ Bases:
object
-
__init__
(text)¶ Initialize self. See help(type(self)) for accurate signature.
-
reset
()¶
-
show
()¶
-
value
()¶
-
tatk.e2e.rnn_rollout.reinforce module¶
tatk.e2e.rnn_rollout.rnn_model module¶
-
class
tatk.e2e.rnn_rollout.rnn_model.
RnnModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(word_dict, item_dict, context_dict, count_dict, args)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
corpus_ty
¶ alias of
tatk.e2e.rnn_rollout.data.WordCorpus
-
engine_ty
¶
-
flatten_parameters
()¶
-
forward
(inpt, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_context
(ctx)¶
-
forward_lm
(inpt_emb, lang_h, ctx_h)¶
-
forward_selection
(inpt_emb, lang_h, ctx_h)¶
-
generate_choice_logits
(inpt, lang_h, ctx_h)¶
-
init_weights
()¶
-
read
(inpt, lang_h, ctx_h, prefix_token='THEM:')¶
-
score_sent
(sent, lang_h, ctx_h, temperature)¶
-
word2var
(word)¶
-
write
(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)¶ Generate a sentence word by word and feed the output of the previous timestep as input to the next.
-
write_batch
(bsz, lang_h, ctx_h, temperature, max_words=100)¶
-
zero_h
(bsz, nhid=None, copies=None)¶
-
tatk.e2e.rnn_rollout.rnnrollout module¶
-
class
tatk.e2e.rnn_rollout.rnnrollout.
RNNRolloutAgent
(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)¶ Bases:
tatk.dialog_agent.agent.Agent
RNN dialog agent with rollout decoding.
-
__init__
(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)¶ Constructor of RNNRollout model.
-
choose
()¶
-
feed_context
(context)¶
-
feed_partner_context
(partner_context)¶
-
get_reward
()¶
-
init_session
()¶ Reset the class variables to prepare for a new session.
-
is_terminated
()¶
-
load_model
()¶
-
read
(inpt)¶
-
response
(observation, max_words=20)¶ Generate agent response given user input.
The data type of input and response can be either str or list of tuples, condition on the form of agent.
- Example:
If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.
- Args:
- observation (str or list of tuples):
The input to the agent.
- Returns:
- response (str or list of tuples):
The response generated by the agent.
-
update
(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)¶
-
write
(max_words=20)¶
-
tatk.e2e.rnn_rollout.split module¶
-
tatk.e2e.rnn_rollout.split.
conv
(line)¶
-
tatk.e2e.rnn_rollout.split.
dialog_len
(line)¶
-
tatk.e2e.rnn_rollout.split.
find
(tokens, tag)¶
-
tatk.e2e.rnn_rollout.split.
invert
(cnts, sel)¶
-
tatk.e2e.rnn_rollout.split.
main
()¶
-
tatk.e2e.rnn_rollout.split.
select
(line)¶
tatk.e2e.rnn_rollout.utils module¶
Various helpers.
-
class
tatk.e2e.rnn_rollout.utils.
ContextGenerator
(context_file)¶ Bases:
object
Dialogue context generator. Generates contexes from the file.
-
__init__
(context_file)¶ Initialize self. See help(type(self)) for accurate signature.
-
iter
(nepoch=1)¶
-
sample
()¶
-
-
class
tatk.e2e.rnn_rollout.utils.
ManualContextGenerator
(num_types=3, num_objects=10, max_score=10)¶ Bases:
object
Dialogue context generator. Takes contexes from stdin.
-
__init__
(num_types=3, num_objects=10, max_score=10)¶ Initialize self. See help(type(self)) for accurate signature.
-
sample
()¶
-
-
tatk.e2e.rnn_rollout.utils.
backward_hook
(grad)¶ Hook for backward pass.
-
tatk.e2e.rnn_rollout.utils.
is_selection
(out)¶ if dialog end
-
tatk.e2e.rnn_rollout.utils.
load_model
(file_name, map_location=None)¶ Reads model from a file.
-
tatk.e2e.rnn_rollout.utils.
prob_random
()¶ Prints out the states of various RNGs.
-
tatk.e2e.rnn_rollout.utils.
save_model
(model, file_name)¶ Serializes model to a file.
-
tatk.e2e.rnn_rollout.utils.
set_seed
(seed)¶ Sets random seed everywhere.
-
tatk.e2e.rnn_rollout.utils.
use_cuda
(enabled, device_id=0)¶ Verifies if CUDA is available and sets default device to be device_id.
tatk.e2e.rnn_rollout.vis module¶
A visualization library. Relies on visdom.
-
class
tatk.e2e.rnn_rollout.vis.
ModulePlot
(module, plot_weight=False, plot_grad=False, running_n=100)¶ Bases:
object
A helper class that plots norms of weights and gradients for a given module.
-
__init__
(module, plot_weight=False, plot_grad=False, running_n=100)¶ Initialize self. See help(type(self)) for accurate signature.
-
update
(x)¶
-
-
class
tatk.e2e.rnn_rollout.vis.
Plot
(metrics, title, ylabel, xlabel='t', running_n=100)¶ Bases:
object
A class for plotting and updating the plot in real time.
-
__init__
(metrics, title, ylabel, xlabel='t', running_n=100)¶ Initialize self. See help(type(self)) for accurate signature.
-
update
(metric, x, y)¶
-