tatk.dst.sumbt.multiwoz package

Submodules

tatk.dst.sumbt.multiwoz.convert_to_glue_format module

tatk.dst.sumbt.multiwoz.convert_to_glue_format.convert_to_glue_format()

tatk.dst.sumbt.multiwoz.sumbt module

class tatk.dst.sumbt.multiwoz.sumbt.MultiWozSUMBT

Bases: tatk.dst.sumbt.sumbt.SUMBTTracker

__init__()

Initialize self. See help(type(self)) for accurate signature.

test()

Model testing entry point

train(load_model=False, start_epoch=0, start_step=0)

Model training entry point

tatk.dst.sumbt.multiwoz.sumbt.eval_all_accs(pred_slot, labels, accuracies)
tatk.dst.sumbt.multiwoz.sumbt.get_label_embedding(labels, max_seq_length, tokenizer, device)
tatk.dst.sumbt.multiwoz.sumbt.warmup_linear(x, warmup=0.002)