tatk.e2e.rnn_rollout.engines package

Submodules

tatk.e2e.rnn_rollout.engines.engine module

Training utilities.

class tatk.e2e.rnn_rollout.engines.engine.Criterion(dictionary, device_id=None, bad_toks=[], reduction='mean')

Bases: object

Weighted CrossEntropyLoss.

__init__(dictionary, device_id=None, bad_toks=[], reduction='mean')

Initialize self. See help(type(self)) for accurate signature.

class tatk.e2e.rnn_rollout.engines.engine.Engine(model, args, device_id=None, verbose=False)

Bases: object

The training engine.

Performs training and evaluation.

__init__(model, args, device_id=None, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

forward(batch, requires_grad=False)

A helper function to perform a forward pass on a batch.

get_model()

Extracts the model.

iter(N, epoch, lr, traindata, validdata)

Performs on iteration of the training. Runs one epoch on the training and validation datasets.

train(corpus)

Entry point.

train_pass(N, trainset)

Training pass.

train_single(N, trainset)

A helper function to train on a random batch.

valid_pass(N, validset, validset_stats)

Validation pass.

tatk.e2e.rnn_rollout.engines.latent_clustering_engine module

class tatk.e2e.rnn_rollout.engines.latent_clustering_engine.BaselineClusteringEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

train_batch(batch)
train_pass(trainset)
valid_batch(batch)
valid_pass(validset, validset_stats)
class tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

combine_loss(lang_loss, select_loss)
train_batch(batch)
train_pass(trainset)
valid_batch(batch)
valid_pass(validset, validset_stats)
class tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringLanguageEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

train_batch(batch)
valid_batch(batch)
class tatk.e2e.rnn_rollout.engines.latent_clustering_engine.LatentClusteringPredictionEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

train_batch(batch)
train_pass(trainset)
valid_batch(batch)
valid_pass(validset, validset_stats)

tatk.e2e.rnn_rollout.engines.rnn_engine module

class tatk.e2e.rnn_rollout.engines.rnn_engine.RnnEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

train_batch(batch)
valid_batch(batch)

tatk.e2e.rnn_rollout.engines.selection_engine module

class tatk.e2e.rnn_rollout.engines.selection_engine.SelectionEngine(model, args, verbose=False)

Bases: tatk.e2e.rnn_rollout.engines.EngineBase

__init__(model, args, verbose=False)

Initialize self. See help(type(self)) for accurate signature.

train_batch(batch)
valid_batch(batch)