tatk.e2e.rnn_rollout.engines package¶
Submodules¶
tatk.e2e.rnn_rollout.engines.engine module¶
Training utilities.
- 
class 
tatk.e2e.rnn_rollout.engines.engine.Criterion(dictionary, device_id=None, bad_toks=[], reduction='mean')¶ Bases:
objectWeighted CrossEntropyLoss.
- 
__init__(dictionary, device_id=None, bad_toks=[], reduction='mean')¶ Initialize self. See help(type(self)) for accurate signature.
- 
 
- 
class 
tatk.e2e.rnn_rollout.engines.engine.Engine(model, args, device_id=None, verbose=False)¶ Bases:
objectThe training engine.
Performs training and evaluation.
- 
__init__(model, args, device_id=None, verbose=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
forward(batch, requires_grad=False)¶ A helper function to perform a forward pass on a batch.
- 
get_model()¶ Extracts the model.
- 
iter(N, epoch, lr, traindata, validdata)¶ Performs on iteration of the training. Runs one epoch on the training and validation datasets.
- 
train(corpus)¶ Entry point.
- 
train_pass(N, trainset)¶ Training pass.
- 
train_single(N, trainset)¶ A helper function to train on a random batch.
- 
valid_pass(N, validset, validset_stats)¶ Validation pass.
- 
 
tatk.e2e.rnn_rollout.engines.latent_clustering_engine module¶
- 
class 
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.BaselineClusteringEngine(model, args, verbose=False)¶ Bases:
tatk.e2e.rnn_rollout.engines.EngineBase- 
__init__(model, args, verbose=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
train_batch(batch)¶ 
- 
train_pass(trainset)¶ 
- 
valid_batch(batch)¶ 
- 
valid_pass(validset, validset_stats)¶ 
- 
 
- 
class 
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringEngine(model, args, verbose=False)¶ Bases:
tatk.e2e.rnn_rollout.engines.EngineBase- 
__init__(model, args, verbose=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
combine_loss(lang_loss, select_loss)¶ 
- 
train_batch(batch)¶ 
- 
train_pass(trainset)¶ 
- 
valid_batch(batch)¶ 
- 
valid_pass(validset, validset_stats)¶ 
- 
 
- 
class 
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine(model, args, verbose=False)¶ Bases:
tatk.e2e.rnn_rollout.engines.EngineBase- 
__init__(model, args, verbose=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
train_batch(batch)¶ 
- 
valid_batch(batch)¶ 
- 
 
- 
class 
tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine(model, args, verbose=False)¶ Bases:
tatk.e2e.rnn_rollout.engines.EngineBase- 
__init__(model, args, verbose=False)¶ Initialize self. See help(type(self)) for accurate signature.
- 
train_batch(batch)¶ 
- 
train_pass(trainset)¶ 
- 
valid_batch(batch)¶ 
- 
valid_pass(validset, validset_stats)¶ 
-