tatk.dst.sumbt package¶
Subpackages¶
Submodules¶
tatk.dst.sumbt.sumbt module¶
-
class
tatk.dst.sumbt.sumbt.
BeliefTracker
¶ Bases:
torch.nn.modules.module.Module
-
__init__
()¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(input_ids, input_len, labels, n_gpu=1, target_slot=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
static
init_parameter
(module)¶
-
init_session
(num_labels)¶
-
initialize_slot_value_lookup
(label_ids, slot_ids)¶
-
-
class
tatk.dst.sumbt.sumbt.
BertForUtteranceEncoding
(config)¶ Bases:
pytorch_pretrained_bert.modeling.BertPreTrainedModel
-
__init__
(config)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
forward
(input_ids, token_type_ids, attention_mask, output_all_encoded_layers=False)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
tatk.dst.sumbt.sumbt.
DataProcessor
¶ Bases:
object
Base class for data converters for sequence classification data sets.
-
get_labels
()¶ Gets the list of labels for this data set.
-
-
class
tatk.dst.sumbt.sumbt.
InputExample
(guid, text_a, text_b=None, label=None)¶ Bases:
object
A single training/test example for simple sequence classification.
-
__init__
(guid, text_a, text_b=None, label=None)¶ Constructs a InputExample.
- Args:
guid: Unique id for the example. text_a: string. The untokenized text of the first sequence. For single sequence tasks, only this sequence must be specified. text_b: (Optional) string. The untokenized text of the second sequence. Only must be specified for sequence pair tasks. label: (Optional) string. The label of the example. This should be specified for train and dev examples, but not for test examples.
-
-
class
tatk.dst.sumbt.sumbt.
InputFeatures
(input_ids, input_len, label_id)¶ Bases:
object
A single set of features of data.
-
__init__
(input_ids, input_len, label_id)¶ Initialize self. See help(type(self)) for accurate signature.
-
-
class
tatk.dst.sumbt.sumbt.
MultiHeadAttention
(heads, d_model, dropout=0.1)¶ Bases:
torch.nn.modules.module.Module
-
__init__
(heads, d_model, dropout=0.1)¶ Initializes internal Module state, shared by both nn.Module and ScriptModule.
-
attention
(q, k, v, d_k, mask=None, dropout=None)¶
-
forward
(q, k, v, mask=None)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
get_scores
()¶
-
-
class
tatk.dst.sumbt.sumbt.
Processor
¶ Bases:
tatk.dst.sumbt.sumbt.DataProcessor
Processor for the belief tracking dataset (GLUE version).
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
get_dev_examples
(data_dir, accumulation=False)¶ See base class.
-
get_labels
()¶ See base class.
-
get_test_examples
(data_dir, accumulation=False)¶ See base class.
-
get_train_examples
(data_dir, accumulation=False)¶ See base class.
-
-
class
tatk.dst.sumbt.sumbt.
SUMBTTracker
¶ Bases:
tatk.dst.dst.DST
-
__init__
()¶ Initialize self. See help(type(self)) for accurate signature.
-
add_track
(dialog_history)¶
-
load_weights
()¶
-
test
()¶ Model testing entry point
-
train
()¶ Model training entry point
-
update
(action)¶ Update the internal dialog state variable. update state[‘user_action’] with input action
- Args:
- action (str or list of tuples):
The type is str when DST is word-level (such as NBT), and list of tuples when it is DA-level.
- Returns:
- new_state (dict):
Updated dialog state, with the same form of previous state.
-
update_batch
(batch_action=None)¶
-