Skip to content

Training

Example Use

Training a model involves three main steps:

  1. Create a model configuration and instance using ModelFactory
  2. Configure the training parameters using TrainingConfig
  3. Initialize the ModelTrainer and start training

See use-examples/train_model.py for a full example.

Source Code

directmultistep.training.config

TrainingConfig dataclass

Source code in src/directmultistep/training/config.py
@dataclass
class TrainingConfig:
    # Data configs
    data_path: Path

    # Training setup
    run_name: str
    train_fname: str
    val_fname: str
    metadata_fname: str

    # Training hyperparameters
    batch_size: int
    learning_rate: float
    max_epochs: int

    # Scheduler configs
    warmup_steps: int
    decay_steps: int
    decay_factor: float

    pad_idx: int
    mask_idx: int

    # Checkpointing
    save_top_k: int = -1
    checkpoint_every_n_epochs: int = 2

    num_workers: int = 1
    n_devices: int = 1
    seed: int = 42

    accelerator: str = "auto"
    matmul_precision: str = "high"
    summary_depth: int = 2
    dist_strategy: str = "ddp_find_unused_parameters_true"

    gradient_clip_val: float = 1.0
    gradient_clip_algorithm: str = "value"

    def __post_init__(self) -> None:
        self.data_path.mkdir(parents=True, exist_ok=True)
        self.run_name = f"{self.run_name}_seed={self.seed}"

        if self.matmul_precision not in ["high", "medium", "low"]:
            raise ValueError(f"{self.matmul_precision=} must be one of 'high', 'medium', or 'low'")

        if self.dist_strategy not in ["auto", "fsdp", "ddp", "ddp_spawn", "ddp_find_unused_parameters_true"]:
            raise ValueError(
                f"{self.dist_strategy=} must be one of 'fsdp', 'ddp', 'ddp_spawn', or 'ddp_find_unused_parameters_true'"
            )

        if self.gradient_clip_algorithm not in ["norm", "value"]:
            raise ValueError(f"{self.gradient_clip_algorithm=} must be one of 'norm' or 'value'")

    def save(self, path: Path) -> None:
        """Save config to YAML file.

        Args:
            path: Path to save config file
        """
        config_dict = asdict(self)
        config_dict["data_path"] = str(config_dict["data_path"])

        with open(path, "w") as f:
            yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False)

    @classmethod
    def load(cls, path: Path) -> "TrainingConfig":
        """Load config from YAML file.

        Args:
            path: Path to config file

        Returns:
            Loaded config object
        """
        with open(path) as f:
            config_dict = yaml.safe_load(f)

        config_dict["data_path"] = Path(config_dict["data_path"])
        instance = cls.__new__(cls)
        for key, value in config_dict.items():
            setattr(instance, key, value)
        return instance

load(path) classmethod

Load config from YAML file.

Parameters:

Name Type Description Default
path Path

Path to config file

required

Returns:

Type Description
TrainingConfig

Loaded config object

Source code in src/directmultistep/training/config.py
@classmethod
def load(cls, path: Path) -> "TrainingConfig":
    """Load config from YAML file.

    Args:
        path: Path to config file

    Returns:
        Loaded config object
    """
    with open(path) as f:
        config_dict = yaml.safe_load(f)

    config_dict["data_path"] = Path(config_dict["data_path"])
    instance = cls.__new__(cls)
    for key, value in config_dict.items():
        setattr(instance, key, value)
    return instance

save(path)

Save config to YAML file.

Parameters:

Name Type Description Default
path Path

Path to save config file

required
Source code in src/directmultistep/training/config.py
def save(self, path: Path) -> None:
    """Save config to YAML file.

    Args:
        path: Path to save config file
    """
    config_dict = asdict(self)
    config_dict["data_path"] = str(config_dict["data_path"])

    with open(path, "w") as f:
        yaml.safe_dump(config_dict, f, default_flow_style=False, sort_keys=False)

directmultistep.training.trainer

ModelTrainer

High-level trainer class that orchestrates the training process.

Source code in src/directmultistep/training/trainer.py
class ModelTrainer:
    """High-level trainer class that orchestrates the training process."""

    def __init__(self, config: TrainingConfig):
        """Initialize trainer with configuration.

        Args:
            config: Training configuration
        """
        self.config = config
        self._setup_environment()

    def _setup_environment(self) -> None:
        """Configure training environment."""
        L.seed_everything(self.config.seed)
        torch.set_float32_matmul_precision(self.config.matmul_precision)

    def _create_lightning_module(self, model: torch.nn.Module) -> LTraining:
        """Create the Lightning training module.

        Args:
            model: The model to train

        Returns:
            Configured PLTraining module
        """
        criterion = torch.nn.CrossEntropyLoss(ignore_index=self.config.pad_idx, reduction="mean")

        return LTraining(
            model=model,
            pad_idx=self.config.pad_idx,
            mask_idx=self.config.mask_idx,
            criterion=criterion,
            lr=self.config.learning_rate,
            batch_size=self.config.batch_size,
            warmup_steps=self.config.warmup_steps,
            decay_steps=self.config.decay_steps,
            decay_factor=self.config.decay_factor,
        )

    def _setup_callbacks(self) -> list[Any]:
        """Configure training callbacks.

        Returns:
            List of Lightning callbacks
        """
        checkpoint_callback = ModelCheckpoint(
            monitor="val_loss",
            dirpath=self.config.data_path / "training" / self.config.run_name,
            save_last=True,
            save_top_k=self.config.save_top_k,
            every_n_epochs=self.config.checkpoint_every_n_epochs,
        )

        return [checkpoint_callback, RichModelSummary(max_depth=self.config.summary_depth)]

    def _create_trainer(self) -> L.Trainer:
        """Create Lightning trainer.

        Returns:
            Configured Lightning trainer
        """
        return L.Trainer(
            default_root_dir=self.config.data_path / "training" / self.config.run_name,
            max_epochs=self.config.max_epochs,
            accelerator=self.config.accelerator,
            devices=self.config.n_devices,
            num_nodes=1,
            strategy=self.config.dist_strategy,
            callbacks=self._setup_callbacks(),
            gradient_clip_val=self.config.gradient_clip_val,
            gradient_clip_algorithm=self.config.gradient_clip_algorithm,
        )

    def _create_dataloaders(
        self,
        train_dataset: RoutesDataset,
        val_dataset: RoutesDataset,
    ) -> tuple[DataLoader[tuple[Tensor, ...]], DataLoader[tuple[Tensor, ...]]]:
        """Create training and validation dataloaders.

        Args:
            train_dataset: Training dataset
            val_dataset: Validation dataset

        Returns:
            Tuple of (train_dataloader, val_dataloader)
        """
        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            persistent_workers=True,
            pin_memory=True,
        )

        val_loader = torch.utils.data.DataLoader(
            dataset=val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            persistent_workers=True,
            pin_memory=True,
        )

        return train_loader, val_loader

    def train(
        self,
        model: torch.nn.Module,
        train_dataset: RoutesDataset,
        val_dataset: RoutesDataset,
    ) -> None:
        """Train the model.

        Args:
            model: Model to train
            train_dataset: Training dataset
            val_dataset: Validation dataset
        """
        lightning_model = self._create_lightning_module(model)
        trainer = self._create_trainer()
        dl_train, dl_val = self._create_dataloaders(train_dataset, val_dataset)
        latest_ckpt = helpers.find_checkpoint(self.config.data_path / "training", self.config.run_name)

        if latest_ckpt is not None:
            print(f"Loading model from {latest_ckpt}")
            trainer.fit(lightning_model, dl_train, dl_val, ckpt_path=latest_ckpt)
        else:
            trainer.fit(lightning_model, dl_train, dl_val)

__init__(config)

Initialize trainer with configuration.

Parameters:

Name Type Description Default
config TrainingConfig

Training configuration

required
Source code in src/directmultistep/training/trainer.py
def __init__(self, config: TrainingConfig):
    """Initialize trainer with configuration.

    Args:
        config: Training configuration
    """
    self.config = config
    self._setup_environment()

train(model, train_dataset, val_dataset)

Train the model.

Parameters:

Name Type Description Default
model Module

Model to train

required
train_dataset RoutesDataset

Training dataset

required
val_dataset RoutesDataset

Validation dataset

required
Source code in src/directmultistep/training/trainer.py
def train(
    self,
    model: torch.nn.Module,
    train_dataset: RoutesDataset,
    val_dataset: RoutesDataset,
) -> None:
    """Train the model.

    Args:
        model: Model to train
        train_dataset: Training dataset
        val_dataset: Validation dataset
    """
    lightning_model = self._create_lightning_module(model)
    trainer = self._create_trainer()
    dl_train, dl_val = self._create_dataloaders(train_dataset, val_dataset)
    latest_ckpt = helpers.find_checkpoint(self.config.data_path / "training", self.config.run_name)

    if latest_ckpt is not None:
        print(f"Loading model from {latest_ckpt}")
        trainer.fit(lightning_model, dl_train, dl_val, ckpt_path=latest_ckpt)
    else:
        trainer.fit(lightning_model, dl_train, dl_val)

directmultistep.training.lightning

LTraining

Bases: LightningModule

A PyTorch Lightning module for training sequence-to-sequence models.

Source code in src/directmultistep/training/lightning.py
class LTraining(pl.LightningModule):
    """A PyTorch Lightning module for training sequence-to-sequence models."""

    def __init__(
        self,
        pad_idx: int,
        mask_idx: int,
        lr: float,
        batch_size: int,
        warmup_steps: int = 4000,
        decay_steps: int = 24000,
        decay_factor: float = 0.1,
        model: nn.Module | None = None,
        criterion: nn.Module | None = None,
        processed_tokens: int = 0,
        start_idx: int = 0,
    ):
        """Initializes the PLTraining module.

        Args:
            pad_idx: The index of the padding token.
            mask_idx: The index of the mask token.
            lr: The initial learning rate.
            batch_size: The batch size.
            warmup_steps: The number of warmup steps for the learning rate scheduler.
            decay_steps: The number of decay steps for the learning rate scheduler.
            decay_factor: The decay factor for the learning rate scheduler.
            model: The sequence-to-sequence model.
            criterion: The loss function.
            processed_tokens: The number of tokens processed so far.
            start_idx: The index of the start token.
        """
        super().__init__()
        if model is not None:
            self.model = model
        if criterion is not None:
            self.criterion = criterion
        self.start_idx = start_idx
        self.pad_idx = pad_idx
        self.mask_idx = mask_idx
        self.learning_rate = lr
        self.batch_size = batch_size
        self.warmup_steps = warmup_steps
        self.decay_steps = decay_steps
        self.decay_factor = decay_factor
        self.processed_tokens = processed_tokens
        self.save_hyperparameters(ignore=["criterion", "model"])
        self.compute_loss = self.compute_loss_full

    def mask_src(self, src_BC: Tensor, masking_prob: float) -> Tensor:
        """Masks the source sequence with a given probability.

        Args:
            src_BC: The source sequence tensor of shape [B, C].
            masking_prob: The probability of masking a token.

        Returns:
            The masked source sequence tensor of shape [B, C].
        """
        mask_idx_BC = torch.rand(src_BC.shape).to(src_BC.device) < masking_prob
        not_pad_BC = src_BC != self.pad_idx
        final_mask_BC = mask_idx_BC & not_pad_BC
        masked_src_BC = src_BC.clone()
        masked_src_BC[final_mask_BC] = self.mask_idx
        return masked_src_BC

    def compute_loss_full(self, batch: Tensor, batch_idx: int) -> Tensor:
        """Computes the loss for the full sequence training.

        This method calculates the loss for all tokens in the sequence.

        Args:
            batch: The input batch tensor.
            batch_idx: The index of the batch.

        Returns:
            The computed loss tensor.
        """
        src_item_BC = batch[0]
        tgt_item_BL = batch[1].long()
        steps_B1 = batch[2].view(-1, 1)
        masked_src_BC = self.mask_src(src_item_BC, masking_prob=0.05)
        # the output actually is [B, L-1, V] given slicing of tgt_item_BL
        output_BLV = self.model(masked_src_BC, tgt_item_BL[:, :-1], steps_B1)
        output_blV = output_BLV.view(-1, output_BLV.shape[-1])  # [B*(L-1), V]
        tgt_bl = tgt_item_BL[:, 1:].reshape(-1)  # [B*(L-1)]
        loss = self.criterion(output_blV, tgt_bl)
        self.processed_tokens += tgt_item_BL.shape[0] * tgt_item_BL.shape[1]
        return cast(Tensor, loss)

    def log_step_info(self, loss: Tensor, mode: str, prog_bar: bool) -> None:
        """Logs the loss and other training information.

        Args:
            loss: The loss tensor.
            mode: The mode of training ('train' or 'val').
            prog_bar: Whether to display the loss in the progress bar.
        """
        self.log(
            f"{mode}_loss",
            loss,
            batch_size=self.batch_size,
            prog_bar=prog_bar,
            sync_dist=True,
        )
        self.log("processed_tokens", self.processed_tokens, sync_dist=True)
        if mode == "train":
            current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
            self.log(f"{mode}_lr", current_lr, batch_size=self.batch_size, sync_dist=True)

    def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        """Performs a single training step.

        Args:
            batch: The input batch tensor.
            batch_idx: The index of the batch.

        Returns:
            The computed loss tensor.
        """
        loss = self.compute_loss(batch, batch_idx)
        self.log_step_info(loss, "train", prog_bar=True)
        return loss

    def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
        """Performs a single validation step.

        Args:
            batch: The input batch tensor.
            batch_idx: The index of the batch.

        Returns:
            The computed loss tensor.
        """
        loss = self.compute_loss(batch, batch_idx)
        self.log_step_info(loss, "val", prog_bar=True)
        return loss

    def configure_optimizers(
        self,
    ) -> tuple[list[torch.optim.Optimizer], list[dict[str, Any]]]:
        """Configures the optimizer and learning rate scheduler.

        Returns:
            A tuple containing the list of optimizers and the list of
            learning rate schedulers.
        """
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
        # return optimizer
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=warmup_and_cosine_decay(
                warmup_steps=self.warmup_steps,
                decay_steps=self.decay_steps,
                decay_factor=self.decay_factor,
            ),
            verbose=False,
        )
        lr_scheduler = {
            "scheduler": scheduler,  # The LR scheduler instance (required)
            "interval": "step",  # The unit of the scheduler's step size
            "frequency": 1,  # The frequency of the scheduler
        }
        return [optimizer], [lr_scheduler]

    def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Adds the processed tokens to the checkpoint.

        Args:
            checkpoint: The checkpoint dictionary.
        """
        # Add processed_tokens to the checkpoint dictionary
        checkpoint["processed_tokens"] = self.processed_tokens

    def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
        """Loads the processed tokens from the checkpoint.

        Args:
            checkpoint: The checkpoint dictionary.
        """
        # Load processed_tokens from the checkpoint dictionary
        self.processed_tokens = checkpoint.get("processed_tokens", 0)

__init__(pad_idx, mask_idx, lr, batch_size, warmup_steps=4000, decay_steps=24000, decay_factor=0.1, model=None, criterion=None, processed_tokens=0, start_idx=0)

Initializes the PLTraining module.

Parameters:

Name Type Description Default
pad_idx int

The index of the padding token.

required
mask_idx int

The index of the mask token.

required
lr float

The initial learning rate.

required
batch_size int

The batch size.

required
warmup_steps int

The number of warmup steps for the learning rate scheduler.

4000
decay_steps int

The number of decay steps for the learning rate scheduler.

24000
decay_factor float

The decay factor for the learning rate scheduler.

0.1
model Module | None

The sequence-to-sequence model.

None
criterion Module | None

The loss function.

None
processed_tokens int

The number of tokens processed so far.

0
start_idx int

The index of the start token.

0
Source code in src/directmultistep/training/lightning.py
def __init__(
    self,
    pad_idx: int,
    mask_idx: int,
    lr: float,
    batch_size: int,
    warmup_steps: int = 4000,
    decay_steps: int = 24000,
    decay_factor: float = 0.1,
    model: nn.Module | None = None,
    criterion: nn.Module | None = None,
    processed_tokens: int = 0,
    start_idx: int = 0,
):
    """Initializes the PLTraining module.

    Args:
        pad_idx: The index of the padding token.
        mask_idx: The index of the mask token.
        lr: The initial learning rate.
        batch_size: The batch size.
        warmup_steps: The number of warmup steps for the learning rate scheduler.
        decay_steps: The number of decay steps for the learning rate scheduler.
        decay_factor: The decay factor for the learning rate scheduler.
        model: The sequence-to-sequence model.
        criterion: The loss function.
        processed_tokens: The number of tokens processed so far.
        start_idx: The index of the start token.
    """
    super().__init__()
    if model is not None:
        self.model = model
    if criterion is not None:
        self.criterion = criterion
    self.start_idx = start_idx
    self.pad_idx = pad_idx
    self.mask_idx = mask_idx
    self.learning_rate = lr
    self.batch_size = batch_size
    self.warmup_steps = warmup_steps
    self.decay_steps = decay_steps
    self.decay_factor = decay_factor
    self.processed_tokens = processed_tokens
    self.save_hyperparameters(ignore=["criterion", "model"])
    self.compute_loss = self.compute_loss_full

compute_loss_full(batch, batch_idx)

Computes the loss for the full sequence training.

This method calculates the loss for all tokens in the sequence.

Parameters:

Name Type Description Default
batch Tensor

The input batch tensor.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The computed loss tensor.

Source code in src/directmultistep/training/lightning.py
def compute_loss_full(self, batch: Tensor, batch_idx: int) -> Tensor:
    """Computes the loss for the full sequence training.

    This method calculates the loss for all tokens in the sequence.

    Args:
        batch: The input batch tensor.
        batch_idx: The index of the batch.

    Returns:
        The computed loss tensor.
    """
    src_item_BC = batch[0]
    tgt_item_BL = batch[1].long()
    steps_B1 = batch[2].view(-1, 1)
    masked_src_BC = self.mask_src(src_item_BC, masking_prob=0.05)
    # the output actually is [B, L-1, V] given slicing of tgt_item_BL
    output_BLV = self.model(masked_src_BC, tgt_item_BL[:, :-1], steps_B1)
    output_blV = output_BLV.view(-1, output_BLV.shape[-1])  # [B*(L-1), V]
    tgt_bl = tgt_item_BL[:, 1:].reshape(-1)  # [B*(L-1)]
    loss = self.criterion(output_blV, tgt_bl)
    self.processed_tokens += tgt_item_BL.shape[0] * tgt_item_BL.shape[1]
    return cast(Tensor, loss)

configure_optimizers()

Configures the optimizer and learning rate scheduler.

Returns:

Type Description
list[Optimizer]

A tuple containing the list of optimizers and the list of

list[dict[str, Any]]

learning rate schedulers.

Source code in src/directmultistep/training/lightning.py
def configure_optimizers(
    self,
) -> tuple[list[torch.optim.Optimizer], list[dict[str, Any]]]:
    """Configures the optimizer and learning rate scheduler.

    Returns:
        A tuple containing the list of optimizers and the list of
        learning rate schedulers.
    """
    optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate)
    # return optimizer
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer,
        lr_lambda=warmup_and_cosine_decay(
            warmup_steps=self.warmup_steps,
            decay_steps=self.decay_steps,
            decay_factor=self.decay_factor,
        ),
        verbose=False,
    )
    lr_scheduler = {
        "scheduler": scheduler,  # The LR scheduler instance (required)
        "interval": "step",  # The unit of the scheduler's step size
        "frequency": 1,  # The frequency of the scheduler
    }
    return [optimizer], [lr_scheduler]

log_step_info(loss, mode, prog_bar)

Logs the loss and other training information.

Parameters:

Name Type Description Default
loss Tensor

The loss tensor.

required
mode str

The mode of training ('train' or 'val').

required
prog_bar bool

Whether to display the loss in the progress bar.

required
Source code in src/directmultistep/training/lightning.py
def log_step_info(self, loss: Tensor, mode: str, prog_bar: bool) -> None:
    """Logs the loss and other training information.

    Args:
        loss: The loss tensor.
        mode: The mode of training ('train' or 'val').
        prog_bar: Whether to display the loss in the progress bar.
    """
    self.log(
        f"{mode}_loss",
        loss,
        batch_size=self.batch_size,
        prog_bar=prog_bar,
        sync_dist=True,
    )
    self.log("processed_tokens", self.processed_tokens, sync_dist=True)
    if mode == "train":
        current_lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.log(f"{mode}_lr", current_lr, batch_size=self.batch_size, sync_dist=True)

mask_src(src_BC, masking_prob)

Masks the source sequence with a given probability.

Parameters:

Name Type Description Default
src_BC Tensor

The source sequence tensor of shape [B, C].

required
masking_prob float

The probability of masking a token.

required

Returns:

Type Description
Tensor

The masked source sequence tensor of shape [B, C].

Source code in src/directmultistep/training/lightning.py
def mask_src(self, src_BC: Tensor, masking_prob: float) -> Tensor:
    """Masks the source sequence with a given probability.

    Args:
        src_BC: The source sequence tensor of shape [B, C].
        masking_prob: The probability of masking a token.

    Returns:
        The masked source sequence tensor of shape [B, C].
    """
    mask_idx_BC = torch.rand(src_BC.shape).to(src_BC.device) < masking_prob
    not_pad_BC = src_BC != self.pad_idx
    final_mask_BC = mask_idx_BC & not_pad_BC
    masked_src_BC = src_BC.clone()
    masked_src_BC[final_mask_BC] = self.mask_idx
    return masked_src_BC

on_load_checkpoint(checkpoint)

Loads the processed tokens from the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The checkpoint dictionary.

required
Source code in src/directmultistep/training/lightning.py
def on_load_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Loads the processed tokens from the checkpoint.

    Args:
        checkpoint: The checkpoint dictionary.
    """
    # Load processed_tokens from the checkpoint dictionary
    self.processed_tokens = checkpoint.get("processed_tokens", 0)

on_save_checkpoint(checkpoint)

Adds the processed tokens to the checkpoint.

Parameters:

Name Type Description Default
checkpoint dict[str, Any]

The checkpoint dictionary.

required
Source code in src/directmultistep/training/lightning.py
def on_save_checkpoint(self, checkpoint: dict[str, Any]) -> None:
    """Adds the processed tokens to the checkpoint.

    Args:
        checkpoint: The checkpoint dictionary.
    """
    # Add processed_tokens to the checkpoint dictionary
    checkpoint["processed_tokens"] = self.processed_tokens

training_step(batch, batch_idx)

Performs a single training step.

Parameters:

Name Type Description Default
batch Tensor

The input batch tensor.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The computed loss tensor.

Source code in src/directmultistep/training/lightning.py
def training_step(self, batch: Tensor, batch_idx: int) -> Tensor:
    """Performs a single training step.

    Args:
        batch: The input batch tensor.
        batch_idx: The index of the batch.

    Returns:
        The computed loss tensor.
    """
    loss = self.compute_loss(batch, batch_idx)
    self.log_step_info(loss, "train", prog_bar=True)
    return loss

validation_step(batch, batch_idx)

Performs a single validation step.

Parameters:

Name Type Description Default
batch Tensor

The input batch tensor.

required
batch_idx int

The index of the batch.

required

Returns:

Type Description
Tensor

The computed loss tensor.

Source code in src/directmultistep/training/lightning.py
def validation_step(self, batch: Tensor, batch_idx: int) -> Tensor:
    """Performs a single validation step.

    Args:
        batch: The input batch tensor.
        batch_idx: The index of the batch.

    Returns:
        The computed loss tensor.
    """
    loss = self.compute_loss(batch, batch_idx)
    self.log_step_info(loss, "val", prog_bar=True)
    return loss

warmup_and_cosine_decay(warmup_steps, decay_steps, decay_factor)

Creates a learning rate schedule with warmup and cosine decay.

The learning rate increases linearly during the warmup phase, then decreases following a cosine function during the decay phase, and finally remains constant at the decay factor.

Parameters:

Name Type Description Default
warmup_steps int

The number of steps for the warmup phase.

required
decay_steps int

The number of steps for the decay phase.

required
decay_factor float

The final learning rate factor after decay.

required

Returns:

Type Description
Callable[[int], float]

A function that takes the current step as input and returns the

Callable[[int], float]

corresponding learning rate factor.

Source code in src/directmultistep/training/lightning.py
def warmup_and_cosine_decay(warmup_steps: int, decay_steps: int, decay_factor: float) -> Callable[[int], float]:
    """Creates a learning rate schedule with warmup and cosine decay.

    The learning rate increases linearly during the warmup phase, then
    decreases following a cosine function during the decay phase, and
    finally remains constant at the decay factor.

    Args:
        warmup_steps: The number of steps for the warmup phase.
        decay_steps: The number of steps for the decay phase.
        decay_factor: The final learning rate factor after decay.

    Returns:
        A function that takes the current step as input and returns the
        corresponding learning rate factor.
    """

    def _get_new_lr(step: int) -> float:
        if step < warmup_steps:
            return step / warmup_steps
        elif step >= warmup_steps and step < warmup_steps + decay_steps:
            factor = 0.5 * (1 + np.cos(np.pi * (step - warmup_steps) / decay_steps))
            return cast(float, max(factor, decay_factor))
        else:
            return decay_factor

    return _get_new_lr