Skip to content

Monitoring Training

This guide explains how to monitor and visualize training progress for DMS models.

Basic Usage

The simplest way to visualize training progress is using the provided plotting utilities in use-examples/visualize_train_curves.py

Run Configuration

Use RunConfig to specify which training runs to visualize:

from directmultistep.analysis.training import RunConfig

run = RunConfig(
    run_name="flash_10M",      # Folder name of the run
    trace_name="Flash Model",  # Display name for the traces
    include_val=True          # Whether to include validation curve
)

Training Curves

The plot_training_curves function creates a figure showing:

  • Training loss curves (solid lines)
  • Validation loss curves (dotted lines with markers)
  • X-axis shows number of processed tokens
  • Hovering over validation points shows epoch information

Learning Rate Curves

The plot_learning_rates function visualizes the learning rate schedule:

  • Shows learning rate vs. training step
  • Useful for verifying learning rate schedules
  • Multiple runs can be compared on the same plot

Advanced Usage

For more control over visualization, you can load the training data directly:

from directmultistep.analysis.training import load_training_df

# Load training data
df = load_training_df(train_path, "flash_10M")

# Ignore specific training runs by ID
df = load_training_df(train_path, "flash_10M", ignore_ids=[0, 1])

The returned DataFrame contains columns:

  • processed_tokens: Number of tokens processed
  • train_loss: Training loss
  • val_loss: Validation loss (if available)
  • train_lr: Learning rate
  • epoch: Current epoch
  • Additional metrics depending on the training configuration

Source Code

directmultistep.analysis.training

RunConfig dataclass

Configuration for a training run visualization.

Source code in src/directmultistep/analysis/training.py
@dataclass
class RunConfig:
    """Configuration for a training run visualization."""

    run_name: str  # Folder name of the run
    trace_name: str  # Display name for the traces
    include_val: bool = True  # Whether to include validation curve
    ignore_ids: list[int] | None = None  # Version IDs to ignore when loading data

plot_training_curves(train_path, runs, x_axis='processed_tokens')

Create a figure showing training and validation curves for multiple runs.

Parameters:

Name Type Description Default
train_path Path

Path to training data directory

required
runs list[RunConfig]

List of run configurations specifying what and how to plot

required
x_axis str

Column to use for x-axis values ("processed_tokens", "epoch", or "step")

'processed_tokens'

Returns:

Type Description
Figure

Plotly figure with training and validation curves

Source code in src/directmultistep/analysis/training.py
def plot_training_curves(
    train_path: Path,
    runs: list[RunConfig],
    x_axis: str = "processed_tokens",
) -> go.Figure:
    """Create a figure showing training and validation curves for multiple runs.

    Args:
        train_path: Path to training data directory
        runs: List of run configurations specifying what and how to plot
        x_axis: Column to use for x-axis values ("processed_tokens", "epoch", or "step")

    Returns:
        Plotly figure with training and validation curves
    """
    traces = []
    for i, run in enumerate(runs):
        df = load_training_df(train_path, run.run_name, run.ignore_ids)
        color_idx = i % len(style.colors_light)
        traces.append(
            create_train_trace(df, run.trace_name, style.colors_light[color_idx % len(style.colors_light)], x_axis)
        )
        if run.include_val:
            traces.append(
                create_val_trace(df, run.trace_name, style.colors_dark[color_idx % len(style.colors_dark)], x_axis)
            )

    fig = go.Figure(data=traces)

    fig.update_layout(
        title="Training Loss",
        xaxis_title=x_axis,
        yaxis_title="Loss",
        width=1000,
    )
    style.apply_development_style(fig)

    return fig

plot_learning_rates(train_path, runs)

Create a figure showing learning rate curves for multiple runs.

Parameters:

Name Type Description Default
train_path Path

Path to training data directory

required
runs list[RunConfig]

List of run configurations specifying what and how to plot

required

Returns:

Type Description
Figure

Plotly figure with learning rate curves

Source code in src/directmultistep/analysis/training.py
def plot_learning_rates(
    train_path: Path,
    runs: list[RunConfig],
) -> go.Figure:
    """Create a figure showing learning rate curves for multiple runs.

    Args:
        train_path: Path to training data directory
        runs: List of run configurations specifying what and how to plot

    Returns:
        Plotly figure with learning rate curves
    """
    traces = []
    for run in runs:
        df = load_training_df(train_path, run.run_name, run.ignore_ids)
        traces.append(get_lr_trace(df, run.trace_name))

    fig = go.Figure(data=traces)

    fig.update_layout(
        title="Learning Rate",
        xaxis_title="Step",
        yaxis_title="Learning Rate",
        width=800,
    )
    style.apply_development_style(fig)

    return fig

load_training_df(train_path, run_name, ignore_ids=None)

Source code in src/directmultistep/analysis/training.py
def load_training_df(train_path: Path, run_name: str, ignore_ids: list[int] | None = None) -> pd.DataFrame:
    logger.debug(f"Loading {run_name=}")
    log_path = train_path / run_name / "lightning_logs"
    dfs = []
    versions = [log.name for log in log_path.glob("version_*")]
    logger.debug(f"Found versions: {versions} for {run_name}")
    if ignore_ids is not None:
        ignored_folders = {f"version_{i}" for i in ignore_ids}
    else:
        ignored_folders = set()
    for version in sorted(versions, key=lambda x: int(x.split("_")[1])):
        if version in ignored_folders:
            continue
        temp_df = pd.read_csv(log_path / version / "metrics.csv")
        logger.debug(f"Loaded df with shape {temp_df.shape}")
        dfs.append(temp_df)
    df = pd.concat(dfs)
    df = df.reset_index(drop=True)
    return df