convlab2.e2e.rnn_rollout package¶
Subpackages¶
- convlab2.e2e.rnn_rollout.deal_or_not package
- convlab2.e2e.rnn_rollout.engines package
- convlab2.e2e.rnn_rollout.models package
- Submodules
- convlab2.e2e.rnn_rollout.models.attn module
- convlab2.e2e.rnn_rollout.models.ctx_encoder module
- convlab2.e2e.rnn_rollout.models.latent_clustering_model module
- convlab2.e2e.rnn_rollout.models.modules module
- convlab2.e2e.rnn_rollout.models.rnn_model module
- convlab2.e2e.rnn_rollout.models.selection_model module
- convlab2.e2e.rnn_rollout.models.utils module
- Module contents
Submodules¶
convlab2.e2e.rnn_rollout.agent module¶
-
class
convlab2.e2e.rnn_rollout.agent.
RnnRolloutAgent
(model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20)¶ Bases:
convlab2.dialog_agent.agent.Agent
-
choose
()¶
-
feed_context
(context)¶
-
feed_partner_context
(partner_context)¶
-
init_session
()¶ Reset the class variables to prepare for a new session.
-
read
(inpt)¶
-
response
(observation)¶ Generate agent response given user input.
The data type of input and response can be either str or list of tuples, condition on the form of agent.
- Example:
If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.
- Args:
- observation (str or list of tuples):
The input to the agent.
- Returns:
- response (str or list of tuples):
The response generated by the agent.
-
update
(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)¶
-
write
(max_words=20)¶
-
convlab2.e2e.rnn_rollout.avg_rank module¶
convlab2.e2e.rnn_rollout.chat module¶
convlab2.e2e.rnn_rollout.config module¶
Configuration script. Stores variables and settings used across application
convlab2.e2e.rnn_rollout.data module¶
-
class
convlab2.e2e.rnn_rollout.data.
CountDictionary
(init=True)¶ Bases:
convlab2.e2e.rnn_rollout.data.Dictionary
-
get_idx
(words)¶
-
get_key
()¶
-
read_tag
(file_name, tag, init_dict=False)¶
-
-
class
convlab2.e2e.rnn_rollout.data.
Dictionary
(init=True)¶ Bases:
object
-
add_word
(word)¶
-
get_idx
(word)¶
-
get_word
(idx)¶
-
i2w
(idx)¶
-
read_tag
(file_name, tag, freq_cutoff=- 1, init_dict=True)¶
-
w2i
(words)¶
-
-
class
convlab2.e2e.rnn_rollout.data.
ItemDictionary
(selection_size, init=True)¶ Bases:
convlab2.e2e.rnn_rollout.data.Dictionary
-
read_tag
(file_name, tag, init_dict=False)¶
-
w2i
(words, inv=False)¶
-
-
class
convlab2.e2e.rnn_rollout.data.
PhraseCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶ Bases:
convlab2.e2e.rnn_rollout.data.WordCorpus
-
tokenize
(file_name)¶
-
-
class
convlab2.e2e.rnn_rollout.data.
SentenceCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶
-
class
convlab2.e2e.rnn_rollout.data.
WordCorpus
(domain, path, freq_cutoff=2, train='train.txt', valid='val.txt', test='test.txt', verbose=False, sep_sel=False)¶ Bases:
object
-
test_dataset
(bsz, shuffle=True)¶
-
tokenize
(file_name)¶
-
train_dataset
(bsz, shuffle=True)¶
-
valid_dataset
(bsz, shuffle=True)¶
-
-
convlab2.e2e.rnn_rollout.data.
create_dicts_from_file
(domain, file_name, freq_cutoff)¶
-
convlab2.e2e.rnn_rollout.data.
get_tag
(tokens, tag)¶
-
convlab2.e2e.rnn_rollout.data.
read_lines
(file_name)¶
convlab2.e2e.rnn_rollout.dialog module¶
-
class
convlab2.e2e.rnn_rollout.dialog.
Dialog
(agents, args)¶ Bases:
object
-
run
(ctxs, logger, max_words=5000)¶
-
show_metrics
()¶
-
-
class
convlab2.e2e.rnn_rollout.dialog.
DialogLogger
(verbose=False, log_file=None, append=False)¶ Bases:
object
-
CODE2ITEM
= [('item0', 'book'), ('item1', 'hat'), ('item2', 'ball')]¶
-
dump
(s, forced=False)¶
-
dump_agreement
(agree)¶
-
dump_choice
(name, choice)¶
-
dump_ctx
(name, ctx)¶
-
dump_reward
(name, agree, reward)¶
-
dump_sent
(name, sent)¶
-
convlab2.e2e.rnn_rollout.domain module¶
-
class
convlab2.e2e.rnn_rollout.domain.
Domain
¶ Bases:
object
Domain interface.
-
generate_choices
(input)¶
-
input_length
()¶
-
parse_choice
(choice)¶
-
parse_context
(ctx)¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
class
convlab2.e2e.rnn_rollout.domain.
ObjectDivisionDomain
¶ Bases:
convlab2.e2e.rnn_rollout.domain.Domain
-
generate_choices
(input, with_disagreement=True)¶
-
input_length
()¶
-
num_choices
()¶
-
parse_choice
(choice)¶
-
parse_context
(ctx)¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
class
convlab2.e2e.rnn_rollout.domain.
ObjectTradeDomain
(max_items=1)¶ Bases:
convlab2.e2e.rnn_rollout.domain.ObjectDivisionDomain
-
generate_choices
(input)¶
-
input_length
()¶
-
parse_human_choice
(input, output)¶
-
score
(context, choice)¶
-
score_choices
(choices, ctxs)¶
-
selection_length
()¶
-
-
convlab2.e2e.rnn_rollout.domain.
get_domain
(name)¶
convlab2.e2e.rnn_rollout.eval_selfplay module¶
Script to evaluate selfplay. It computes agreement rate, average score and Pareto optimality.
-
convlab2.e2e.rnn_rollout.eval_selfplay.
compute_score
(vals, picks)¶ Compute the score of the selection.
-
convlab2.e2e.rnn_rollout.eval_selfplay.
gen_choices
(cnts, idx=0, choice=[])¶ Generate all the valid choices. It generates both yours and your opponent choices.
-
convlab2.e2e.rnn_rollout.eval_selfplay.
main
()¶
-
convlab2.e2e.rnn_rollout.eval_selfplay.
parse_line
(line, domain)¶
-
convlab2.e2e.rnn_rollout.eval_selfplay.
parse_log
(file_name, domain)¶ Parse the log file produced by selfplay. See the format of that log file to get more details.
convlab2.e2e.rnn_rollout.metric module¶
-
class
convlab2.e2e.rnn_rollout.metric.
AverageMetric
¶ Bases:
convlab2.e2e.rnn_rollout.metric.NumericMetric
-
show
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
MetricsContainer
¶ Bases:
object
-
dict
()¶
-
record
(name, *args, **kwargs)¶
-
register_average
(name, *args, **kwargs)¶
-
register_moving_average
(name, *args, **kwargs)¶
-
register_moving_percentage
(name, *args, **kwargs)¶
-
register_ngram
(name, *args, **kwargs)¶
-
register_percentage
(name, *args, **kwargs)¶
-
register_similarity
(name, *args, **kwargs)¶
-
register_time
(name, *args, **kwargs)¶
-
register_uniqueness
(name, *args, **kwargs)¶
-
reset
()¶
-
show
()¶
-
value
(name)¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
MovingAverageMetric
(window=100)¶ Bases:
convlab2.e2e.rnn_rollout.metric.MovingNumericMetric
-
show
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
MovingNumericMetric
(window=100)¶ Bases:
object
-
record
(k)¶
-
reset
()¶
-
value
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
MovingPercentageMetric
(window=100)¶ Bases:
convlab2.e2e.rnn_rollout.metric.MovingNumericMetric
-
show
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
NGramMetric
(text, ngram=- 1)¶ Bases:
convlab2.e2e.rnn_rollout.metric.TextMetric
-
record
(sen)¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
NumericMetric
¶ Bases:
object
-
record
(k, n=1)¶
-
reset
()¶
-
value
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
PercentageMetric
¶ Bases:
convlab2.e2e.rnn_rollout.metric.NumericMetric
-
show
()¶
-
-
class
convlab2.e2e.rnn_rollout.metric.
SimilarityMetric
¶ Bases:
object
-
record
(sen)¶
-
reset
()¶
-
show
()¶
-
value
()¶
-
convlab2.e2e.rnn_rollout.reinforce module¶
convlab2.e2e.rnn_rollout.rnn_model module¶
-
class
convlab2.e2e.rnn_rollout.rnn_model.
RnnModel
(word_dict, item_dict, context_dict, count_dict, args)¶ Bases:
torch.nn.modules.module.Module
-
corpus_ty
¶
-
engine_ty
¶ alias of
convlab2.e2e.rnn_rollout.engines.rnn_engine.RnnEngine
-
flatten_parameters
()¶
-
forward
(inpt, ctx)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_context
(ctx)¶
-
forward_lm
(inpt_emb, lang_h, ctx_h)¶
-
forward_selection
(inpt_emb, lang_h, ctx_h)¶
-
generate_choice_logits
(inpt, lang_h, ctx_h)¶
-
init_weights
()¶
-
read
(inpt, lang_h, ctx_h, prefix_token='THEM:')¶
-
score_sent
(sent, lang_h, ctx_h, temperature)¶
-
word2var
(word)¶
-
write
(lang_h, ctx_h, max_words, temperature, stop_tokens=['<eos>', '<selection>'], resume=False)¶ Generate a sentence word by word and feed the output of the previous timestep as input to the next.
-
write_batch
(bsz, lang_h, ctx_h, temperature, max_words=100)¶
-
zero_h
(bsz, nhid=None, copies=None)¶
-
convlab2.e2e.rnn_rollout.rnnrollout module¶
-
class
convlab2.e2e.rnn_rollout.rnnrollout.
RNNRolloutAgent
(model, sel_model, args, name='Alice', train=False, diverse=False, max_total_len=100)¶ Bases:
convlab2.dialog_agent.agent.Agent
RNN dialog agent with rollout decoding.
-
choose
()¶
-
feed_context
(context)¶
-
feed_partner_context
(partner_context)¶
-
get_reward
()¶
-
init_session
()¶ Reset the class variables to prepare for a new session.
-
is_terminated
()¶
-
load_model
()¶
-
read
(inpt)¶
-
response
(observation, max_words=20)¶ Generate agent response given user input.
The data type of input and response can be either str or list of tuples, condition on the form of agent.
- Example:
If the agent is a pipeline agent with NLU, DST and Policy, then type(input) == str and type(response) == list of tuples.
- Args:
- observation (str or list of tuples):
The input to the agent.
- Returns:
- response (str or list of tuples):
The response generated by the agent.
-
update
(agree, reward, choice=None, partner_choice=None, partner_input=None, partner_reward=None)¶
-
write
(max_words=20)¶
-
convlab2.e2e.rnn_rollout.split module¶
-
convlab2.e2e.rnn_rollout.split.
conv
(line)¶
-
convlab2.e2e.rnn_rollout.split.
dialog_len
(line)¶
-
convlab2.e2e.rnn_rollout.split.
find
(tokens, tag)¶
-
convlab2.e2e.rnn_rollout.split.
invert
(cnts, sel)¶
-
convlab2.e2e.rnn_rollout.split.
main
()¶
-
convlab2.e2e.rnn_rollout.split.
select
(line)¶
convlab2.e2e.rnn_rollout.utils module¶
Various helpers.
-
class
convlab2.e2e.rnn_rollout.utils.
ContextGenerator
(context_file)¶ Bases:
object
Dialogue context generator. Generates contexes from the file.
-
iter
(nepoch=1)¶
-
sample
()¶
-
-
class
convlab2.e2e.rnn_rollout.utils.
ManualContextGenerator
(num_types=3, num_objects=10, max_score=10)¶ Bases:
object
Dialogue context generator. Takes contexes from stdin.
-
sample
()¶
-
-
convlab2.e2e.rnn_rollout.utils.
backward_hook
(grad)¶ Hook for backward pass.
-
convlab2.e2e.rnn_rollout.utils.
is_selection
(out)¶ if dialog end
-
convlab2.e2e.rnn_rollout.utils.
load_model
(file_name, map_location=None)¶ Reads model from a file.
-
convlab2.e2e.rnn_rollout.utils.
prob_random
()¶ Prints out the states of various RNGs.
-
convlab2.e2e.rnn_rollout.utils.
save_model
(model, file_name)¶ Serializes model to a file.
-
convlab2.e2e.rnn_rollout.utils.
set_seed
(seed)¶ Sets random seed everywhere.
-
convlab2.e2e.rnn_rollout.utils.
use_cuda
(enabled, device_id=0)¶ Verifies if CUDA is available and sets default device to be device_id.
convlab2.e2e.rnn_rollout.vis module¶
A visualization library. Relies on visdom.