tatk.dst.sumbt package

Submodules

tatk.dst.sumbt.sumbt module

class tatk.dst.sumbt.sumbt.BeliefTracker

Bases: torch.nn.modules.module.Module

__init__()

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input_ids, input_len, labels, n_gpu=1, target_slot=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

static init_parameter(module)
init_session(num_labels)
initialize_slot_value_lookup(label_ids, slot_ids)
class tatk.dst.sumbt.sumbt.BertForUtteranceEncoding(config)

Bases: pytorch_pretrained_bert.modeling.BertPreTrainedModel

__init__(config)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

forward(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class tatk.dst.sumbt.sumbt.DataProcessor

Bases: object

Base class for data converters for sequence classification data sets.

get_dev_examples(data_dir)

Gets a collection of `InputExample`s for the dev set.

get_labels()

Gets the list of labels for this data set.

get_train_examples(data_dir)

Gets a collection of `InputExample`s for the train set.

class tatk.dst.sumbt.sumbt.InputExample(guid, text_a, text_b=None, label=None)

Bases: object

A single training/test example for simple sequence classification.

__init__(guid, text_a, text_b=None, label=None)

Constructs a InputExample.

Args:

guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples.

class tatk.dst.sumbt.sumbt.InputFeatures(input_ids, input_len, label_id)

Bases: object

A single set of features of data.

__init__(input_ids, input_len, label_id)

Initialize self. See help(type(self)) for accurate signature.

class tatk.dst.sumbt.sumbt.MultiHeadAttention(heads, d_model, dropout=0.1)

Bases: torch.nn.modules.module.Module

__init__(heads, d_model, dropout=0.1)

Initializes internal Module state, shared by both nn.Module and ScriptModule.

attention(q, k, v, d_k, mask=None, dropout=None)
forward(q, k, v, mask=None)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

get_scores()
class tatk.dst.sumbt.sumbt.Processor

Bases: tatk.dst.sumbt.sumbt.DataProcessor

Processor for the belief tracking dataset (GLUE version).

__init__()

Initialize self. See help(type(self)) for accurate signature.

get_dev_examples(data_dir, accumulation=False)

See base class.

get_labels()

See base class.

get_test_examples(data_dir, accumulation=False)

See base class.

get_train_examples(data_dir, accumulation=False)

See base class.

class tatk.dst.sumbt.sumbt.SUMBTTracker

Bases: tatk.dst.dst.DST

__init__()

Initialize self. See help(type(self)) for accurate signature.

add_track(dialog_history)
load_weights()
test()

Model testing entry point

train()

Model training entry point

update(action)

Update the internal dialog state variable. update state[‘user_action’] with input action

Args:
action (str or list of tuples):

The type is str when DST is word-level (such as NBT), and list of tuples when it is DA-level.

Returns:
new_state (dict):

Updated dialog state, with the same form of previous state.

update_batch(batch_action=None)
tatk.dst.sumbt.sumbt.convert_examples_to_features(examples, label_list, max_seq_length, tokenizer, max_turn_length)

Loads a data file into a list of `InputBatch`s.