Skip to content

Visualizing Routes

Example use

To visualize a path string, you can use the following snippet:

from directmultistep.utils.web_visualize import draw_tree_from_path_string

path = "{'smiles':'O=C(c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1)N1CCN(CC2CC2)CC1','children':[{'smiles':'O=C(O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(NS(=O)(=O)c2cccc3cccnc23)cc1','children':[{'smiles':'CCOC(=O)c1ccc(N)cc1'},{'smiles':'O=S(=O)(Cl)c1cccc2cccnc12'}]}]},{'smiles':'C1CN(CC2CC2)CCN1'}]}"

svg_str = draw_tree_from_path_string(
    path_string=path,
    save_path=Path("data/figures/desired_file_name"),
    width=400,
    height=400,
    x_margin=50,
    y_margin=100,
    theme="light",
)

Source Code

directmultistep.utils.web_visualize

ThemeType = Literal['light', 'dark'] module-attribute

FilteredDict

Bases: TypedDict

A dictionary format for multistep routes, used in DirectMultiStep models.

This dictionary is designed to represent a node in a synthetic route tree. It contains the SMILES string of a molecule and a list of its child nodes. To get its string format, use stringify_dict.

Attributes:

Name Type Description
smiles str

SMILES string of the molecule.

children list[FilteredDict]

List of child nodes, each a FilteredDict.

Source code in src/directmultistep/utils/pre_process.py
class FilteredDict(TypedDict, total=False):
    """A dictionary format for multistep routes, used in DirectMultiStep models.

    This dictionary is designed to represent a node in a synthetic route tree.
    It contains the SMILES string of a molecule and a list of its child nodes.
    To get its string format, use `stringify_dict`.

    Attributes:
        smiles: SMILES string of the molecule.
        children: List of child nodes, each a FilteredDict.
    """

    smiles: str
    children: list["FilteredDict"]

ColorPalette

Bases: NamedTuple

Defines a color palette for drawing molecules.

Attributes:

Name Type Description
atom_colors dict[int, tuple[float, float, float]]

A dictionary mapping atomic numbers to RGB color tuples.

annotation tuple[float, float, float, float]

An RGBA color tuple for annotations.

border tuple[float, float, float]

An RGB color tuple for borders.

text tuple[float, float, float]

An RGB color tuple for text.

background tuple[float, float, float, float]

An RGBA color tuple for background.

Source code in src/directmultistep/utils/web_visualize.py
class ColorPalette(NamedTuple):
    """Defines a color palette for drawing molecules.

    Attributes:
        atom_colors: A dictionary mapping atomic numbers to RGB color tuples.
        annotation: An RGBA color tuple for annotations.
        border: An RGB color tuple for borders.
        text: An RGB color tuple for text.
        background: An RGBA color tuple for background.
    """

    atom_colors: dict[int, tuple[float, float, float]]
    annotation: tuple[float, float, float, float]
    border: tuple[float, float, float]
    text: tuple[float, float, float]
    background: tuple[float, float, float, float]

RetroSynthesisTree

Basic tree structure for retrosynthesis visualization.

Source code in src/directmultistep/utils/web_visualize.py
class RetroSynthesisTree:
    """Basic tree structure for retrosynthesis visualization."""

    def __init__(self, idx: int = 0) -> None:
        """
        Initializes a new node in the retrosynthesis tree.

        Args:
            idx: The unique identifier for the node.
        """
        self.node_id = idx
        self.smiles = ""
        self.children: list[RetroSynthesisTree] = []

    def build_tree(self, path_dict: FilteredDict) -> int:
        """Recursively builds the retrosynthesis tree from a dictionary.

        Args:
            path_dict: A dictionary representing the tree structure.

        Returns:
            The next available node ID.
        """
        self.smiles = path_dict["smiles"]
        cur_id = self.node_id + 1

        if "children" in path_dict:
            for child in path_dict["children"]:
                node = RetroSynthesisTree(idx=cur_id)
                cur_id = node.build_tree(path_dict=child)
                self.children.append(node)
        return cur_id

    def __str__(self) -> str:
        """Returns a string representation of the tree node and its children."""
        child_ids = [str(child.node_id) for child in self.children]
        return f"Node ID: {self.node_id}, Children: {child_ids}, SMILES: {self.smiles}\n" + "".join(
            str(child) for child in self.children
        )

__init__(idx=0)

Initializes a new node in the retrosynthesis tree.

Parameters:

Name Type Description Default
idx int

The unique identifier for the node.

0
Source code in src/directmultistep/utils/web_visualize.py
def __init__(self, idx: int = 0) -> None:
    """
    Initializes a new node in the retrosynthesis tree.

    Args:
        idx: The unique identifier for the node.
    """
    self.node_id = idx
    self.smiles = ""
    self.children: list[RetroSynthesisTree] = []

__str__()

Returns a string representation of the tree node and its children.

Source code in src/directmultistep/utils/web_visualize.py
def __str__(self) -> str:
    """Returns a string representation of the tree node and its children."""
    child_ids = [str(child.node_id) for child in self.children]
    return f"Node ID: {self.node_id}, Children: {child_ids}, SMILES: {self.smiles}\n" + "".join(
        str(child) for child in self.children
    )

build_tree(path_dict)

Recursively builds the retrosynthesis tree from a dictionary.

Parameters:

Name Type Description Default
path_dict FilteredDict

A dictionary representing the tree structure.

required

Returns:

Type Description
int

The next available node ID.

Source code in src/directmultistep/utils/web_visualize.py
def build_tree(self, path_dict: FilteredDict) -> int:
    """Recursively builds the retrosynthesis tree from a dictionary.

    Args:
        path_dict: A dictionary representing the tree structure.

    Returns:
        The next available node ID.
    """
    self.smiles = path_dict["smiles"]
    cur_id = self.node_id + 1

    if "children" in path_dict:
        for child in path_dict["children"]:
            node = RetroSynthesisTree(idx=cur_id)
            cur_id = node.build_tree(path_dict=child)
            self.children.append(node)
    return cur_id

TreeDimensions

Bases: NamedTuple

Represents the dimensions of a tree or subtree.

Source code in src/directmultistep/utils/web_visualize.py
class TreeDimensions(NamedTuple):
    """Represents the dimensions of a tree or subtree."""

    width: int
    height: int

compute_subtree_dimensions(tree, img_width, img_height, y_offset)

Compute dimensions of a subtree for layout.

Parameters:

Name Type Description Default
tree RetroSynthesisTree

The subtree to compute dimensions for.

required
img_width int

The width of the molecule image.

required
img_height int

The height of the molecule image.

required
y_offset int

The vertical offset between nodes.

required

Returns:

Type Description
TreeDimensions

The dimensions of the subtree.

Source code in src/directmultistep/utils/web_visualize.py
def compute_subtree_dimensions(
    tree: RetroSynthesisTree, img_width: int, img_height: int, y_offset: int
) -> TreeDimensions:
    """Compute dimensions of a subtree for layout.

    Args:
        tree: The subtree to compute dimensions for.
        img_width: The width of the molecule image.
        img_height: The height of the molecule image.
        y_offset: The vertical offset between nodes.

    Returns:
        The dimensions of the subtree.
    """
    if not tree.children:
        return TreeDimensions(img_width, img_height + y_offset)

    width = img_width
    height = img_height + y_offset

    for child in tree.children:
        child_dims = compute_subtree_dimensions(child, img_width, img_height, y_offset)
        width += child_dims.width
        height = max(height, child_dims.height + img_height + y_offset)

    return TreeDimensions(width, height)

compute_canvas_dimensions(tree, img_width, img_height, y_offset)

Compute overall canvas dimensions.

Parameters:

Name Type Description Default
tree RetroSynthesisTree

The retrosynthesis tree.

required
img_width int

The width of the molecule image.

required
img_height int

The height of the molecule image.

required
y_offset int

The vertical offset between nodes.

required

Returns:

Type Description
TreeDimensions

The dimensions of the canvas.

Source code in src/directmultistep/utils/web_visualize.py
def compute_canvas_dimensions(
    tree: RetroSynthesisTree, img_width: int, img_height: int, y_offset: int
) -> TreeDimensions:
    """Compute overall canvas dimensions.

    Args:
        tree: The retrosynthesis tree.
        img_width: The width of the molecule image.
        img_height: The height of the molecule image.
        y_offset: The vertical offset between nodes.

    Returns:
        The dimensions of the canvas.
    """
    child_dims = [compute_subtree_dimensions(child, img_width, img_height, y_offset) for child in tree.children]
    width = sum(d.width for d in child_dims)
    height = max((d.height for d in child_dims), default=0) + img_height + y_offset
    return TreeDimensions(width, height + 100)

check_overlap(new_x, new_y, existing_boxes, img_width, img_height)

Check if a new node overlaps with existing nodes.

Parameters:

Name Type Description Default
new_x int

The x-coordinate of the new node.

required
new_y int

The y-coordinate of the new node.

required
existing_boxes list[tuple[int, int]]

A list of tuples representing the coordinates of existing nodes.

required
img_width int

The width of the molecule image.

required
img_height int

The height of the molecule image.

required

Returns:

Type Description
bool

True if there is an overlap, False otherwise.

Source code in src/directmultistep/utils/web_visualize.py
def check_overlap(
    new_x: int,
    new_y: int,
    existing_boxes: list[tuple[int, int]],
    img_width: int,
    img_height: int,
) -> bool:
    """Check if a new node overlaps with existing nodes.

    Args:
        new_x: The x-coordinate of the new node.
        new_y: The y-coordinate of the new node.
        existing_boxes: A list of tuples representing the coordinates of existing nodes.
        img_width: The width of the molecule image.
        img_height: The height of the molecule image.

    Returns:
        True if there is an overlap, False otherwise.
    """
    return any(
        (x - img_width < new_x < x + img_width) and (y - img_height < new_y < y + img_height) for x, y in existing_boxes
    )

draw_molecule(smiles, size, theme)

Render a SMILES string as base64-encoded PNG.

Parameters:

Name Type Description Default
smiles str

The SMILES string of the molecule.

required
size tuple[int, int]

The desired size (width, height) of the image.

required
theme ThemeType

The color theme ("light" or "dark").

required

Returns:

Type Description
str

The base64-encoded PNG image data.

Raises:

Type Description
ValueError

If the SMILES string is invalid.

Source code in src/directmultistep/utils/web_visualize.py
def draw_molecule(smiles: str, size: tuple[int, int], theme: ThemeType) -> str:
    """Render a SMILES string as base64-encoded PNG.

    Args:
        smiles: The SMILES string of the molecule.
        size: The desired size (width, height) of the image.
        theme: The color theme ("light" or "dark").

    Returns:
        The base64-encoded PNG image data.

    Raises:
        ValueError: If the SMILES string is invalid.
    """
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        raise ValueError(f"Invalid SMILES string: {smiles}")

    draw_width, draw_height = size
    drawer = rdMolDraw2D.MolDraw2DCairo(draw_width, draw_height)
    opts = drawer.drawOptions()

    palette = DARK_PALETTE if theme == "dark" else LIGHT_PALETTE
    background_color = palette.background
    if not __support_pdf__:
        background_color = (
            background_color[0],
            background_color[1],
            background_color[2],
            0,
        )
    opts.setBackgroundColour(background_color)
    opts.setAtomPalette(palette.atom_colors)
    opts.setAnnotationColour(palette.annotation)

    drawer.DrawMolecule(mol)
    drawer.FinishDrawing()

    png_data = drawer.GetDrawingText()
    return base64.b64encode(png_data).decode("utf-8")

draw_tree_svg(tree, width, height, x_margin, y_margin, theme, force_canvas_width=None)

Create SVG visualization of the retrosynthesis tree.

Parameters:

Name Type Description Default
tree RetroSynthesisTree

The retrosynthesis tree to visualize.

required
width int

The width of each molecule image.

required
height int

The height of each molecule image.

required
x_margin int

The horizontal margin between nodes.

required
y_margin int

The vertical margin between nodes.

required
theme ThemeType

The color theme ("light" or "dark").

required
force_canvas_width int | None

An optional width to force for the canvas.

None

Returns:

Type Description
str

The SVG content as a string.

Source code in src/directmultistep/utils/web_visualize.py
def draw_tree_svg(
    tree: RetroSynthesisTree,
    width: int,
    height: int,
    x_margin: int,
    y_margin: int,
    theme: ThemeType,
    force_canvas_width: int | None = None,
) -> str:
    """Create SVG visualization of the retrosynthesis tree.

    Args:
        tree: The retrosynthesis tree to visualize.
        width: The width of each molecule image.
        height: The height of each molecule image.
        x_margin: The horizontal margin between nodes.
        y_margin: The vertical margin between nodes.
        theme: The color theme ("light" or "dark").
        force_canvas_width: An optional width to force for the canvas.

    Returns:
        The SVG content as a string.
    """

    initial_dims = compute_canvas_dimensions(tree, width, height, y_margin)
    canvas_width = force_canvas_width if force_canvas_width else initial_dims.width
    drawing = svgwrite.Drawing(size=(canvas_width, initial_dims.height))

    existing_boxes: list[tuple[int, int]] = []
    memo = {"left_x": float("inf"), "right_x": float("-inf")}

    def draw_node(node: RetroSynthesisTree, nx: int, ny: int) -> None:
        """Draws a single node of the retrosynthesis tree.

        Args:
            node: The tree node to draw.
            nx: The x-coordinate of the node.
            ny: The y-coordinate of the node.
        """
        while check_overlap(nx, ny, existing_boxes, width, height) or check_overlap(
            nx, ny - y_margin, existing_boxes, width, height
        ):
            nx += width // 2

        existing_boxes.append((nx, ny))
        memo["left_x"] = min(memo["left_x"], nx - width // 2)
        memo["right_x"] = max(memo["right_x"], nx + width // 2)

        # Draw molecule
        b64_img = draw_molecule(node.smiles, (width, height), theme)
        drawing.add(
            drawing.image(
                href=f"data:image/png;base64,{b64_img}",
                insert=(nx, ny),
                size=(width, height),
            )
        )

        # Draw border
        palette = DARK_PALETTE if theme == "dark" else LIGHT_PALETTE
        border_color = svgwrite.rgb(*[c * 255 for c in palette.border])

        box = drawing.rect(
            insert=(nx, ny),
            size=(width, height),
            rx=20,
            ry=20,
            fill="none",
            stroke=border_color,
            stroke_width=4,
        )
        drawing.add(box)

        # Draw node ID
        text_color = svgwrite.rgb(*[c * 255 for c in palette.text])
        node_label = drawing.text(
            f"ID: {node.node_id}",
            insert=(nx, ny + height + 35),
            fill=text_color,
            font_size=20,
            font_family="Arial",
        )
        drawing.add(node_label)

        # Draw children
        child_count = len(node.children)
        if child_count > 0:
            next_x = nx if child_count == 1 else nx - (child_count - 1) * width // 2
            next_y = ny + y_margin + height

            for child in node.children:
                # Draw connecting line
                line = drawing.line(
                    start=(nx + width / 2, ny + height),
                    end=(next_x + width / 2, next_y),
                    stroke=border_color,
                    stroke_width=4,
                )
                drawing.add(line)

                draw_node(child, next_x, next_y)
                next_x += x_margin + width

    # Draw the root
    root_x = (canvas_width - width) // 2
    draw_node(tree, root_x, 50)

    # Adjust canvas if needed
    final_width = int(memo["right_x"] - memo["left_x"] + width * 2 + x_margin * 2)
    if final_width > canvas_width and force_canvas_width is None:
        return draw_tree_svg(
            tree,
            width,
            height,
            x_margin,
            y_margin,
            theme,
            force_canvas_width=final_width,
        )

    return cast(str, drawing.tostring())

create_tree_from_path_string(path_string)

Parse a dictionary-like string into a RetroSynthesisTree.

Parameters:

Name Type Description Default
path_string str

A string representing the tree structure as a dictionary.

required

Returns:

Type Description
RetroSynthesisTree

A RetroSynthesisTree object.

Source code in src/directmultistep/utils/web_visualize.py
def create_tree_from_path_string(path_string: str) -> RetroSynthesisTree:
    """Parse a dictionary-like string into a RetroSynthesisTree.

    Args:
        path_string: A string representing the tree structure as a dictionary.

    Returns:
        A RetroSynthesisTree object.
    """
    path_dict: FilteredDict = eval(path_string)  # TODO: Use safer parsing
    retro_tree = RetroSynthesisTree()
    retro_tree.build_tree(path_dict=path_dict)
    return retro_tree

draw_tree_from_path_string(path_string, save_path, width=400, height=400, x_margin=50, y_margin=100, theme='light')

Generate SVG and PDF visualizations from a path string.

Parameters:

Name Type Description Default
path_string str

A string representing the tree structure as a dictionary.

required
save_path Path

The path to save the generated SVG and PDF files.

required
width int

The width of each molecule image.

400
height int

The height of each molecule image.

400
x_margin int

The horizontal margin between nodes.

50
y_margin int

The vertical margin between nodes.

100
theme str

The color theme ("light" or "dark").

'light'

Returns:

Type Description
str

The SVG content as a string.

Source code in src/directmultistep/utils/web_visualize.py
def draw_tree_from_path_string(
    path_string: str,
    save_path: Path,
    width: int = 400,
    height: int = 400,
    x_margin: int = 50,
    y_margin: int = 100,
    theme: str = "light",
) -> str:
    """Generate SVG and PDF visualizations from a path string.

    Args:
        path_string: A string representing the tree structure as a dictionary.
        save_path: The path to save the generated SVG and PDF files.
        width: The width of each molecule image.
        height: The height of each molecule image.
        x_margin: The horizontal margin between nodes.
        y_margin: The vertical margin between nodes.
        theme: The color theme ("light" or "dark").

    Returns:
        The SVG content as a string.
    """
    assert theme in ["light", "dark"]
    theme = cast(ThemeType, theme)

    retro_tree = create_tree_from_path_string(path_string)
    svg_content = draw_tree_svg(
        retro_tree,
        width=width,
        height=height,
        x_margin=x_margin,
        y_margin=y_margin,
        theme=theme,
    )

    svg_path = save_path.with_suffix(".svg")
    with open(svg_path, "w", encoding="utf-8") as f:
        f.write(svg_content)

    # Convert to PDF
    drawing = svg2rlg(str(svg_path))
    renderPDF.drawToFile(drawing, str(save_path.with_suffix(".pdf")))
    # remove SVG file
    svg_path.unlink()
    return svg_content