Monitoring Training
This guide explains how to monitor and visualize training progress for DMS models.
Basic Usage
The simplest way to visualize training progress is using the provided plotting utilities in use-examples/visualize_train_curves.py
Run Configuration
Use RunConfig
to specify which training runs to visualize:
from directmultistep.analysis.training import RunConfig
run = RunConfig(
run_name="flash_10M", # Folder name of the run
trace_name="Flash Model", # Display name for the traces
include_val=True # Whether to include validation curve
)
Training Curves
The plot_training_curves
function creates a figure showing:
- Training loss curves (solid lines)
- Validation loss curves (dotted lines with markers)
- X-axis shows number of processed tokens
- Hovering over validation points shows epoch information
Learning Rate Curves
The plot_learning_rates
function visualizes the learning rate schedule:
- Shows learning rate vs. training step
- Useful for verifying learning rate schedules
- Multiple runs can be compared on the same plot
Advanced Usage
For more control over visualization, you can load the training data directly:
from directmultistep.analysis.training import load_training_df
# Load training data
df = load_training_df(train_path, "flash_10M")
# Ignore specific training runs by ID
df = load_training_df(train_path, "flash_10M", ignore_ids=[0, 1])
The returned DataFrame contains columns:
processed_tokens
: Number of tokens processedtrain_loss
: Training lossval_loss
: Validation loss (if available)train_lr
: Learning rateepoch
: Current epoch- Additional metrics depending on the training configuration
Source Code
directmultistep.analysis.training
RunConfig
dataclass
Configuration for a training run visualization.
Source code in src/directmultistep/analysis/training.py
plot_training_curves(train_path, runs, x_axis='processed_tokens')
Create a figure showing training and validation curves for multiple runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_path
|
Path
|
Path to training data directory |
required |
runs
|
list[RunConfig]
|
List of run configurations specifying what and how to plot |
required |
x_axis
|
str
|
Column to use for x-axis values ("processed_tokens", "epoch", or "step") |
'processed_tokens'
|
Returns:
Type | Description |
---|---|
Figure
|
Plotly figure with training and validation curves |
Source code in src/directmultistep/analysis/training.py
plot_learning_rates(train_path, runs)
Create a figure showing learning rate curves for multiple runs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
train_path
|
Path
|
Path to training data directory |
required |
runs
|
list[RunConfig]
|
List of run configurations specifying what and how to plot |
required |
Returns:
Type | Description |
---|---|
Figure
|
Plotly figure with learning rate curves |