Training
Example Use
Training a model involves three main steps:
- Create a model configuration and instance using
ModelFactory
- Configure the training parameters using
TrainingConfig
- Initialize the
ModelTrainer
and start training
See use-examples/train_model.py
for a full example.
Source Code
directmultistep.training.config
TrainingConfig
dataclass
Source code in src/directmultistep/training/config.py
load(path)
classmethod
Load config from YAML file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to config file |
required |
Returns:
Type | Description |
---|---|
TrainingConfig
|
Loaded config object |
Source code in src/directmultistep/training/config.py
save(path)
Save config to YAML file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to save config file |
required |
Source code in src/directmultistep/training/config.py
directmultistep.training.trainer
ModelTrainer
High-level trainer class that orchestrates the training process.
Source code in src/directmultistep/training/trainer.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
|
__init__(config)
Initialize trainer with configuration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config
|
TrainingConfig
|
Training configuration |
required |
train(model, train_dataset, val_dataset)
Train the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
Model to train |
required |
train_dataset
|
RoutesDataset
|
Training dataset |
required |
val_dataset
|
RoutesDataset
|
Validation dataset |
required |
Source code in src/directmultistep/training/trainer.py
directmultistep.training.lightning
LTraining
Bases: LightningModule
A PyTorch Lightning module for training sequence-to-sequence models.
Source code in src/directmultistep/training/lightning.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
|
__init__(pad_idx, mask_idx, lr, batch_size, warmup_steps=4000, decay_steps=24000, decay_factor=0.1, model=None, criterion=None, processed_tokens=0, start_idx=0)
Initializes the PLTraining module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
pad_idx
|
int
|
The index of the padding token. |
required |
mask_idx
|
int
|
The index of the mask token. |
required |
lr
|
float
|
The initial learning rate. |
required |
batch_size
|
int
|
The batch size. |
required |
warmup_steps
|
int
|
The number of warmup steps for the learning rate scheduler. |
4000
|
decay_steps
|
int
|
The number of decay steps for the learning rate scheduler. |
24000
|
decay_factor
|
float
|
The decay factor for the learning rate scheduler. |
0.1
|
model
|
Module | None
|
The sequence-to-sequence model. |
None
|
criterion
|
Module | None
|
The loss function. |
None
|
processed_tokens
|
int
|
The number of tokens processed so far. |
0
|
start_idx
|
int
|
The index of the start token. |
0
|
Source code in src/directmultistep/training/lightning.py
compute_loss_full(batch, batch_idx)
Computes the loss for the full sequence training.
This method calculates the loss for all tokens in the sequence.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Tensor
|
The input batch tensor. |
required |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The computed loss tensor. |
Source code in src/directmultistep/training/lightning.py
configure_optimizers()
Configures the optimizer and learning rate scheduler.
Returns:
Type | Description |
---|---|
list[Optimizer]
|
A tuple containing the list of optimizers and the list of |
list[dict[str, Any]]
|
learning rate schedulers. |
Source code in src/directmultistep/training/lightning.py
log_step_info(loss, mode, prog_bar)
Logs the loss and other training information.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
loss
|
Tensor
|
The loss tensor. |
required |
mode
|
str
|
The mode of training ('train' or 'val'). |
required |
prog_bar
|
bool
|
Whether to display the loss in the progress bar. |
required |
Source code in src/directmultistep/training/lightning.py
mask_src(src_BC, masking_prob)
Masks the source sequence with a given probability.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
src_BC
|
Tensor
|
The source sequence tensor of shape [B, C]. |
required |
masking_prob
|
float
|
The probability of masking a token. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The masked source sequence tensor of shape [B, C]. |
Source code in src/directmultistep/training/lightning.py
on_load_checkpoint(checkpoint)
Loads the processed tokens from the checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint
|
dict[str, Any]
|
The checkpoint dictionary. |
required |
Source code in src/directmultistep/training/lightning.py
on_save_checkpoint(checkpoint)
Adds the processed tokens to the checkpoint.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
checkpoint
|
dict[str, Any]
|
The checkpoint dictionary. |
required |
Source code in src/directmultistep/training/lightning.py
training_step(batch, batch_idx)
Performs a single training step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Tensor
|
The input batch tensor. |
required |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The computed loss tensor. |
Source code in src/directmultistep/training/lightning.py
validation_step(batch, batch_idx)
Performs a single validation step.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
batch
|
Tensor
|
The input batch tensor. |
required |
batch_idx
|
int
|
The index of the batch. |
required |
Returns:
Type | Description |
---|---|
Tensor
|
The computed loss tensor. |
Source code in src/directmultistep/training/lightning.py
warmup_and_cosine_decay(warmup_steps, decay_steps, decay_factor)
Creates a learning rate schedule with warmup and cosine decay.
The learning rate increases linearly during the warmup phase, then decreases following a cosine function during the decay phase, and finally remains constant at the decay factor.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
warmup_steps
|
int
|
The number of steps for the warmup phase. |
required |
decay_steps
|
int
|
The number of steps for the decay phase. |
required |
decay_factor
|
float
|
The final learning rate factor after decay. |
required |
Returns:
Type | Description |
---|---|
Callable[[int], float]
|
A function that takes the current step as input and returns the |
Callable[[int], float]
|
corresponding learning rate factor. |