convlab2.e2e.rnn_rollout.engines package¶
Submodules¶
convlab2.e2e.rnn_rollout.engines.engine module¶
Training utilities.
-
class
convlab2.e2e.rnn_rollout.engines.engine.
Criterion
(dictionary, device_id=None, bad_toks=[], reduction='mean')¶ Bases:
object
Weighted CrossEntropyLoss.
-
class
convlab2.e2e.rnn_rollout.engines.engine.
Engine
(model, args, device_id=None, verbose=False)¶ Bases:
object
The training engine.
Performs training and evaluation.
-
forward
(batch, requires_grad=False)¶ A helper function to perform a forward pass on a batch.
-
get_model
()¶ Extracts the model.
-
iter
(N, epoch, lr, traindata, validdata)¶ Performs on iteration of the training. Runs one epoch on the training and validation datasets.
-
train
(corpus)¶ Entry point.
-
train_pass
(N, trainset)¶ Training pass.
-
train_single
(N, trainset)¶ A helper function to train on a random batch.
-
valid_pass
(N, validset, validset_stats)¶ Validation pass.
-
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine module¶
-
class
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.
BaselineClusteringEngine
(model, args, verbose=False)¶ Bases:
convlab2.e2e.rnn_rollout.engines.EngineBase
-
train_batch
(batch)¶
-
train_pass
(trainset)¶
-
valid_batch
(batch)¶
-
valid_pass
(validset, validset_stats)¶
-
-
class
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.
LatentClusteringEngine
(model, args, verbose=False)¶ Bases:
convlab2.e2e.rnn_rollout.engines.EngineBase
-
combine_loss
(lang_loss, select_loss)¶
-
train_batch
(batch)¶
-
train_pass
(trainset)¶
-
valid_batch
(batch)¶
-
valid_pass
(validset, validset_stats)¶
-
-
class
convlab2.e2e.rnn_rollout.engines.latent_clustering_engine.
LatentClusteringLanguageEngine
(model, args, verbose=False)¶ Bases:
convlab2.e2e.rnn_rollout.engines.EngineBase
-
train_batch
(batch)¶
-
valid_batch
(batch)¶
-
convlab2.e2e.rnn_rollout.engines.rnn_engine module¶
-
class
convlab2.e2e.rnn_rollout.engines.rnn_engine.
RnnEngine
(model, args, verbose=False)¶ Bases:
convlab2.e2e.rnn_rollout.engines.EngineBase
-
train_batch
(batch)¶
-
valid_batch
(batch)¶
-
convlab2.e2e.rnn_rollout.engines.selection_engine module¶
-
class
convlab2.e2e.rnn_rollout.engines.selection_engine.
SelectionEngine
(model, args, verbose=False)¶ Bases:
convlab2.e2e.rnn_rollout.engines.EngineBase
-
train_batch
(batch)¶
-
valid_batch
(batch)¶
-
Module contents¶
-
class
convlab2.e2e.rnn_rollout.engines.
Criterion
(dictionary, device_id=None, bad_toks=[], reduction='mean')¶ Bases:
object
Weighted CrossEntropyLoss.
-
class
convlab2.e2e.rnn_rollout.engines.
EngineBase
(model, args, verbose=False)¶ Bases:
object
Base class for training engine.
-
combine_loss
(lang_loss, select_loss)¶
-
get_model
()¶
-
iter
(epoch, lr, traindata, validdata)¶
-
make_opt
(lr)¶
-
train
(corpus)¶
-
train_batch
(batch)¶
-
train_pass
(trainset)¶
-
valid_batch
(batch)¶
-
valid_pass
(validset, validset_stats)¶
-