tatk.policy.mle.crosswoz package

Submodules

tatk.policy.mle.crosswoz.evaluate module

tatk.policy.mle.crosswoz.evaluate.calculateF1(predict_golden)
tatk.policy.mle.crosswoz.evaluate.da_evaluate_simulation(policy)
tatk.policy.mle.crosswoz.evaluate.end2end_evaluate_simulation(policy)
tatk.policy.mle.crosswoz.evaluate.evaluate_corpus_f1(policy, data, goal_type=None)
tatk.policy.mle.crosswoz.evaluate.read_zipped_json(filepath, filename)

tatk.policy.mle.crosswoz.loader module

class tatk.policy.mle.crosswoz.loader.Dataset(s_s, a_s)

Bases: torch.utils.data.dataset.Dataset

__init__(s_s, a_s)

Initialize self. See help(type(self)) for accurate signature.

class tatk.policy.mle.crosswoz.loader.PolicyDataLoaderCrossWoz

Bases: object

__init__()

Initialize self. See help(type(self)) for accurate signature.

create_dataset(part, batchsz)

tatk.policy.mle.crosswoz.mle module

class tatk.policy.mle.crosswoz.mle.MLE(archive_file='/home/travis/build/thu-coai/tatk/tatk/policy/mle/crosswoz/models/mle_policy_crosswoz.zip', model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/mle_policy_multiwoz.zip')

Bases: tatk.policy.mle.mle.MLEAbstract

__init__(archive_file='/home/travis/build/thu-coai/tatk/tatk/policy/mle/crosswoz/models/mle_policy_crosswoz.zip', model_file='https://tatk-data.s3-ap-northeast-1.amazonaws.com/mle_policy_multiwoz.zip')

Initialize self. See help(type(self)) for accurate signature.

tatk.policy.mle.crosswoz.train module

class tatk.policy.mle.crosswoz.train.MLE_Trainer(manager, cfg)

Bases: object

__init__(manager, cfg)

Initialize self. See help(type(self)) for accurate signature.

imit_test(epoch, best)

provide an unbiased evaluation of the policy fit on the training dataset

imitating(epoch)

pretrain the policy by simple imitation learning (behavioral cloning)

load(filename='save/best')
policy_loop(data)
save(directory, epoch)
test()