Skip to content

Torch Dataset for Routes

This module provides a custom PyTorch Dataset class for handling reaction routes. It includes functionalities for tokenizing SMILES strings, reaction paths, and context information, as well as preparing data for training and generation.

Example Use

tokenize_path_string is the most important function. It tokenizes a reaction path string. It uses a regular expression to split the string into tokens, and it can optionally add start-of-sequence (<SOS>) and end-of-sequence (?) tokens.

from directmultistep.utils.dataset import tokenize_path_string

path_string = "{'smiles':'CC','children':[{'smiles':'CC(=O)O'}]}"
tokens = tokenize_path_string(path_string)
print(tokens)

Notes on Path Start

In the RoutesDataset class, the get_generation_with_sm and get_generation_no_sm methods return an initial path tensor. This tensor is created from a path_start string, which is a partial path string that the model will start generating from. The path_start is "{'smiles': 'product_smiles', 'children': [{'smiles':". The model will generate the rest of the path string from this starting point.

This design is important because a trained model always generates this path_start at the beginning of the sequence. By providing this as the initial input, we avoid wasting time generating this part and can focus on generating the rest of the reaction path.

The prepare_input_tensors function in directmultistep.generate allows for the provision of a custom path_start string. This is useful when you want to initiate the generation process from a specific point in the reaction path, instead of the default starting point. By modifying the path_start argument, you can control the initial state of the generation and explore different reaction pathways with user-defined intermediates.

Source Code

directmultistep.utils.dataset

RoutesDataset

Bases: RoutesProcessing

Dataset for multi-step reaction routes.

Source code in src/directmultistep/utils/dataset.py
class RoutesDataset(RoutesProcessing):
    """Dataset for multi-step reaction routes."""

    def __init__(
        self,
        metadata_path: Path,
        products: list[str],
        path_strings: list[str],
        n_steps_list: list[int],
        starting_materials: list[str] | None = None,
        mode: str = "training",
        name_idx: dict[str, list[int]] | None = None,
    ) -> None:
        """Initializes the RoutesDataset.

        Args:
            metadata_path: Path to the metadata file (YAML).
            products: List of product SMILES strings.
            path_strings: List of reaction path strings.
            n_steps_list: List of integers representing the number of steps in each path.
            starting_materials: List of starting material SMILES strings.
            mode: Either "training" or "generation".
            name_idx: A dictionary mapping names to lists of indices.
        """
        super().__init__(metadata_path)
        self.products = products
        self.path_strings = path_strings
        self.step_lengths = n_steps_list
        self.sms = starting_materials
        # name_idx is an optional attribute that shows labels for items in the dataset
        # currently used for evals on pharma compounds
        self.name_idx = name_idx
        assert mode in ["training", "generation"], "mode must be either 'training' or 'generation'"
        self.mode = mode

    def __repr__(self) -> str:
        """Returns a string representation of the dataset."""
        sms_str = "SM (enabled)" if self.sms is not None else "SM (disabled)"
        return f"RoutesDataset(mode={self.mode}, len={len(self)}, {sms_str})"

    def __getitem__(self, index: int) -> tuple[Tensor, ...]:
        """Retrieves an item from the dataset.

        Args:
            index: The index of the item to retrieve.

        Returns:
            A tuple of tensors representing the input and output data.
        """
        if self.mode == "training":
            if self.sms is not None:
                return self.get_training_with_sm(index)
            else:
                return self.get_training_no_sm(index)
        elif self.mode == "generation":
            if self.sms is not None:
                return self.get_generation_with_sm(index)
            else:
                return self.get_generation_no_sm(index)
        else:
            raise ValueError(f"Invalid mode: {self.mode}")

    def __len__(self) -> int:
        """Returns the number of items in the dataset."""
        return len(self.products)

    def get_training_with_sm(self, index: int) -> tuple[Tensor, ...]:
        """Retrieves a training item with starting materials.

        Args:
            index: The index of the item to retrieve.

        Returns:
            A tuple of tensors: encoder input, decoder input, and step length.
        """
        assert self.sms is not None, "starting materials are not provided"
        product_item = self.smile_to_tokens(self.products[index], self.product_max_length)
        one_sm_item = self.smile_to_tokens(self.sms[index], self.sm_max_length)
        seq_encoder_item = torch.cat((product_item, one_sm_item), dim=0)
        seq_decoder_item = self.path_string_to_tokens(self.path_strings[index], self.seq_out_max_length)

        step_item = Tensor([self.step_lengths[index]])
        return seq_encoder_item, seq_decoder_item, step_item

    def get_generation_with_sm(self, index: int) -> tuple[Tensor, ...]:
        """Retrieves a generation item with starting materials.

        Args:
            index: The index of the item to retrieve.

        Returns:
            A tuple of tensors: encoder input, step length, and initial path tensor.
        """
        assert self.sms is not None, "starting materials are not provided"
        product_item = self.smile_to_tokens(self.products[index], self.product_max_length)
        one_sm_item = self.smile_to_tokens(self.sms[index], self.sm_max_length)
        seq_encoder_item = torch.cat((product_item, one_sm_item), dim=0)

        step_item = Tensor([self.step_lengths[index]])
        smile_dict = {"smiles": self.products[index], "children": [{"smiles": ""}]}
        path_start = str(smile_dict).replace(" ", "")[:-4]
        path_tens = self.path_string_to_tokens(path_start, max_length=None, add_eos=False)
        return seq_encoder_item, step_item, path_tens

    def get_training_no_sm(self, index: int) -> tuple[Tensor, ...]:
        """Retrieves a training item without starting materials.

        Args:
            index: The index of the item to retrieve.

        Returns:
            A tuple of tensors: encoder input, decoder input, and step length.
        """
        seq_encoder_item = self.smile_to_tokens(self.products[index], self.product_max_length)
        seq_decoder_item = self.path_string_to_tokens(self.path_strings[index], self.seq_out_max_length)

        step_item = Tensor([self.step_lengths[index]])
        # shapes: [product_max_length], [output_max_length], int
        return seq_encoder_item, seq_decoder_item, step_item

    def get_generation_no_sm(self, index: int) -> tuple[Tensor, ...]:
        """Retrieves a generation item without starting materials.

        Args:
            index: The index of the item to retrieve.

        Returns:
            A tuple of tensors: encoder input, step length, and initial path tensor.
        """
        seq_encoder_item = self.smile_to_tokens(self.products[index], self.product_max_length)

        step_item = Tensor([self.step_lengths[index]])
        # shapes: [product_max_length], [output_max_length], int
        smile_dict = {"smiles": self.products[index], "children": [{"smiles": ""}]}
        path_start = str(smile_dict).replace(" ", "")[:-4]
        # path_start = "{'smiles':'" + self.products[index] + "','children':["
        path_tens = self.path_string_to_tokens(path_start, max_length=None, add_eos=False)
        return seq_encoder_item, step_item, path_tens

__getitem__(index)

Retrieves an item from the dataset.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
tuple[Tensor, ...]

A tuple of tensors representing the input and output data.

Source code in src/directmultistep/utils/dataset.py
def __getitem__(self, index: int) -> tuple[Tensor, ...]:
    """Retrieves an item from the dataset.

    Args:
        index: The index of the item to retrieve.

    Returns:
        A tuple of tensors representing the input and output data.
    """
    if self.mode == "training":
        if self.sms is not None:
            return self.get_training_with_sm(index)
        else:
            return self.get_training_no_sm(index)
    elif self.mode == "generation":
        if self.sms is not None:
            return self.get_generation_with_sm(index)
        else:
            return self.get_generation_no_sm(index)
    else:
        raise ValueError(f"Invalid mode: {self.mode}")

__init__(metadata_path, products, path_strings, n_steps_list, starting_materials=None, mode='training', name_idx=None)

Initializes the RoutesDataset.

Parameters:

Name Type Description Default
metadata_path Path

Path to the metadata file (YAML).

required
products list[str]

List of product SMILES strings.

required
path_strings list[str]

List of reaction path strings.

required
n_steps_list list[int]

List of integers representing the number of steps in each path.

required
starting_materials list[str] | None

List of starting material SMILES strings.

None
mode str

Either "training" or "generation".

'training'
name_idx dict[str, list[int]] | None

A dictionary mapping names to lists of indices.

None
Source code in src/directmultistep/utils/dataset.py
def __init__(
    self,
    metadata_path: Path,
    products: list[str],
    path_strings: list[str],
    n_steps_list: list[int],
    starting_materials: list[str] | None = None,
    mode: str = "training",
    name_idx: dict[str, list[int]] | None = None,
) -> None:
    """Initializes the RoutesDataset.

    Args:
        metadata_path: Path to the metadata file (YAML).
        products: List of product SMILES strings.
        path_strings: List of reaction path strings.
        n_steps_list: List of integers representing the number of steps in each path.
        starting_materials: List of starting material SMILES strings.
        mode: Either "training" or "generation".
        name_idx: A dictionary mapping names to lists of indices.
    """
    super().__init__(metadata_path)
    self.products = products
    self.path_strings = path_strings
    self.step_lengths = n_steps_list
    self.sms = starting_materials
    # name_idx is an optional attribute that shows labels for items in the dataset
    # currently used for evals on pharma compounds
    self.name_idx = name_idx
    assert mode in ["training", "generation"], "mode must be either 'training' or 'generation'"
    self.mode = mode

__len__()

Returns the number of items in the dataset.

Source code in src/directmultistep/utils/dataset.py
def __len__(self) -> int:
    """Returns the number of items in the dataset."""
    return len(self.products)

__repr__()

Returns a string representation of the dataset.

Source code in src/directmultistep/utils/dataset.py
def __repr__(self) -> str:
    """Returns a string representation of the dataset."""
    sms_str = "SM (enabled)" if self.sms is not None else "SM (disabled)"
    return f"RoutesDataset(mode={self.mode}, len={len(self)}, {sms_str})"

get_generation_no_sm(index)

Retrieves a generation item without starting materials.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
tuple[Tensor, ...]

A tuple of tensors: encoder input, step length, and initial path tensor.

Source code in src/directmultistep/utils/dataset.py
def get_generation_no_sm(self, index: int) -> tuple[Tensor, ...]:
    """Retrieves a generation item without starting materials.

    Args:
        index: The index of the item to retrieve.

    Returns:
        A tuple of tensors: encoder input, step length, and initial path tensor.
    """
    seq_encoder_item = self.smile_to_tokens(self.products[index], self.product_max_length)

    step_item = Tensor([self.step_lengths[index]])
    # shapes: [product_max_length], [output_max_length], int
    smile_dict = {"smiles": self.products[index], "children": [{"smiles": ""}]}
    path_start = str(smile_dict).replace(" ", "")[:-4]
    # path_start = "{'smiles':'" + self.products[index] + "','children':["
    path_tens = self.path_string_to_tokens(path_start, max_length=None, add_eos=False)
    return seq_encoder_item, step_item, path_tens

get_generation_with_sm(index)

Retrieves a generation item with starting materials.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
tuple[Tensor, ...]

A tuple of tensors: encoder input, step length, and initial path tensor.

Source code in src/directmultistep/utils/dataset.py
def get_generation_with_sm(self, index: int) -> tuple[Tensor, ...]:
    """Retrieves a generation item with starting materials.

    Args:
        index: The index of the item to retrieve.

    Returns:
        A tuple of tensors: encoder input, step length, and initial path tensor.
    """
    assert self.sms is not None, "starting materials are not provided"
    product_item = self.smile_to_tokens(self.products[index], self.product_max_length)
    one_sm_item = self.smile_to_tokens(self.sms[index], self.sm_max_length)
    seq_encoder_item = torch.cat((product_item, one_sm_item), dim=0)

    step_item = Tensor([self.step_lengths[index]])
    smile_dict = {"smiles": self.products[index], "children": [{"smiles": ""}]}
    path_start = str(smile_dict).replace(" ", "")[:-4]
    path_tens = self.path_string_to_tokens(path_start, max_length=None, add_eos=False)
    return seq_encoder_item, step_item, path_tens

get_training_no_sm(index)

Retrieves a training item without starting materials.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
tuple[Tensor, ...]

A tuple of tensors: encoder input, decoder input, and step length.

Source code in src/directmultistep/utils/dataset.py
def get_training_no_sm(self, index: int) -> tuple[Tensor, ...]:
    """Retrieves a training item without starting materials.

    Args:
        index: The index of the item to retrieve.

    Returns:
        A tuple of tensors: encoder input, decoder input, and step length.
    """
    seq_encoder_item = self.smile_to_tokens(self.products[index], self.product_max_length)
    seq_decoder_item = self.path_string_to_tokens(self.path_strings[index], self.seq_out_max_length)

    step_item = Tensor([self.step_lengths[index]])
    # shapes: [product_max_length], [output_max_length], int
    return seq_encoder_item, seq_decoder_item, step_item

get_training_with_sm(index)

Retrieves a training item with starting materials.

Parameters:

Name Type Description Default
index int

The index of the item to retrieve.

required

Returns:

Type Description
tuple[Tensor, ...]

A tuple of tensors: encoder input, decoder input, and step length.

Source code in src/directmultistep/utils/dataset.py
def get_training_with_sm(self, index: int) -> tuple[Tensor, ...]:
    """Retrieves a training item with starting materials.

    Args:
        index: The index of the item to retrieve.

    Returns:
        A tuple of tensors: encoder input, decoder input, and step length.
    """
    assert self.sms is not None, "starting materials are not provided"
    product_item = self.smile_to_tokens(self.products[index], self.product_max_length)
    one_sm_item = self.smile_to_tokens(self.sms[index], self.sm_max_length)
    seq_encoder_item = torch.cat((product_item, one_sm_item), dim=0)
    seq_decoder_item = self.path_string_to_tokens(self.path_strings[index], self.seq_out_max_length)

    step_item = Tensor([self.step_lengths[index]])
    return seq_encoder_item, seq_decoder_item, step_item

tokenize_smile(smile)

Tokenizes a SMILES string by character.

Parameters:

Name Type Description Default
smile str

The SMILES string to tokenize.

required

Returns:

Type Description
list[str]

A list of tokens, including start and end of sequence tokens.

Source code in src/directmultistep/utils/dataset.py
def tokenize_smile(smile: str) -> list[str]:
    """Tokenizes a SMILES string by character.

    Args:
        smile: The SMILES string to tokenize.

    Returns:
        A list of tokens, including start and end of sequence tokens.
    """
    return ["<SOS>"] + list(smile) + ["?"]

tokenize_smile_atom(smile, has_atom_types, mask=False)

Tokenizes a SMILES string, considering atom types of up to two characters.

Parameters:

Name Type Description Default
smile str

The SMILES string to tokenize.

required
has_atom_types list[str]

A list of atom types to consider (e.g., ["Cl", "Br"]).

required
mask bool

If True, replaces all atom tokens with "J".

False

Returns:

Type Description
list[str]

A list of tokens, including start and end of sequence tokens.

Source code in src/directmultistep/utils/dataset.py
def tokenize_smile_atom(smile: str, has_atom_types: list[str], mask: bool = False) -> list[str]:
    """Tokenizes a SMILES string, considering atom types of up to two characters.

    Args:
        smile: The SMILES string to tokenize.
        has_atom_types: A list of atom types to consider (e.g., ["Cl", "Br"]).
        mask: If True, replaces all atom tokens with "J".

    Returns:
        A list of tokens, including start and end of sequence tokens.
    """
    tokens = []
    i = 0
    while i < len(smile):
        if i < len(smile) - 1 and smile[i : i + 2] in has_atom_types:
            tokens.append("J" if mask else smile[i : i + 2])
            i += 2
        else:
            tokens.append("J" if mask else smile[i])
            i += 1
    return ["<SOS>"] + tokens + ["?"]

tokenize_context(context_list)

Tokenizes a list of context strings.

Parameters:

Name Type Description Default
context_list list[str]

A list of context strings to tokenize.

required

Returns:

Type Description
list[str]

A list of tokens, including context start, separator, and end tokens.

Source code in src/directmultistep/utils/dataset.py
def tokenize_context(context_list: list[str]) -> list[str]:
    """Tokenizes a list of context strings.

    Args:
        context_list: A list of context strings to tokenize.

    Returns:
        A list of tokens, including context start, separator, and end tokens.
    """
    tokens = ["<context>"]
    for context in context_list:
        tokens.extend(tokenize_path_string(context, add_sos=False, add_eos=False))
        tokens.append("<sep>")
    tokens.append("</context>")
    return tokens

tokenize_path_string(path_string, add_sos=True, add_eos=True)

Tokenizes a path string based on a regular expression.

Parameters:

Name Type Description Default
path_string str

The path string to tokenize.

required
add_sos bool

If True, adds a start of sequence token.

True
add_eos bool

If True, adds an end of sequence token.

True

Returns:

Type Description
list[str]

A list of tokens.

Source code in src/directmultistep/utils/dataset.py
def tokenize_path_string(path_string: str, add_sos: bool = True, add_eos: bool = True) -> list[str]:
    """Tokenizes a path string based on a regular expression.

    Args:
        path_string: The path string to tokenize.
        add_sos: If True, adds a start of sequence token.
        add_eos: If True, adds an end of sequence token.

    Returns:
        A list of tokens.
    """
    pattern = re.compile(r"('smiles':|'children':|\[|\]|{|}|.)")
    tokens = ["<SOS>"] if add_sos else []
    tokens.extend(pattern.findall(path_string))
    if add_eos:
        tokens.append("?")
    return tokens

directmultistep.generate

prepare_input_tensors(target, n_steps, starting_material, rds, product_max_length, sm_max_length)

Prepare input tensors for the model. Args: target: SMILES string of the target molecule. n_steps: Number of synthesis steps. starting_material: SMILES string of the starting material, if any. rds: RoutesProcessing object for tokenization. product_max_length: Maximum length of the product SMILES sequence. sm_max_length: Maximum length of the starting material SMILES sequence. use_fp16: Whether to use half precision (FP16) for tensors. path_start: Initial path string to start generation from. Returns: A tuple containing: - encoder_inp: Input tensor for the encoder. - steps_tens: Tensor of the number of steps, or None if not provided. - path_tens: Initial path tensor for the decoder.

Source code in src/directmultistep/generate.py
def prepare_input_tensors(
    target: str,
    n_steps: int | None,
    starting_material: str | None,
    rds: RoutesProcessing,
    product_max_length: int,
    sm_max_length: int,
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor]:
    """Prepare input tensors for the model.
    Args:
        target: SMILES string of the target molecule.
        n_steps: Number of synthesis steps.
        starting_material: SMILES string of the starting material, if any.
        rds: RoutesProcessing object for tokenization.
        product_max_length: Maximum length of the product SMILES sequence.
        sm_max_length: Maximum length of the starting material SMILES sequence.
        use_fp16: Whether to use half precision (FP16) for tensors.
        path_start: Initial path string to start generation from.
    Returns:
        A tuple containing:
            - encoder_inp: Input tensor for the encoder.
            - steps_tens: Tensor of the number of steps, or None if not provided.
            - path_tens: Initial path tensor for the decoder.
    """
    prod_tens = rds.smile_to_tokens(target, product_max_length)
    if starting_material:
        sm_tens = rds.smile_to_tokens(starting_material, sm_max_length)
        encoder_inp = torch.cat([prod_tens, sm_tens], dim=0).unsqueeze(0)
    else:
        encoder_inp = torch.cat([prod_tens], dim=0).unsqueeze(0)

    steps_tens = torch.tensor([n_steps]).unsqueeze(0) if n_steps is not None else None
    path_start = "{'smiles':'" + target + "','children':[{'smiles':'"
    path_tens = rds.path_string_to_tokens(path_start, max_length=None, add_eos=False).unsqueeze(0)

    return encoder_inp, steps_tens, path_tens