Skip to content

Paper Figures

This document describes the figures that can be generated using the paper-figures.py script.

Available Figures

1. Route Length Distribution

  • File: route_length_distribution.{pdf,html}
  • Description: Visualizes the distribution of route lengths across different datasets (training, n1, and n5 datasets).
  • Generated by: plot_route_length_distribution()

2. Leaf Distribution

  • File: leaf_distribution.{pdf,html}
  • Description: Shows the distribution of leaf nodes (end states) across different datasets.
  • Generated by: plot_leaf_distribution()

3. Convergent Route Analysis

Two figures are generated for convergent route analysis:

  • Files:
    • convergent_fraction_by_length.{pdf,html}
    • convergent_fraction_overall.{pdf,html}
  • Description: Analyzes the fraction of convergent routes by length and overall convergent fraction across datasets.
  • Generated by: plot_convergent_fraction_by_length() and plot_convergent_fraction_overall()

4. Top-K Accuracy Analysis

  • File: {dataset_name}_topk_accuracy_subplots.{pdf,html}
  • Description: Comparative bar plots showing top-k accuracy metrics for different models and configurations.
  • Features: Shows accuracy for k values [1, 2, 3, 4, 5, 10]
  • Generated separately for n1 and n5 datasets

5. Route Processing Stages

  • File: {dataset_name}_route_processing_stages_{config}.{pdf,html}
  • Description: Visualizes different stages of route processing, comparing:
    • Valid routes
    • Processed routes without stock
    • Processed routes with stock
    • True routes

6. Accuracy by Route Length

  • File: accuracy_by_length_subplots_{config}.{pdf,html}
  • Description: Shows top-k accuracy metrics broken down by route length
  • Features:
    • Compares performance across different datasets (n1, n5)
    • Shows accuracy for k=1 and k=10

Usage

To generate these figures, modify the rerun dictionary in paper-figures.py to specify which figures you want to generate:

rerun = {
    "route-distribution": False,
    "leaf-distribution": False,
    "convergent-fraction": False,
    "topk-accuracy": False,
    "extraction-distribution": True,
    "accuracy-by-length": False,
}

Set the corresponding flag to True for the figures you want to generate. All figures will be saved in both PDF and HTML formats in the data/figures/paper directory.

Source Code

directmultistep.analysis.paper.dataset_analysis

create_convergent_fraction_trace(path_strings, route_lengths, label, color)

Create a bar trace showing fraction of convergent routes by length.

Parameters:

Name Type Description Default
path_strings list[str]

List of path strings to analyze

required
route_lengths list[int]

List of corresponding route lengths

required
label str

Label for the trace

required
color str

Color for the trace

required

Returns:

Type Description
Bar

Bar trace showing convergent fraction by length

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def create_convergent_fraction_trace(
    path_strings: list[str], route_lengths: list[int], label: str, color: str
) -> go.Bar:
    """Create a bar trace showing fraction of convergent routes by length.

    Args:
        path_strings: List of path strings to analyze
        route_lengths: List of corresponding route lengths
        label: Label for the trace
        color: Color for the trace

    Returns:
        Bar trace showing convergent fraction by length
    """
    # Group paths by length and compute convergent fraction
    max_length = 10
    fractions = []
    lengths = []

    for length in tqdm(range(1, max_length + 1)):
        # Get paths of this length
        mask = np.array(route_lengths) == length
        paths_at_length = np.array(path_strings)[mask]

        if len(paths_at_length) == 0:
            continue

        # Compute fraction of convergent paths
        n_convergent = sum(1 for path in paths_at_length if is_convergent(eval(path)))
        fraction = n_convergent / len(paths_at_length)

        fractions.append(fraction)
        lengths.append(length)

    return go.Bar(
        x=lengths,
        y=fractions,
        name=label,
        marker=dict(color=color),
        hovertemplate="Route Length: %{x}<br>Convergent Fraction: %{y:.2%}<extra></extra>",
        # text=[f"{v:.1%}" for v in fractions],
        # textposition="auto",
        # textfont=dict(size=12),
    )

create_leaf_bar_trace(path_strings, label, color)

Create a bar trace showing distribution of number of leaves at root node.

Parameters:

Name Type Description Default
path_strings list[str]

List of path strings to analyze

required
label str

Label for the trace

required
color str

Color for the trace

required

Returns:

Type Description
Bar

Bar trace showing leaf distribution

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def create_leaf_bar_trace(path_strings: list[str], label: str, color: str) -> go.Bar:
    """Create a bar trace showing distribution of number of leaves at root node.

    Args:
        path_strings: List of path strings to analyze
        label: Label for the trace
        color: Color for the trace

    Returns:
        Bar trace showing leaf distribution
    """
    n_leaves = []
    for path in tqdm(path_strings):
        path_dict = eval(path)
        root_leaves = sum(
            1 for child in path_dict["children"] if "children" not in child or len(child["children"]) == 0
        )
        n_leaves.append(root_leaves)

    unique_lengths, counts = np.unique(n_leaves, return_counts=True)

    unique_lengths = unique_lengths[:4]
    counts = counts[:4]
    relative_abundance = counts / len(path_strings)

    return go.Bar(
        x=unique_lengths,
        y=relative_abundance,
        name=label,
        marker=dict(color=color),
        hovertemplate="Number of Leaves: %{x}<br>Relative Frequency: %{y:.2%}<extra></extra>",
        text=[f"{v:.1%}" for v in relative_abundance],
        textposition="auto",
        textfont=dict(size=12),
    )

create_split_bar_trace(route_lengths, label, sep_threshold, color)

Create two bar traces split by a threshold value.

Parameters:

Name Type Description Default
route_lengths list[int]

List of route lengths to plot

required
label str

Label for the traces

required
sep_threshold int

Threshold value to split traces

required
color str

Color for both traces

required

Returns:

Type Description
tuple[Bar, Bar]

Tuple of two bar traces - one for values <= threshold, one for values > threshold

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def create_split_bar_trace(
    route_lengths: list[int], label: str, sep_threshold: int, color: str
) -> tuple[go.Bar, go.Bar]:
    """Create two bar traces split by a threshold value.

    Args:
        route_lengths: List of route lengths to plot
        label: Label for the traces
        sep_threshold: Threshold value to split traces
        color: Color for both traces

    Returns:
        Tuple of two bar traces - one for values <= threshold, one for values > threshold
    """
    unique_lengths, counts = np.unique(route_lengths, return_counts=True)
    relative_abundance = counts / len(route_lengths)

    trace_settings = dict(
        name=label,
        marker=dict(color=color),
        hovertemplate="Route Length: %{x}<br>Relative Abundance: %{y:.2%}<extra></extra>",
        textposition="auto",
    )

    # Split data by threshold
    mask_short = unique_lengths <= sep_threshold
    mask_long = unique_lengths > sep_threshold

    trace1 = go.Bar(x=unique_lengths[mask_short], y=relative_abundance[mask_short], **trace_settings)

    trace2 = go.Bar(x=unique_lengths[mask_long], y=relative_abundance[mask_long], showlegend=False, **trace_settings)

    return trace1, trace2

plot_convergent_fraction_by_length(train_paths, train_lengths, n1_paths, n1_lengths, n5_paths, n5_lengths)

Create a plot showing fraction of convergent routes by length for different datasets.

Parameters:

Name Type Description Default
train_paths list[str]

List of path strings from training set

required
train_lengths list[int]

List of route lengths from training set

required
n1_paths list[str]

List of path strings from n1 dataset

required
n1_lengths list[int]

List of route lengths from n1 dataset

required
n5_paths list[str]

List of path strings from n5 dataset

required
n5_lengths list[int]

List of route lengths from n5 dataset

required

Returns:

Type Description
Figure

Plotly figure object containing the visualization

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def plot_convergent_fraction_by_length(
    train_paths: list[str],
    train_lengths: list[int],
    n1_paths: list[str],
    n1_lengths: list[int],
    n5_paths: list[str],
    n5_lengths: list[int],
) -> go.Figure:
    """Create a plot showing fraction of convergent routes by length for different datasets.

    Args:
        train_paths: List of path strings from training set
        train_lengths: List of route lengths from training set
        n1_paths: List of path strings from n1 dataset
        n1_lengths: List of route lengths from n1 dataset
        n5_paths: List of path strings from n5 dataset
        n5_lengths: List of route lengths from n5 dataset

    Returns:
        Plotly figure object containing the visualization
    """
    colors = [FONT_COLOR, publication_colors["dark_blue"], publication_colors["dark_purple"]]
    datasets = [
        (train_paths, train_lengths, "Training Routes", colors[0]),
        (n1_paths, n1_lengths, "n1", colors[1]),
        (n5_paths, n5_lengths, "n5", colors[2]),
    ]

    fig = go.Figure()

    for paths, lengths, label, color in datasets:
        fig.add_trace(create_convergent_fraction_trace(paths, lengths, label, color))

    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)
    # fmt:off
    fig.update_layout(width=1000, height=250, bargap=0.15,
        margin=dict(l=100, r=50, t=20, b=50),
        legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=0.99),
        yaxis=dict(range=[0, 0.31], dtick=0.05),
    )
    # fmt:on
    fig.update_xaxes(title_text="<b>Route Length</b>", dtick=1, showgrid=False)
    fig.update_yaxes(title_text="<b>Fraction Convergent</b>", tickformat=",.0%")

    return fig

plot_convergent_fraction_overall(train_paths, n1_paths, n5_paths)

Create a plot showing overall fraction of convergent routes for different datasets.

Parameters:

Name Type Description Default
train_paths list[str]

List of path strings from training set

required
n1_paths list[str]

List of path strings from n1 dataset

required
n5_paths list[str]

List of path strings from n5 dataset

required

Returns:

Type Description
Figure

Plotly figure object containing the visualization

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def plot_convergent_fraction_overall(
    train_paths: list[str],
    n1_paths: list[str],
    n5_paths: list[str],
) -> go.Figure:
    """Create a plot showing overall fraction of convergent routes for different datasets.

    Args:
        train_paths: List of path strings from training set
        n1_paths: List of path strings from n1 dataset
        n5_paths: List of path strings from n5 dataset

    Returns:
        Plotly figure object containing the visualization
    """
    colors = [FONT_COLOR, publication_colors["dark_blue"], publication_colors["dark_purple"]]
    datasets = [
        (train_paths, "Training Routes", colors[0]),
        (n1_paths, "n1", colors[1]),
        (n5_paths, "n5", colors[2]),
    ]

    fractions = []
    labels = []
    colors_used = []

    for paths, label, color in tqdm(datasets):
        n_convergent = sum(1 for path in paths if is_convergent(eval(path)))
        fraction = n_convergent / len(paths)
        fractions.append(fraction)
        labels.append(label)
        colors_used.append(color)

    fig = go.Figure()
    # fmt:off
    fig.add_trace(
        go.Bar(x=labels, y=fractions,
            marker=dict(color=colors_used), text=[f"{v:.1%}" for v in fractions], textposition="auto"))

    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)
    fig.update_layout(width=600, height=250, margin=dict(l=100, r=50, t=20, b=50), showlegend=False)
    # fmt:on
    fig.update_xaxes(title_text="<b>Dataset</b>", showgrid=False)
    fig.update_yaxes(title_text="<b>Fraction Convergent</b>", tickformat=",.0%", dtick=0.05, range=[0, 0.31])

    return fig

plot_leaf_distribution(train_paths, n1_paths, n5_paths)

Create a plot showing the distribution of number of leaves for different datasets.

Parameters:

Name Type Description Default
train_paths list[str]

List of path strings from training set

required
n1_paths list[str]

List of path strings from n1 dataset

required
n5_paths list[str]

List of path strings from n5 dataset

required

Returns:

Type Description
Figure

Plotly figure object containing the visualization

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def plot_leaf_distribution(
    train_paths: list[str],
    n1_paths: list[str],
    n5_paths: list[str],
) -> go.Figure:
    """Create a plot showing the distribution of number of leaves for different datasets.

    Args:
        train_paths: List of path strings from training set
        n1_paths: List of path strings from n1 dataset
        n5_paths: List of path strings from n5 dataset

    Returns:
        Plotly figure object containing the visualization
    """
    colors = [FONT_COLOR, publication_colors["dark_blue"], publication_colors["dark_purple"]]
    datasets = [(train_paths, "Training Routes", colors[0]), (n1_paths, "n1", colors[1]), (n5_paths, "n5", colors[2])]

    fig = go.Figure()

    for paths, label, color in datasets:
        fig.add_trace(create_leaf_bar_trace(paths, label, color))

    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)
    # fmt:off
    fig.update_layout(width=700, height=250, bargap=0.08, yaxis_range=[0, 0.81],
        margin=dict(l=100, r=50, t=20, b=50),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=0.99))

    fig.update_xaxes(title_text="<b>Number of Leaves at Root Node</b>", dtick=1, showgrid=False)
    fig.update_yaxes(title_text="<b>Relative Frequency</b>", tickformat=",.0%")

    return fig

plot_route_length_distribution(train_steps, n1_steps, n5_steps)

Create a split plot showing the distribution of route lengths for different datasets.

Parameters:

Name Type Description Default
train_steps list[int]

List of route lengths from training set

required
n1_steps list[int]

List of route lengths from n1 dataset

required
n5_steps list[int]

List of route lengths from n5 dataset

required

Returns:

Type Description
Figure

Plotly figure object containing the visualization

Source code in src/directmultistep/analysis/paper/dataset_analysis.py
def plot_route_length_distribution(
    train_steps: list[int],
    n1_steps: list[int],
    n5_steps: list[int],
) -> go.Figure:
    """Create a split plot showing the distribution of route lengths for different datasets.

    Args:
        train_steps: List of route lengths from training set
        n1_steps: List of route lengths from n1 dataset
        n5_steps: List of route lengths from n5 dataset

    Returns:
        Plotly figure object containing the visualization
    """
    # Plot settings
    colors = [FONT_COLOR, publication_colors["dark_blue"], publication_colors["dark_purple"]]
    sep_threshold = 6

    fig = make_subplots(rows=1, cols=2)

    datasets = [(train_steps, "Training Routes", colors[0]), (n1_steps, "n1", colors[1]), (n5_steps, "n5", colors[2])]

    for steps, label, color in datasets:
        trace1, trace2 = create_split_bar_trace(steps, label, sep_threshold, color)
        fig.add_trace(trace1, row=1, col=1)
        fig.add_trace(trace2, row=1, col=2)

    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)

    # fmt:off
    fig.update_layout(width=1000, height=300, margin=dict(l=100, r=50, t=20, b=50))

    for col in [1, 2]:
        fig.update_xaxes(title_text="<b>Route Length</b>", showgrid=False, row=1, col=col, dtick=1)

    fig.update_yaxes(title_text="<b>Relative Abundance</b>", tickformat=",.0%", dtick=0.1, row=1, col=1)
    fig.update_yaxes(tickformat=",.2%", dtick=0.003, row=1, col=2)
    fig.update_layout(legend=dict( orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=0.99))

    return fig

directmultistep.analysis.paper.linear_vs_convergent

ModelPlotConfig dataclass

Configuration for model plotting.

Attributes:

Name Type Description
model_name str

Name of the model (e.g. 'flex_20M', 'flash_10M').

epoch str

Epoch number as string (e.g. 'epoch=20').

variant_base str

Base variant string (e.g. 'b50_sm_st_ea=1_da=1').

true_reacs bool

Whether to use true reactions.

stock bool

Whether to use stock compounds.

ds_name str

Dataset name (e.g. 'n1', 'n5').

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
@dataclass
class ModelPlotConfig:
    """Configuration for model plotting.

    Attributes:
        model_name: Name of the model (e.g. 'flex_20M', 'flash_10M').
        epoch: Epoch number as string (e.g. 'epoch=20').
        variant_base: Base variant string (e.g. 'b50_sm_st_ea=1_da=1').
        true_reacs: Whether to use true reactions.
        stock: Whether to use stock compounds.
        ds_name: Dataset name (e.g. 'n1', 'n5').
    """

    model_name: str
    epoch: str
    variant_base: str
    true_reacs: bool = True
    stock: bool = True
    ds_name: str = "n1"

    def __post_init__(self) -> None:
        if "nosm" in self.variant_base:
            self.true_reacs = False

    @property
    def display_name(self) -> str:
        """Generate display name from model name.

        Returns:
            str: Display name of the model.
        """
        base = self.model_name.replace("_", " ").title()

        if "nosm" in self.variant_base:
            base += " (no SM)"
        elif "sm" in self.variant_base:
            base += " (SM)"

        if "ea=1" in self.variant_base and "da=1" in self.variant_base:
            base = base.replace("(", "(Mono, ")
            base += ")"
        elif "ea=2" in self.variant_base and "da=2" in self.variant_base:
            base = base.replace("(", "(Duo, ")
            base += ")"

        return base

    @property
    def variant(self) -> str:
        """Get the full variant string.

        Returns:
            str: Full variant string.
        """
        return f"{self.ds_name}_{self.variant_base}"

    @property
    def save_suffix(self) -> str:
        """Get the name of the save file.

        Returns:
            str: Save file suffix.
        """
        return f"{self.model_name}_{self.variant}"

    @property
    def processed_paths_name(self) -> str:
        """Get the name of the processed paths file.

        Returns:
            str: Processed paths file name.
        """
        return f"processed_paths_NS2n_true_reacs={self.true_reacs}_stock={self.stock}.pkl"

    @property
    def correct_paths_name(self) -> str:
        """Get the name of the correct paths file.

        Returns:
            str: Correct paths file name.
        """
        return "correct_paths_NS2n.pkl"

    def with_dataset(self, ds_name: str) -> "ModelPlotConfig":
        """Create a new config with dataset information.

        Args:
            ds_name: Dataset name.

        Returns:
            ModelPlotConfig: New config with dataset information.
        """
        return replace(self, ds_name=ds_name)

    def get_result_path(self, eval_path: Path) -> Path:
        """Get the path to the results directory for this config.

        Args:
            eval_path: Path to the evaluation directory.

        Returns:
            Path: Path to the results directory.
        """
        return eval_path / self.model_name / self.epoch / self.variant

correct_paths_name: str property

Get the name of the correct paths file.

Returns:

Name Type Description
str str

Correct paths file name.

display_name: str property

Generate display name from model name.

Returns:

Name Type Description
str str

Display name of the model.

processed_paths_name: str property

Get the name of the processed paths file.

Returns:

Name Type Description
str str

Processed paths file name.

save_suffix: str property

Get the name of the save file.

Returns:

Name Type Description
str str

Save file suffix.

variant: str property

Get the full variant string.

Returns:

Name Type Description
str str

Full variant string.

get_result_path(eval_path)

Get the path to the results directory for this config.

Parameters:

Name Type Description Default
eval_path Path

Path to the evaluation directory.

required

Returns:

Name Type Description
Path Path

Path to the results directory.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def get_result_path(self, eval_path: Path) -> Path:
    """Get the path to the results directory for this config.

    Args:
        eval_path: Path to the evaluation directory.

    Returns:
        Path: Path to the results directory.
    """
    return eval_path / self.model_name / self.epoch / self.variant

with_dataset(ds_name)

Create a new config with dataset information.

Parameters:

Name Type Description Default
ds_name str

Dataset name.

required

Returns:

Name Type Description
ModelPlotConfig ModelPlotConfig

New config with dataset information.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def with_dataset(self, ds_name: str) -> "ModelPlotConfig":
    """Create a new config with dataset information.

    Args:
        ds_name: Dataset name.

    Returns:
        ModelPlotConfig: New config with dataset information.
    """
    return replace(self, ds_name=ds_name)

RouteAnalyzer

Analyzes predicted routes and calculates various statistics.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
class RouteAnalyzer:
    """Analyzes predicted routes and calculates various statistics."""

    def __init__(self, predicted_routes: PathsProcessedType, true_routes: list[str], k_vals: list[int] | None = None):
        """Initializes the RouteAnalyzer.

        Args:
            predicted_routes: Predicted routes.
            true_routes: True routes.
            k_vals: List of k values for top-k accuracy calculation.
        """
        self.predicted_routes = predicted_routes
        self.true_routes = true_routes
        self.k_vals = k_vals if k_vals is not None else [1, 2, 3, 4, 5, 10, 20, 50]
        self.convergent_idxs = get_convergent_indices(true_routes)
        self.non_convergent_idxs = set(range(len(true_routes))) - self.convergent_idxs

    def analyze_convergence_stats(self) -> None:
        """Analyze and log basic convergence statistics."""
        n_convergent = len(self.convergent_idxs)
        total = len(self.true_routes)
        logger.info(f"Found {n_convergent} convergent routes out of {total} total routes")
        logger.info(f"Percentage convergent: {100 * n_convergent / total:.1f}%")

    def calculate_top_k_accuracies(self, save_path: Path | None = None) -> dict[str, dict[str, str]]:
        """Calculate top-k accuracies for different route subsets and optionally save results.

        Args:
            save_path: Optional path to save detailed accuracies to YAML file.

        Returns:
            dict[str, dict[str, str]]: Dictionary of top-k accuracies.
        """
        results = {}
        route_types = {"all": None, "convergent": self.non_convergent_idxs, "non_convergent": self.convergent_idxs}

        with tqdm(total=len(route_types), desc="Analyzing top-k accuracy") as pbar:
            for route_type, ignore_ids in route_types.items():
                pbar.set_description(f"{route_type} routes")
                _, perm_matches = find_matching_paths(self.predicted_routes, self.true_routes, ignore_ids=ignore_ids)
                results[route_type] = find_top_n_accuracy(perm_matches, self.k_vals)
                pbar.update(1)

        if save_path is not None:
            save_path = save_path / "top_k_accuracy_detailed.yaml"
            with open(save_path, "w") as f:
                yaml.dump(results, f, default_flow_style=False)
            logger.info(f"Saved detailed accuracies to {save_path}")

        return results

    def analyze_and_log_results(self) -> dict[str, dict[str, str]]:
        """Run full analysis and log results.

        Returns:
            dict[str, dict[str, str]]: Dictionary of top-k accuracies.
        """
        self.analyze_convergence_stats()
        results = self.calculate_top_k_accuracies()

        for route_type, accuracies in results.items():
            logger.info(f"\nTop-k accuracy for {route_type} routes:")
            logger.info(accuracies)

        return results

    def visualize_route_distributions(self, dataset_name: str = "") -> go.Figure:
        """Create a publication-quality figure showing the distribution of predicted routes.

        Args:
            dataset_name: Name of the dataset being analyzed, used in plot title.

        Returns:
            go.Figure: Plotly figure object.
        """
        n_predictions = [len(routes) for routes in self.predicted_routes]

        conv_predictions = [n_predictions[i] for i in self.convergent_idxs]
        nonconv_predictions = [n_predictions[i] for i in self.non_convergent_idxs]

        mean_all, median_all, mean_all_filtered, median_all_filtered = calculate_prediction_stats(n_predictions)
        mean_conv, median_conv, mean_conv_filtered, median_conv_filtered = calculate_prediction_stats(conv_predictions)
        mean_nonconv, median_nonconv, mean_nonconv_filtered, median_nonconv_filtered = calculate_prediction_stats(
            nonconv_predictions
        )

        # fmt: off
        fig = make_subplots(rows=1, cols=3,
            subplot_titles=(
                f'All Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_all:.1f}, median: {median_all:.1f} (mean*: {mean_all_filtered:.1f}, median*: {median_all_filtered:.1f})</span>',
                f'Convergent Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_conv:.1f}, median: {median_conv:.1f} (mean*: {mean_conv_filtered:.1f}, median*: {median_conv_filtered:.1f})</span>',
                f'Non-convergent Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_nonconv:.1f}, median: {median_nonconv:.1f} (mean*: {mean_nonconv_filtered:.1f}, median*: {median_nonconv_filtered:.1f})</span>'
            ), horizontal_spacing=0.1)

        histogram_style = dict(opacity =0.75, nbinsx=30,histnorm='percent',marker_color=style.publication_colors["dark_blue"])

        data = [(n_predictions, "All"), (conv_predictions, "Convergent"), (nonconv_predictions, "Non-convergent")]
        for i, (predictions, name) in enumerate(data, start=1):
            fig.add_trace(go.Histogram(x=predictions, name=name, **histogram_style), row=1, col=i)

        title = "Distribution of Predicted Routes per Target"
        if dataset_name:
            title = f"{title} - {dataset_name}"

        fig.update_layout(title=dict(text=title, x=0.5, xanchor='center'), showlegend=False, height=400, width=1200,)

        apply_publication_style(fig)

        for i in range(1, 4):
            fig.update_xaxes(title=dict(text="Number of Predicted Routes", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=1, col=i)
            fig.update_yaxes(title=dict(text="Percentage (%)", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=1, col=i)
        # fmt: on
        return fig

    @staticmethod
    def create_comparative_bar_plots(
        result_paths: list[Path], trace_names: list[str], k_vals: list[int] | None = None, title: str = ""
    ) -> go.Figure:
        """Create comparative bar plots showing top-k accuracy for different configurations.

        Args:
            result_paths: List of paths to top_k_accuracy_detailed.yaml files.
            trace_names: List of names for each trace (must match length of result_paths).
            k_vals: Optional list of k values to show. If None, shows all k values.
            title: Title for the plot.

        Returns:
            go.Figure: Plotly figure object.
        """
        if len(result_paths) != len(trace_names):
            raise ValueError("Number of result paths must match number of trace names")

        results = []
        for path in result_paths:
            with open(path / "top_k_accuracy_detailed.yaml") as f:
                results.append(yaml.safe_load(f))

        # fmt: off
        fig = make_subplots(rows=3, cols=1, horizontal_spacing=0.07, vertical_spacing=0.12,
            subplot_titles=[f"<b>{t}</b>" for t in ('(a) all routes', '(b) convergent routes', '(c) non-convergent routes')])

        categories = ['all', 'convergent', 'non_convergent']
        positions = [1, 2, 3]

        colors = style.colors_blue + style.colors_purple + style.colors_red

        for cat, pos in zip(categories, positions):
            x = list(results[0][cat].keys())
            x.sort(key=lambda k: int(k.split()[-1]))

            if k_vals is not None:
                k_vals_str = [f"Top {k}" for k in k_vals]
                x = [k for k in x if k in k_vals_str]

            for i, (result, name) in enumerate(zip(results, trace_names)):
                y = [float(result[cat][k].strip('%')) for k in x]

                fig.add_trace(
                    go.Bar(name=name, x=x, y=y, showlegend=pos == 1, marker_color=colors[i % len(colors)],
                        legendgroup=name,), row=pos, col=1)

        fig.update_layout(
            title=dict(text=title, x=0.5, xanchor='center'),
            barmode='group', height=600, width=1000,
            legend=dict(font=get_font_dict(FONT_SIZES["legend"]), orientation="h", yanchor="bottom",
                y=-0.20, xanchor="center", x=0.5, entrywidth=140, tracegroupgap=0))

        style.AXIS_STYLE["linecolor"] = None
        apply_publication_style(fig)

        for i in range(1, 4):
            fig.update_yaxes(dtick=10, title=dict(text="Accuracy (%)", font=get_font_dict(FONT_SIZES["axis_title"])), row=i, col=1)
            fig.update_xaxes(showgrid=False, row=i, col=1)
        # fmt: on
        return fig

    @staticmethod
    def _calculate_accuracy_by_length_data(
        predicted_routes: PathsProcessedType,
        dataset: DatasetDict,
        k_vals: list[int],
        ignore_ids: set[int] | None = None,
    ) -> tuple[list[int], dict[int, dict[str, int]]]:
        """Helper function to calculate accuracy by length data.

        Args:
            predicted_routes: List of predicted routes.
            dataset: Dataset dictionary.
            k_vals: List of k values to calculate accuracy for.
            ignore_ids: Optional set of indices to ignore.

        Returns:
            Tuple of (lengths, step_stats) where step_stats maps length to accuracy stats.
        """
        _, perm_matches = find_matching_paths(predicted_routes, dataset["path_strings"], ignore_ids=ignore_ids)
        step_stats = calculate_top_k_counts_by_step_length(perm_matches, dataset["n_steps_list"], k_vals)
        lengths = list(step_stats.keys())
        return lengths, step_stats

    @staticmethod
    def create_accuracy_by_length_plot(
        result_paths: list[Path],
        datasets: list[DatasetDict],
        configs: list[ModelPlotConfig],
        k_vals: list[int],
        title: str = "",
    ) -> go.Figure:
        """Create plot showing accuracy by route length.

        Args:
            result_paths: List of paths to result directories.
            datasets: List of datasets to analyze.
            configs: List of model configurations.
            k_vals: List of k values to calculate accuracy for.
            title: Title for the plot.

        Returns:
            go.Figure: Plotly figure object.
        """
        fig = go.Figure()

        cset = style.publication_colors
        colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

        for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
            paths_name = config.processed_paths_name

            with open(path / paths_name, "rb") as f:
                predicted_routes = pickle.load(f)

            lengths, step_stats = RouteAnalyzer._calculate_accuracy_by_length_data(predicted_routes, dataset, k_vals)

            for k_idx, k in enumerate(k_vals):
                accuracies = [
                    step_stats[length].get(f"Top {k}", 0) / step_stats[length]["Total"] * 100 for length in lengths
                ]
                # fmt:off
                fig.add_trace(go.Bar(name=f"{dataset['ds_name']} (Top-{k})", x=lengths, y=accuracies, marker_color=colors[i * len(k_vals) + k_idx]))

        fig.update_layout(
            barmode="group",
            height=300,
            width=1000,
            xaxis=dict(title="<b>Route Length</b>", dtick=1),
            yaxis=dict(title="<b>Accuracy (%)</b>", dtick=10, range=[0, 82]),
            legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
        )
        # fmt: on
        style.AXIS_STYLE["linecolor"] = None
        apply_publication_style(fig)
        fig.update_xaxes(showgrid=False)
        return fig

    @staticmethod
    def create_accuracy_by_length_subplots(
        result_paths: list[Path],
        datasets: list[DatasetDict],
        configs: list[ModelPlotConfig],
        k_vals: list[int],
        title: str = "",
    ) -> go.Figure:
        """Create plot showing accuracy by route length with subplots for all/convergent/non-convergent routes.

        Args:
            result_paths: List of paths to result directories.
            datasets: List of datasets to analyze.
            configs: List of model configurations.
            k_vals: List of k values to calculate accuracy for.
            title: Title for the plot.

        Returns:
            go.Figure: Plotly figure object.
        """
        fig = make_subplots(
            rows=3,
            cols=1,
            subplot_titles=[
                f"<b>{t}</b>" for t in ("(a) all routes", "(b) convergent routes", "(c) non-convergent routes")
            ],
            vertical_spacing=0.12,
        )

        cset = style.publication_colors
        colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

        for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
            paths_name = config.processed_paths_name

            with open(path / paths_name, "rb") as f:
                predicted_routes = pickle.load(f)

            analyzer = RouteAnalyzer(predicted_routes, dataset["path_strings"])
            route_types = {
                "all": (None, 1),
                "convergent": (analyzer.non_convergent_idxs, 2),
                "non_convergent": (analyzer.convergent_idxs, 3),
            }

            for route_type, (ignore_ids, row) in route_types.items():
                lengths, step_stats = RouteAnalyzer._calculate_accuracy_by_length_data(
                    predicted_routes, dataset, k_vals, ignore_ids=ignore_ids
                )

                for k_idx, k in enumerate(k_vals):
                    accuracies = [
                        step_stats[length].get(f"Top {k}", 0) / step_stats[length]["Total"] * 100 for length in lengths
                    ]
                    # fmt:off
                    fig.add_trace(go.Bar(name=f"{dataset['ds_name']} (Top-{k})",
                            x=lengths,y=accuracies, marker_color=colors[i * len(k_vals) + k_idx],
                            showlegend=row == 1, legendgroup=f"{dataset['ds_name']} (Top-{k})"
                        ), row=row, col=1)

        fig.update_layout(
            barmode="group",
            height=900,
            width=1000,
            legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5),
        )

        for i in range(1, 4):
            fig.update_xaxes(title="<b>Route Length</b>", dtick=1, row=i, col=1, showgrid=False)
            fig.update_yaxes(title="<b>Accuracy (%)</b>", dtick=10, range=[0, 82], row=i, col=1)
        # fmt: on
        style.AXIS_STYLE["linecolor"] = None
        apply_publication_style(fig)
        fig.update_xaxes(showgrid=False)
        return fig

    @staticmethod
    def visualize_route_processing_stages(
        valid_routes: PathsProcessedType,
        processed_routes_no_stock: PathsProcessedType,
        processed_routes_with_stock: PathsProcessedType,
        true_routes: list[str],
        dataset_name: str = "",
        show_filtered_stats: bool = False,
    ) -> go.Figure:
        """Create a publication-quality figure showing the distribution of routes at different processing stages.

        Args:
            valid_routes: Valid routes from beam search.
            processed_routes_no_stock: Routes after canonicalization/removing repetitions.
            processed_routes_with_stock: Routes after applying stock filter.
            true_routes: True routes for convergence analysis.
            dataset_name: Name of the dataset being analyzed.
            show_filtered_stats: Whether to show filtered statistics (mean* and median*).

        Returns:
            go.Figure: Plotly figure object.
        """
        # Get convergent indices
        convergent_idxs = get_convergent_indices(true_routes)
        non_convergent_idxs = set(range(len(true_routes))) - convergent_idxs

        def get_predictions_by_type(routes: PathsProcessedType) -> tuple[list[int], list[int], list[int]]:
            all_predictions = [len(routes) for routes in routes]
            conv_predictions = [all_predictions[i] for i in convergent_idxs]
            nonconv_predictions = [all_predictions[i] for i in non_convergent_idxs]
            return all_predictions, conv_predictions, nonconv_predictions

        valid_all, valid_conv, valid_nonconv = get_predictions_by_type(valid_routes)
        no_stock_all, no_stock_conv, no_stock_nonconv = get_predictions_by_type(processed_routes_no_stock)
        with_stock_all, with_stock_conv, with_stock_nonconv = get_predictions_by_type(processed_routes_with_stock)

        # Create subplot titles
        def create_subtitle(stage: str, predictions: list[int]) -> str:
            mean, median, mean_f, median_f = calculate_prediction_stats(predictions)
            base = f"{stage}<br><span style=\"font-size:{FONT_SIZES['subplot_title']-4}px\">"
            stats = f"mean={mean:.1f}, median={median:.1f}"
            if show_filtered_stats:
                stats += f" (μ*={mean_f:.1f}, m*={median_f:.1f})"
            return base + stats + "</span>"

        # fmt:off
        fig = make_subplots(rows=3, cols=3,
            subplot_titles=[
                create_subtitle("<b>(a) valid routes (all)</b>", valid_all),
                create_subtitle("<b>(b) valid routes (convergent)</b>", valid_conv),
                create_subtitle("<b>(c) valid routes (non-convergent)</b>", valid_nonconv),
                create_subtitle("<b>(d) after canonicalization (all)</b>", no_stock_all),
                create_subtitle("<b>(e) after canonicalization (convergent)</b>", no_stock_conv),
                create_subtitle("<b>(f) after canonicalization (non-convergent)</b>", no_stock_nonconv),
                create_subtitle("<b>(g) after stock filter (all)</b>", with_stock_all),
                create_subtitle("<b>(h) after stock filter (convergent)</b>", with_stock_conv),
                create_subtitle("<b>(i) after stock filter (non-convergent)</b>", with_stock_nonconv),
            ], vertical_spacing=0.10, horizontal_spacing=0.05)

        histogram_style = dict(histnorm='percent', marker_color=style.publication_colors["dark_blue"], marker_line_width=0)

        data = [
            (valid_all, valid_conv, valid_nonconv),
            (no_stock_all, no_stock_conv, no_stock_nonconv),
            (with_stock_all, with_stock_conv, with_stock_nonconv)
        ]

        for row, (all_pred, conv_pred, nonconv_pred) in enumerate(data, start=1):
            for col, predictions in enumerate([all_pred, conv_pred, nonconv_pred], start=1):
                fig.add_trace(go.Histogram(x=predictions, xbins=dict(start=0, end=50, size=2), **histogram_style), row=row, col=col)

        apply_publication_style(fig)
        fig.update_layout(showlegend=False, height=900, width=1200, margin_t=60, bargap=0.03)

        for row in range(1, 4):
            for col in range(1, 4):
                fig.update_xaxes(title=None, dtick=5, range=[0, 50], row=row, col=col)
                if row == 3:
                    fig.update_xaxes(title=dict(text="<b>Number of Routes</b>", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=row, col=col)

                if col == 1:
                    fig.update_yaxes(title=dict(text="<b>Percentage (%)</b>", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=row, col=col)
                else:
                    fig.update_yaxes(title=None, row=row, col=col)

        return fig

__init__(predicted_routes, true_routes, k_vals=None)

Initializes the RouteAnalyzer.

Parameters:

Name Type Description Default
predicted_routes PathsProcessedType

Predicted routes.

required
true_routes list[str]

True routes.

required
k_vals list[int] | None

List of k values for top-k accuracy calculation.

None
Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def __init__(self, predicted_routes: PathsProcessedType, true_routes: list[str], k_vals: list[int] | None = None):
    """Initializes the RouteAnalyzer.

    Args:
        predicted_routes: Predicted routes.
        true_routes: True routes.
        k_vals: List of k values for top-k accuracy calculation.
    """
    self.predicted_routes = predicted_routes
    self.true_routes = true_routes
    self.k_vals = k_vals if k_vals is not None else [1, 2, 3, 4, 5, 10, 20, 50]
    self.convergent_idxs = get_convergent_indices(true_routes)
    self.non_convergent_idxs = set(range(len(true_routes))) - self.convergent_idxs

analyze_and_log_results()

Run full analysis and log results.

Returns:

Type Description
dict[str, dict[str, str]]

dict[str, dict[str, str]]: Dictionary of top-k accuracies.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def analyze_and_log_results(self) -> dict[str, dict[str, str]]:
    """Run full analysis and log results.

    Returns:
        dict[str, dict[str, str]]: Dictionary of top-k accuracies.
    """
    self.analyze_convergence_stats()
    results = self.calculate_top_k_accuracies()

    for route_type, accuracies in results.items():
        logger.info(f"\nTop-k accuracy for {route_type} routes:")
        logger.info(accuracies)

    return results

analyze_convergence_stats()

Analyze and log basic convergence statistics.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def analyze_convergence_stats(self) -> None:
    """Analyze and log basic convergence statistics."""
    n_convergent = len(self.convergent_idxs)
    total = len(self.true_routes)
    logger.info(f"Found {n_convergent} convergent routes out of {total} total routes")
    logger.info(f"Percentage convergent: {100 * n_convergent / total:.1f}%")

calculate_top_k_accuracies(save_path=None)

Calculate top-k accuracies for different route subsets and optionally save results.

Parameters:

Name Type Description Default
save_path Path | None

Optional path to save detailed accuracies to YAML file.

None

Returns:

Type Description
dict[str, dict[str, str]]

dict[str, dict[str, str]]: Dictionary of top-k accuracies.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def calculate_top_k_accuracies(self, save_path: Path | None = None) -> dict[str, dict[str, str]]:
    """Calculate top-k accuracies for different route subsets and optionally save results.

    Args:
        save_path: Optional path to save detailed accuracies to YAML file.

    Returns:
        dict[str, dict[str, str]]: Dictionary of top-k accuracies.
    """
    results = {}
    route_types = {"all": None, "convergent": self.non_convergent_idxs, "non_convergent": self.convergent_idxs}

    with tqdm(total=len(route_types), desc="Analyzing top-k accuracy") as pbar:
        for route_type, ignore_ids in route_types.items():
            pbar.set_description(f"{route_type} routes")
            _, perm_matches = find_matching_paths(self.predicted_routes, self.true_routes, ignore_ids=ignore_ids)
            results[route_type] = find_top_n_accuracy(perm_matches, self.k_vals)
            pbar.update(1)

    if save_path is not None:
        save_path = save_path / "top_k_accuracy_detailed.yaml"
        with open(save_path, "w") as f:
            yaml.dump(results, f, default_flow_style=False)
        logger.info(f"Saved detailed accuracies to {save_path}")

    return results

create_accuracy_by_length_plot(result_paths, datasets, configs, k_vals, title='') staticmethod

Create plot showing accuracy by route length.

Parameters:

Name Type Description Default
result_paths list[Path]

List of paths to result directories.

required
datasets list[DatasetDict]

List of datasets to analyze.

required
configs list[ModelPlotConfig]

List of model configurations.

required
k_vals list[int]

List of k values to calculate accuracy for.

required
title str

Title for the plot.

''

Returns:

Type Description
Figure

go.Figure: Plotly figure object.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
@staticmethod
def create_accuracy_by_length_plot(
    result_paths: list[Path],
    datasets: list[DatasetDict],
    configs: list[ModelPlotConfig],
    k_vals: list[int],
    title: str = "",
) -> go.Figure:
    """Create plot showing accuracy by route length.

    Args:
        result_paths: List of paths to result directories.
        datasets: List of datasets to analyze.
        configs: List of model configurations.
        k_vals: List of k values to calculate accuracy for.
        title: Title for the plot.

    Returns:
        go.Figure: Plotly figure object.
    """
    fig = go.Figure()

    cset = style.publication_colors
    colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

    for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
        paths_name = config.processed_paths_name

        with open(path / paths_name, "rb") as f:
            predicted_routes = pickle.load(f)

        lengths, step_stats = RouteAnalyzer._calculate_accuracy_by_length_data(predicted_routes, dataset, k_vals)

        for k_idx, k in enumerate(k_vals):
            accuracies = [
                step_stats[length].get(f"Top {k}", 0) / step_stats[length]["Total"] * 100 for length in lengths
            ]
            # fmt:off
            fig.add_trace(go.Bar(name=f"{dataset['ds_name']} (Top-{k})", x=lengths, y=accuracies, marker_color=colors[i * len(k_vals) + k_idx]))

    fig.update_layout(
        barmode="group",
        height=300,
        width=1000,
        xaxis=dict(title="<b>Route Length</b>", dtick=1),
        yaxis=dict(title="<b>Accuracy (%)</b>", dtick=10, range=[0, 82]),
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1.0),
    )
    # fmt: on
    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)
    fig.update_xaxes(showgrid=False)
    return fig

create_accuracy_by_length_subplots(result_paths, datasets, configs, k_vals, title='') staticmethod

Create plot showing accuracy by route length with subplots for all/convergent/non-convergent routes.

Parameters:

Name Type Description Default
result_paths list[Path]

List of paths to result directories.

required
datasets list[DatasetDict]

List of datasets to analyze.

required
configs list[ModelPlotConfig]

List of model configurations.

required
k_vals list[int]

List of k values to calculate accuracy for.

required
title str

Title for the plot.

''

Returns:

Type Description
Figure

go.Figure: Plotly figure object.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
@staticmethod
def create_accuracy_by_length_subplots(
    result_paths: list[Path],
    datasets: list[DatasetDict],
    configs: list[ModelPlotConfig],
    k_vals: list[int],
    title: str = "",
) -> go.Figure:
    """Create plot showing accuracy by route length with subplots for all/convergent/non-convergent routes.

    Args:
        result_paths: List of paths to result directories.
        datasets: List of datasets to analyze.
        configs: List of model configurations.
        k_vals: List of k values to calculate accuracy for.
        title: Title for the plot.

    Returns:
        go.Figure: Plotly figure object.
    """
    fig = make_subplots(
        rows=3,
        cols=1,
        subplot_titles=[
            f"<b>{t}</b>" for t in ("(a) all routes", "(b) convergent routes", "(c) non-convergent routes")
        ],
        vertical_spacing=0.12,
    )

    cset = style.publication_colors
    colors = [cset["primary_blue"], cset["dark_blue"], cset["purple"], cset["dark_purple"]]

    for i, (path, dataset, config) in enumerate(zip(result_paths, datasets, configs)):
        paths_name = config.processed_paths_name

        with open(path / paths_name, "rb") as f:
            predicted_routes = pickle.load(f)

        analyzer = RouteAnalyzer(predicted_routes, dataset["path_strings"])
        route_types = {
            "all": (None, 1),
            "convergent": (analyzer.non_convergent_idxs, 2),
            "non_convergent": (analyzer.convergent_idxs, 3),
        }

        for route_type, (ignore_ids, row) in route_types.items():
            lengths, step_stats = RouteAnalyzer._calculate_accuracy_by_length_data(
                predicted_routes, dataset, k_vals, ignore_ids=ignore_ids
            )

            for k_idx, k in enumerate(k_vals):
                accuracies = [
                    step_stats[length].get(f"Top {k}", 0) / step_stats[length]["Total"] * 100 for length in lengths
                ]
                # fmt:off
                fig.add_trace(go.Bar(name=f"{dataset['ds_name']} (Top-{k})",
                        x=lengths,y=accuracies, marker_color=colors[i * len(k_vals) + k_idx],
                        showlegend=row == 1, legendgroup=f"{dataset['ds_name']} (Top-{k})"
                    ), row=row, col=1)

    fig.update_layout(
        barmode="group",
        height=900,
        width=1000,
        legend=dict(orientation="h", yanchor="bottom", y=-0.15, xanchor="center", x=0.5),
    )

    for i in range(1, 4):
        fig.update_xaxes(title="<b>Route Length</b>", dtick=1, row=i, col=1, showgrid=False)
        fig.update_yaxes(title="<b>Accuracy (%)</b>", dtick=10, range=[0, 82], row=i, col=1)
    # fmt: on
    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)
    fig.update_xaxes(showgrid=False)
    return fig

create_comparative_bar_plots(result_paths, trace_names, k_vals=None, title='') staticmethod

Create comparative bar plots showing top-k accuracy for different configurations.

Parameters:

Name Type Description Default
result_paths list[Path]

List of paths to top_k_accuracy_detailed.yaml files.

required
trace_names list[str]

List of names for each trace (must match length of result_paths).

required
k_vals list[int] | None

Optional list of k values to show. If None, shows all k values.

None
title str

Title for the plot.

''

Returns:

Type Description
Figure

go.Figure: Plotly figure object.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
@staticmethod
def create_comparative_bar_plots(
    result_paths: list[Path], trace_names: list[str], k_vals: list[int] | None = None, title: str = ""
) -> go.Figure:
    """Create comparative bar plots showing top-k accuracy for different configurations.

    Args:
        result_paths: List of paths to top_k_accuracy_detailed.yaml files.
        trace_names: List of names for each trace (must match length of result_paths).
        k_vals: Optional list of k values to show. If None, shows all k values.
        title: Title for the plot.

    Returns:
        go.Figure: Plotly figure object.
    """
    if len(result_paths) != len(trace_names):
        raise ValueError("Number of result paths must match number of trace names")

    results = []
    for path in result_paths:
        with open(path / "top_k_accuracy_detailed.yaml") as f:
            results.append(yaml.safe_load(f))

    # fmt: off
    fig = make_subplots(rows=3, cols=1, horizontal_spacing=0.07, vertical_spacing=0.12,
        subplot_titles=[f"<b>{t}</b>" for t in ('(a) all routes', '(b) convergent routes', '(c) non-convergent routes')])

    categories = ['all', 'convergent', 'non_convergent']
    positions = [1, 2, 3]

    colors = style.colors_blue + style.colors_purple + style.colors_red

    for cat, pos in zip(categories, positions):
        x = list(results[0][cat].keys())
        x.sort(key=lambda k: int(k.split()[-1]))

        if k_vals is not None:
            k_vals_str = [f"Top {k}" for k in k_vals]
            x = [k for k in x if k in k_vals_str]

        for i, (result, name) in enumerate(zip(results, trace_names)):
            y = [float(result[cat][k].strip('%')) for k in x]

            fig.add_trace(
                go.Bar(name=name, x=x, y=y, showlegend=pos == 1, marker_color=colors[i % len(colors)],
                    legendgroup=name,), row=pos, col=1)

    fig.update_layout(
        title=dict(text=title, x=0.5, xanchor='center'),
        barmode='group', height=600, width=1000,
        legend=dict(font=get_font_dict(FONT_SIZES["legend"]), orientation="h", yanchor="bottom",
            y=-0.20, xanchor="center", x=0.5, entrywidth=140, tracegroupgap=0))

    style.AXIS_STYLE["linecolor"] = None
    apply_publication_style(fig)

    for i in range(1, 4):
        fig.update_yaxes(dtick=10, title=dict(text="Accuracy (%)", font=get_font_dict(FONT_SIZES["axis_title"])), row=i, col=1)
        fig.update_xaxes(showgrid=False, row=i, col=1)
    # fmt: on
    return fig

visualize_route_distributions(dataset_name='')

Create a publication-quality figure showing the distribution of predicted routes.

Parameters:

Name Type Description Default
dataset_name str

Name of the dataset being analyzed, used in plot title.

''

Returns:

Type Description
Figure

go.Figure: Plotly figure object.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def visualize_route_distributions(self, dataset_name: str = "") -> go.Figure:
    """Create a publication-quality figure showing the distribution of predicted routes.

    Args:
        dataset_name: Name of the dataset being analyzed, used in plot title.

    Returns:
        go.Figure: Plotly figure object.
    """
    n_predictions = [len(routes) for routes in self.predicted_routes]

    conv_predictions = [n_predictions[i] for i in self.convergent_idxs]
    nonconv_predictions = [n_predictions[i] for i in self.non_convergent_idxs]

    mean_all, median_all, mean_all_filtered, median_all_filtered = calculate_prediction_stats(n_predictions)
    mean_conv, median_conv, mean_conv_filtered, median_conv_filtered = calculate_prediction_stats(conv_predictions)
    mean_nonconv, median_nonconv, mean_nonconv_filtered, median_nonconv_filtered = calculate_prediction_stats(
        nonconv_predictions
    )

    # fmt: off
    fig = make_subplots(rows=1, cols=3,
        subplot_titles=(
            f'All Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_all:.1f}, median: {median_all:.1f} (mean*: {mean_all_filtered:.1f}, median*: {median_all_filtered:.1f})</span>',
            f'Convergent Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_conv:.1f}, median: {median_conv:.1f} (mean*: {mean_conv_filtered:.1f}, median*: {median_conv_filtered:.1f})</span>',
            f'Non-convergent Routes<br><span style="font-size:{FONT_SIZES["subplot_title"]}px">mean: {mean_nonconv:.1f}, median: {median_nonconv:.1f} (mean*: {mean_nonconv_filtered:.1f}, median*: {median_nonconv_filtered:.1f})</span>'
        ), horizontal_spacing=0.1)

    histogram_style = dict(opacity =0.75, nbinsx=30,histnorm='percent',marker_color=style.publication_colors["dark_blue"])

    data = [(n_predictions, "All"), (conv_predictions, "Convergent"), (nonconv_predictions, "Non-convergent")]
    for i, (predictions, name) in enumerate(data, start=1):
        fig.add_trace(go.Histogram(x=predictions, name=name, **histogram_style), row=1, col=i)

    title = "Distribution of Predicted Routes per Target"
    if dataset_name:
        title = f"{title} - {dataset_name}"

    fig.update_layout(title=dict(text=title, x=0.5, xanchor='center'), showlegend=False, height=400, width=1200,)

    apply_publication_style(fig)

    for i in range(1, 4):
        fig.update_xaxes(title=dict(text="Number of Predicted Routes", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=1, col=i)
        fig.update_yaxes(title=dict(text="Percentage (%)", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=1, col=i)
    # fmt: on
    return fig

visualize_route_processing_stages(valid_routes, processed_routes_no_stock, processed_routes_with_stock, true_routes, dataset_name='', show_filtered_stats=False) staticmethod

Create a publication-quality figure showing the distribution of routes at different processing stages.

Parameters:

Name Type Description Default
valid_routes PathsProcessedType

Valid routes from beam search.

required
processed_routes_no_stock PathsProcessedType

Routes after canonicalization/removing repetitions.

required
processed_routes_with_stock PathsProcessedType

Routes after applying stock filter.

required
true_routes list[str]

True routes for convergence analysis.

required
dataset_name str

Name of the dataset being analyzed.

''
show_filtered_stats bool

Whether to show filtered statistics (mean and median).

False

Returns:

Type Description
Figure

go.Figure: Plotly figure object.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
@staticmethod
def visualize_route_processing_stages(
    valid_routes: PathsProcessedType,
    processed_routes_no_stock: PathsProcessedType,
    processed_routes_with_stock: PathsProcessedType,
    true_routes: list[str],
    dataset_name: str = "",
    show_filtered_stats: bool = False,
) -> go.Figure:
    """Create a publication-quality figure showing the distribution of routes at different processing stages.

    Args:
        valid_routes: Valid routes from beam search.
        processed_routes_no_stock: Routes after canonicalization/removing repetitions.
        processed_routes_with_stock: Routes after applying stock filter.
        true_routes: True routes for convergence analysis.
        dataset_name: Name of the dataset being analyzed.
        show_filtered_stats: Whether to show filtered statistics (mean* and median*).

    Returns:
        go.Figure: Plotly figure object.
    """
    # Get convergent indices
    convergent_idxs = get_convergent_indices(true_routes)
    non_convergent_idxs = set(range(len(true_routes))) - convergent_idxs

    def get_predictions_by_type(routes: PathsProcessedType) -> tuple[list[int], list[int], list[int]]:
        all_predictions = [len(routes) for routes in routes]
        conv_predictions = [all_predictions[i] for i in convergent_idxs]
        nonconv_predictions = [all_predictions[i] for i in non_convergent_idxs]
        return all_predictions, conv_predictions, nonconv_predictions

    valid_all, valid_conv, valid_nonconv = get_predictions_by_type(valid_routes)
    no_stock_all, no_stock_conv, no_stock_nonconv = get_predictions_by_type(processed_routes_no_stock)
    with_stock_all, with_stock_conv, with_stock_nonconv = get_predictions_by_type(processed_routes_with_stock)

    # Create subplot titles
    def create_subtitle(stage: str, predictions: list[int]) -> str:
        mean, median, mean_f, median_f = calculate_prediction_stats(predictions)
        base = f"{stage}<br><span style=\"font-size:{FONT_SIZES['subplot_title']-4}px\">"
        stats = f"mean={mean:.1f}, median={median:.1f}"
        if show_filtered_stats:
            stats += f" (μ*={mean_f:.1f}, m*={median_f:.1f})"
        return base + stats + "</span>"

    # fmt:off
    fig = make_subplots(rows=3, cols=3,
        subplot_titles=[
            create_subtitle("<b>(a) valid routes (all)</b>", valid_all),
            create_subtitle("<b>(b) valid routes (convergent)</b>", valid_conv),
            create_subtitle("<b>(c) valid routes (non-convergent)</b>", valid_nonconv),
            create_subtitle("<b>(d) after canonicalization (all)</b>", no_stock_all),
            create_subtitle("<b>(e) after canonicalization (convergent)</b>", no_stock_conv),
            create_subtitle("<b>(f) after canonicalization (non-convergent)</b>", no_stock_nonconv),
            create_subtitle("<b>(g) after stock filter (all)</b>", with_stock_all),
            create_subtitle("<b>(h) after stock filter (convergent)</b>", with_stock_conv),
            create_subtitle("<b>(i) after stock filter (non-convergent)</b>", with_stock_nonconv),
        ], vertical_spacing=0.10, horizontal_spacing=0.05)

    histogram_style = dict(histnorm='percent', marker_color=style.publication_colors["dark_blue"], marker_line_width=0)

    data = [
        (valid_all, valid_conv, valid_nonconv),
        (no_stock_all, no_stock_conv, no_stock_nonconv),
        (with_stock_all, with_stock_conv, with_stock_nonconv)
    ]

    for row, (all_pred, conv_pred, nonconv_pred) in enumerate(data, start=1):
        for col, predictions in enumerate([all_pred, conv_pred, nonconv_pred], start=1):
            fig.add_trace(go.Histogram(x=predictions, xbins=dict(start=0, end=50, size=2), **histogram_style), row=row, col=col)

    apply_publication_style(fig)
    fig.update_layout(showlegend=False, height=900, width=1200, margin_t=60, bargap=0.03)

    for row in range(1, 4):
        for col in range(1, 4):
            fig.update_xaxes(title=None, dtick=5, range=[0, 50], row=row, col=col)
            if row == 3:
                fig.update_xaxes(title=dict(text="<b>Number of Routes</b>", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=row, col=col)

            if col == 1:
                fig.update_yaxes(title=dict(text="<b>Percentage (%)</b>", font=get_font_dict(FONT_SIZES["axis_title"]), standoff=15), row=row, col=col)
            else:
                fig.update_yaxes(title=None, row=row, col=col)

    return fig

calculate_prediction_stats(predictions)

Calculate mean and median statistics for a list of predictions.

Parameters:

Name Type Description Default
predictions list[int]

List of prediction counts.

required

Returns:

Type Description
float

Tuple of (mean, median, filtered_mean, filtered_median) where filtered

float

versions only consider predictions with count > 0.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def calculate_prediction_stats(predictions: list[int]) -> tuple[float, float, float, float]:
    """Calculate mean and median statistics for a list of predictions.

    Args:
        predictions: List of prediction counts.

    Returns:
        Tuple of (mean, median, filtered_mean, filtered_median) where filtered
        versions only consider predictions with count > 0.
    """
    mean = np.float64(np.mean(predictions)).item()
    median = np.float64(np.median(predictions)).item()

    filtered = [x for x in predictions if x > 0]
    filtered_mean = np.float64(np.mean(filtered)).item() if filtered else 0.0
    filtered_median = np.float64(np.median(filtered)).item() if filtered else 0.0

    return mean, median, filtered_mean, filtered_median

get_convergent_indices(path_strings)

Identify indices of convergent routes in dataset.

Parameters:

Name Type Description Default
path_strings list[str]

List of path strings.

required

Returns:

Type Description
set[int]

set[int]: Set of indices of convergent routes.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def get_convergent_indices(path_strings: list[str]) -> set[int]:
    """Identify indices of convergent routes in dataset.

    Args:
        path_strings: List of path strings.

    Returns:
        set[int]: Set of indices of convergent routes.
    """
    convergent_idxs = set()
    logger.info("Finding convergent routes")
    for i, path_str in enumerate(tqdm(path_strings)):
        path_dict = eval(path_str)
        if is_convergent(path_dict):
            convergent_idxs.add(i)
    return convergent_idxs

load_predicted_routes(path)

Load predicted routes from a pickle file.

Parameters:

Name Type Description Default
path Path

Path to the pickle file.

required

Returns:

Name Type Description
PathsProcessedType PathsProcessedType

Loaded predicted routes.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def load_predicted_routes(path: Path) -> PathsProcessedType:
    """Load predicted routes from a pickle file.

    Args:
        path: Path to the pickle file.

    Returns:
        PathsProcessedType: Loaded predicted routes.
    """
    with open(path, "rb") as f:
        routes: PathsProcessedType = pickle.load(f)
    logger.info(f"Loaded {len(routes)} predicted routes")
    return routes

process_model_configs(eval_path, configs, dataset)

Process model configurations and ensure top-k accuracies are calculated.

Parameters:

Name Type Description Default
eval_path Path

Path to evaluation directory.

required
configs list[ModelPlotConfig]

List of model configurations.

required
dataset DatasetDict

Dataset to process.

required

Returns:

Type Description
tuple[list[Path], list[str]]

Tuple of (result_paths, trace_names) for plotting.

Source code in src/directmultistep/analysis/paper/linear_vs_convergent.py
def process_model_configs(
    eval_path: Path, configs: list[ModelPlotConfig], dataset: DatasetDict
) -> tuple[list[Path], list[str]]:
    """Process model configurations and ensure top-k accuracies are calculated.

    Args:
        eval_path: Path to evaluation directory.
        configs: List of model configurations.
        dataset: Dataset to process.

    Returns:
        Tuple of (result_paths, trace_names) for plotting.
    """
    result_paths = []
    trace_names = []

    for config in configs:
        res_path = config.get_result_path(eval_path)
        accuracy_file = res_path / "top_k_accuracy_detailed.yaml"

        if not accuracy_file.exists():
            logger.info(f"Calculating accuracies for {config.display_name}...")
            predicted_routes = load_predicted_routes(res_path / config.processed_paths_name)
            analyzer = RouteAnalyzer(predicted_routes, dataset["path_strings"])
            analyzer.calculate_top_k_accuracies(save_path=res_path)

        result_paths.append(res_path)
        trace_names.append(config.display_name)

    return result_paths, trace_names