Skip to content

Creating a model instance

There are several ways to create a DMS model instance, ranging from using preset configurations to custom configurations.

Using Preset Configurations

The simplest way to create a model is using one of the preset configurations:

from directmultistep.model import ModelFactory

factory = ModelFactory.from_preset("flash_10M", compile_model=True)
model = factory.create_model()

Available presets include: deep_40M, explorer_xl_50M, flash_10M, flash_20M, flex_20M, and wide_40M.

Custom Configuration

For more control, you can create a custom configuration:

from directmultistep.model.config import Seq2SeqConfig, EncoderAConfig, MoEDecoderConfig

config = Seq2SeqConfig(
    encoder=EncoderAConfig(
        vocab_dim=53,
        hid_dim=256,
        n_layers=6,
        n_heads=8,
        ff_mult=3,
        ff_activation="gelu",
        dropout=0.1,
        attn_bias=False,
        context_window=280,
        start_idx=0,
        mask_idx=51,
        pad_idx=52,
        initiate_steps=True,
        include_steps=True
    ),
    decoder=MoEDecoderConfig(
        vocab_dim=53,
        hid_dim=256,
        n_layers=6,
        n_heads=8,
        ff_mult=3,
        ff_activation="gelu",
        dropout=0.1,
        attn_bias=False,
        context_window=1075,
        start_idx=0,
        mask_idx=51,
        pad_idx=52,
        n_experts=3,
        top_k=2,
        capacity_factor=1.0,
    ),
)

factory = ModelFactory(config, device=None, compile_model=True)
model = factory.create_model()

Configuration Types

The model supports different types of encoders and decoders:

  • Encoders:
  • EncoderAConfig: EncoderA Type (the one we've been using so far)
  • MoEEncoderConfig: Mixture of Experts encoder

  • Decoders:

  • TransformerConfig: Standard transformer decoder
  • MoEDecoderConfig: Mixture of Experts decoder

Saving and Loading Configurations

Configurations can be saved to and loaded from YAML files:

# Save configuration
config.save("model_config.yaml")

# Load configuration and create model
factory = ModelFactory.from_config_file("model_config.yaml")
model = factory.create_model()

Source Code

directmultistep.model.config

TransformerConfig dataclass

Configuration for transformer components.

Attributes:

Name Type Description
vocab_dim int

Vocabulary dimension.

hid_dim int

Hidden dimension.

n_layers int

Number of layers.

n_heads int

Number of attention heads.

ff_mult int

Feedforward multiplier.

ff_activation Literal['gelu', 'relu']

Feedforward activation function ('gelu' or 'relu').

dropout float

Dropout probability.

attn_bias bool

Whether to use attention bias.

context_window int

Context window size.

start_idx int

Start token index.

mask_idx int

Mask token index.

pad_idx int

Padding token index.

Source code in src/directmultistep/model/config.py
@dataclass
class TransformerConfig:
    """Configuration for transformer components.

    Attributes:
        vocab_dim: Vocabulary dimension.
        hid_dim: Hidden dimension.
        n_layers: Number of layers.
        n_heads: Number of attention heads.
        ff_mult: Feedforward multiplier.
        ff_activation: Feedforward activation function ('gelu' or 'relu').
        dropout: Dropout probability.
        attn_bias: Whether to use attention bias.
        context_window: Context window size.
        start_idx: Start token index.
        mask_idx: Mask token index.
        pad_idx: Padding token index.
    """

    vocab_dim: int
    hid_dim: int
    n_layers: int
    n_heads: int
    ff_mult: int
    ff_activation: Literal["gelu", "relu"]
    dropout: float
    attn_bias: bool
    context_window: int
    start_idx: int
    mask_idx: int
    pad_idx: int

    def __post_init__(self) -> None:
        if self.hid_dim % self.n_heads != 0:
            raise ValueError(f"{self.hid_dim=} must be divisible by {self.n_heads=}")
        if self.ff_activation not in ["gelu", "relu"]:
            raise ValueError(f"{self.ff_activation=} must be either 'gelu' or 'relu'")

    def save(self, path: Path) -> None:
        """Save config to yaml file.

        Args:
            path: Path to save the config to.
        """
        data = asdict(self)
        data["model_type"] = self.__class__.__name__
        with open(path, "w") as f:
            yaml.dump(data, f, sort_keys=False, default_flow_style=False)

    @classmethod
    def load(cls: Type[T], path: Path) -> T:
        """Load config from yaml file.

        Args:
            path: Path to load the config from.

        Returns:
            Loaded config.
        """
        with open(path) as f:
            data = yaml.safe_load(f)
        return cls(**data)

load(path) classmethod

Load config from yaml file.

Parameters:

Name Type Description Default
path Path

Path to load the config from.

required

Returns:

Type Description
T

Loaded config.

Source code in src/directmultistep/model/config.py
@classmethod
def load(cls: Type[T], path: Path) -> T:
    """Load config from yaml file.

    Args:
        path: Path to load the config from.

    Returns:
        Loaded config.
    """
    with open(path) as f:
        data = yaml.safe_load(f)
    return cls(**data)

save(path)

Save config to yaml file.

Parameters:

Name Type Description Default
path Path

Path to save the config to.

required
Source code in src/directmultistep/model/config.py
def save(self, path: Path) -> None:
    """Save config to yaml file.

    Args:
        path: Path to save the config to.
    """
    data = asdict(self)
    data["model_type"] = self.__class__.__name__
    with open(path, "w") as f:
        yaml.dump(data, f, sort_keys=False, default_flow_style=False)

MoEDecoderConfig dataclass

Bases: TransformerConfig

Configuration for Mixture of Experts decoder components.

Attributes:

Name Type Description
n_experts int

Number of experts.

top_k int

Number of experts to use in forward pass.

capacity_factor float

Capacity factor for experts.

Source code in src/directmultistep/model/config.py
@dataclass
class MoEDecoderConfig(TransformerConfig):
    """Configuration for Mixture of Experts decoder components.

    Attributes:
        n_experts: Number of experts.
        top_k: Number of experts to use in forward pass.
        capacity_factor: Capacity factor for experts.
    """

    n_experts: int
    top_k: int
    capacity_factor: float

EncoderAConfig dataclass

Bases: TransformerConfig

Configuration for EncoderA components.

Attributes:

Name Type Description
initiate_steps bool

Whether to initiate steps.

include_steps bool

Whether to include steps.

Source code in src/directmultistep/model/config.py
@dataclass
class EncoderAConfig(TransformerConfig):
    """Configuration for EncoderA components.

    Attributes:
        initiate_steps: Whether to initiate steps.
        include_steps: Whether to include steps.
    """

    initiate_steps: bool
    include_steps: bool

MoEEncoderConfig dataclass

Bases: EncoderAConfig

Configuration for Mixture of Experts encoder components.

Attributes:

Name Type Description
n_experts int

Number of experts.

top_k int

Number of experts to use in forward pass.

capacity_factor float

Capacity factor for experts.

Source code in src/directmultistep/model/config.py
@dataclass
class MoEEncoderConfig(EncoderAConfig):
    """Configuration for Mixture of Experts encoder components.

    Attributes:
        n_experts: Number of experts.
        top_k: Number of experts to use in forward pass.
        capacity_factor: Capacity factor for experts.
    """

    n_experts: int
    top_k: int
    capacity_factor: float

Seq2SeqConfig dataclass

Complete model configuration.

Attributes:

Name Type Description
encoder TransformerConfig

Encoder configuration.

decoder TransformerConfig

Decoder configuration.

Source code in src/directmultistep/model/config.py
@dataclass
class Seq2SeqConfig:
    """Complete model configuration.

    Attributes:
        encoder: Encoder configuration.
        decoder: Decoder configuration.
    """

    encoder: TransformerConfig
    decoder: TransformerConfig

    def save(self, path: Path) -> None:
        """Save config to yaml file.

        Args:
            path: Path to save the config to.
        """
        config_dict = {
            "encoder": asdict(self.encoder) | {"model_type": self.encoder.__class__.__name__},
            "decoder": asdict(self.decoder) | {"model_type": self.decoder.__class__.__name__},
        }
        with open(path, "w") as f:
            yaml.dump(config_dict, f, sort_keys=False)

    @classmethod
    def load(cls, path: Path) -> "Seq2SeqConfig":
        """Load config from yaml file.

        Args:
            path: Path to load the config from.

        Returns:
            Loaded Seq2SeqConfig.
        """
        with open(path) as f:
            data = yaml.safe_load(f)

        # Determine correct encoder/decoder types based on model_type
        encoder_data = data.pop("encoder")
        decoder_data = data.pop("decoder")

        model_type_to_config = {
            "TransformerConfig": TransformerConfig,
            "MoEDecoderConfig": MoEDecoderConfig,
            "EncoderAConfig": EncoderAConfig,
            "MoEEncoderConfig": MoEEncoderConfig,
        }

        encoder_model_type = encoder_data.pop("model_type")
        decoder_model_type = decoder_data.pop("model_type")

        encoder_type = model_type_to_config[encoder_model_type]
        decoder_type = model_type_to_config[decoder_model_type]

        encoder = encoder_type(**encoder_data)
        decoder = decoder_type(**decoder_data)

        return cls(encoder=encoder, decoder=decoder, **data)

load(path) classmethod

Load config from yaml file.

Parameters:

Name Type Description Default
path Path

Path to load the config from.

required

Returns:

Type Description
Seq2SeqConfig

Loaded Seq2SeqConfig.

Source code in src/directmultistep/model/config.py
@classmethod
def load(cls, path: Path) -> "Seq2SeqConfig":
    """Load config from yaml file.

    Args:
        path: Path to load the config from.

    Returns:
        Loaded Seq2SeqConfig.
    """
    with open(path) as f:
        data = yaml.safe_load(f)

    # Determine correct encoder/decoder types based on model_type
    encoder_data = data.pop("encoder")
    decoder_data = data.pop("decoder")

    model_type_to_config = {
        "TransformerConfig": TransformerConfig,
        "MoEDecoderConfig": MoEDecoderConfig,
        "EncoderAConfig": EncoderAConfig,
        "MoEEncoderConfig": MoEEncoderConfig,
    }

    encoder_model_type = encoder_data.pop("model_type")
    decoder_model_type = decoder_data.pop("model_type")

    encoder_type = model_type_to_config[encoder_model_type]
    decoder_type = model_type_to_config[decoder_model_type]

    encoder = encoder_type(**encoder_data)
    decoder = decoder_type(**decoder_data)

    return cls(encoder=encoder, decoder=decoder, **data)

save(path)

Save config to yaml file.

Parameters:

Name Type Description Default
path Path

Path to save the config to.

required
Source code in src/directmultistep/model/config.py
def save(self, path: Path) -> None:
    """Save config to yaml file.

    Args:
        path: Path to save the config to.
    """
    config_dict = {
        "encoder": asdict(self.encoder) | {"model_type": self.encoder.__class__.__name__},
        "decoder": asdict(self.decoder) | {"model_type": self.decoder.__class__.__name__},
    }
    with open(path, "w") as f:
        yaml.dump(config_dict, f, sort_keys=False)

directmultistep.model.factory

ModelFactory

Factory class for creating and configuring models.

Source code in src/directmultistep/model/factory.py
class ModelFactory:
    """Factory class for creating and configuring models."""

    def __init__(
        self,
        config: Seq2SeqConfig,
        device: str | None = None,
        compile_model: bool = True,
        allow_mps: bool = False,
    ) -> None:
        """Initializes the ModelFactory.

        Args:
            config: The complete model configuration.
            device: Optional device specification. If None, the best available device is used.
            compile_model: Whether to compile the model using torch.compile.
            allow_mps: Whether to allow MPS device usage.
        """
        self.config = config
        self.device = self.determine_device(device, allow_mps)
        self.compile_model = compile_model

    def check_for_eval_config_updates(self, ec: EvalConfig) -> None:
        if isinstance(self.config.encoder, MoEEncoderConfig):
            if ec.enc_active_experts is None:
                raise ValueError("Encoder active experts must be set in eval config")
            self.config.encoder.top_k = ec.enc_active_experts
        if isinstance(self.config.decoder, MoEDecoderConfig):
            if ec.dec_active_experts is None:
                raise ValueError("Decoder active experts must be set in eval config")
            self.config.decoder.top_k = ec.dec_active_experts

    @staticmethod
    def determine_device(device: str | None = None, allow_mps: bool = False) -> torch_device:
        """Determines the appropriate device for model placement.

        Args:
            device: Optional device specification.

        Returns:
            The determined torch.device.
        """
        if device is None:
            if torch.cuda.is_available():
                device = "cuda"
            elif allow_mps and torch.backends.mps.is_available():
                device = "mps"
            else:
                device = "cpu"
        return torch.device(device)

    @staticmethod
    def _count_parameters(model: nn.Module) -> int:
        """Counts the trainable parameters in a model.

        Args:
            model: The PyTorch model.

        Returns:
            The number of trainable parameters.
        """
        return sum(p.numel() for p in model.parameters() if p.requires_grad)

    def create_model(self) -> Seq2Seq:
        """Creates and configures a Seq2Seq model based on the provided configuration.

        Returns:
            The configured Seq2Seq model.
        """
        # Create encoder based on configuration type
        if not isinstance(self.config.encoder, (EncoderAConfig, MoEEncoderConfig)):
            raise TypeError("Encoder config must be either EncoderAConfig or MoEEncoderConfig")
        if not isinstance(self.config.decoder, (TransformerConfig, MoEDecoderConfig)):
            raise TypeError("Decoder config must be either TransformerConfig or MoEDecoderConfig")

        encoder: Encoder | MoEEncoder
        if isinstance(self.config.encoder, MoEEncoderConfig):
            encoder = MoEEncoder(
                vocab_dim=self.config.encoder.vocab_dim,
                hid_dim=self.config.encoder.hid_dim,
                context_window=self.config.encoder.context_window,
                n_layers=self.config.encoder.n_layers,
                n_heads=self.config.encoder.n_heads,
                ff_mult=self.config.encoder.ff_mult,
                ff_activation=self.config.encoder.ff_activation,
                dropout=self.config.encoder.dropout,
                attn_bias=self.config.encoder.attn_bias,
                initiate_steps=self.config.encoder.initiate_steps,
                include_steps=self.config.encoder.include_steps,
                n_experts=self.config.encoder.n_experts,
                top_k=self.config.encoder.top_k,
                capacity_factor=self.config.encoder.capacity_factor,
            )
        else:
            encoder = Encoder(
                vocab_dim=self.config.encoder.vocab_dim,
                hid_dim=self.config.encoder.hid_dim,
                context_window=self.config.encoder.context_window,
                n_layers=self.config.encoder.n_layers,
                n_heads=self.config.encoder.n_heads,
                ff_mult=self.config.encoder.ff_mult,
                ff_activation=self.config.encoder.ff_activation,
                dropout=self.config.encoder.dropout,
                attn_bias=self.config.encoder.attn_bias,
                initiate_steps=self.config.encoder.initiate_steps,
                include_steps=self.config.encoder.include_steps,
            )

        decoder: Decoder | MoEDecoder
        if isinstance(self.config.decoder, MoEDecoderConfig):
            decoder = MoEDecoder(
                vocab_dim=self.config.decoder.vocab_dim,
                hid_dim=self.config.decoder.hid_dim,
                context_window=self.config.decoder.context_window,
                n_layers=self.config.decoder.n_layers,
                n_heads=self.config.decoder.n_heads,
                dropout=self.config.decoder.dropout,
                attn_bias=self.config.decoder.attn_bias,
                ff_mult=self.config.decoder.ff_mult,
                ff_activation=self.config.decoder.ff_activation,
                n_experts=self.config.decoder.n_experts,
                top_k=self.config.decoder.top_k,
                capacity_factor=self.config.decoder.capacity_factor,
            )
        else:
            decoder = Decoder(
                vocab_dim=self.config.decoder.vocab_dim,
                hid_dim=self.config.decoder.hid_dim,
                context_window=self.config.decoder.context_window,
                n_layers=self.config.decoder.n_layers,
                n_heads=self.config.decoder.n_heads,
                dropout=self.config.decoder.dropout,
                attn_bias=self.config.decoder.attn_bias,
                ff_mult=self.config.decoder.ff_mult,
                ff_activation=self.config.decoder.ff_activation,
            )

        model = Seq2Seq(
            encoder=encoder,
            decoder=decoder,
            src_pad_idx=self.config.encoder.pad_idx,
            trg_pad_idx=self.config.decoder.pad_idx,
        )

        model.to(self.device)

        if self.compile_model:
            model = torch.compile(model)  # type: ignore

        print(f"The model has {self._count_parameters(model):,} trainable parameters")
        return model

    @classmethod
    def from_config_file(
        cls,
        config_path: str | Path,
        device: str | None = None,
        compile_model: bool = True,
    ) -> "ModelFactory":
        """Creates a ModelFactory instance from a configuration file.

        Args:
            config_path: Path to the configuration file.
            device: Optional device specification.
            compile_model: Whether to compile the model.

        Returns:
            The configured ModelFactory instance.
        """
        config = Seq2SeqConfig.load(Path(config_path))
        return cls(config=config, device=device, compile_model=compile_model)

    @classmethod
    def from_preset(cls, preset_name: str, device: str | None = None, compile_model: bool = True) -> "ModelFactory":
        """Loads a preset configuration by name.

        Args:
            preset_name: The name of the preset configuration.
            device: Optional device specification.
            compile_model: Whether to compile the model.

        Returns:
            The configured ModelFactory instance.

        Raises:
            ValueError: If the preset is not found.
        """
        try:
            with resources.path("directmultistep.model.default_configs", f"{preset_name}.yaml") as config_path:
                return cls.from_config_file(config_path, device, compile_model)
        except FileNotFoundError:
            raise ValueError(
                f"Preset '{preset_name}' not found. Available presets: deep_40M, explorer_xl_50M, flash_10M, flash_20M, flex_20M, wide_40M"
            )

    @staticmethod
    def load_checkpoint(model: Seq2Seq, ckpt_path: Path, device: torch.device) -> Seq2Seq:
        ckpt_torch = torch.load(ckpt_path, map_location=device)
        model.load_state_dict(ckpt_torch)
        model.to(device)
        model.eval()
        return model

    @staticmethod
    def load_lightning_checkpoint(model: Seq2Seq, ckpt_path: Path, device: torch.device) -> Seq2Seq:
        ckpt_lightning = torch.load(ckpt_path, map_location=device)
        ckpt_torch = {k.replace("model.", ""): v for k, v in ckpt_lightning["state_dict"].items()}
        model.load_state_dict(ckpt_torch)
        model.to(device)
        model.eval()
        return model

__init__(config, device=None, compile_model=True, allow_mps=False)

Initializes the ModelFactory.

Parameters:

Name Type Description Default
config Seq2SeqConfig

The complete model configuration.

required
device str | None

Optional device specification. If None, the best available device is used.

None
compile_model bool

Whether to compile the model using torch.compile.

True
allow_mps bool

Whether to allow MPS device usage.

False
Source code in src/directmultistep/model/factory.py
def __init__(
    self,
    config: Seq2SeqConfig,
    device: str | None = None,
    compile_model: bool = True,
    allow_mps: bool = False,
) -> None:
    """Initializes the ModelFactory.

    Args:
        config: The complete model configuration.
        device: Optional device specification. If None, the best available device is used.
        compile_model: Whether to compile the model using torch.compile.
        allow_mps: Whether to allow MPS device usage.
    """
    self.config = config
    self.device = self.determine_device(device, allow_mps)
    self.compile_model = compile_model

create_model()

Creates and configures a Seq2Seq model based on the provided configuration.

Returns:

Type Description
Seq2Seq

The configured Seq2Seq model.

Source code in src/directmultistep/model/factory.py
def create_model(self) -> Seq2Seq:
    """Creates and configures a Seq2Seq model based on the provided configuration.

    Returns:
        The configured Seq2Seq model.
    """
    # Create encoder based on configuration type
    if not isinstance(self.config.encoder, (EncoderAConfig, MoEEncoderConfig)):
        raise TypeError("Encoder config must be either EncoderAConfig or MoEEncoderConfig")
    if not isinstance(self.config.decoder, (TransformerConfig, MoEDecoderConfig)):
        raise TypeError("Decoder config must be either TransformerConfig or MoEDecoderConfig")

    encoder: Encoder | MoEEncoder
    if isinstance(self.config.encoder, MoEEncoderConfig):
        encoder = MoEEncoder(
            vocab_dim=self.config.encoder.vocab_dim,
            hid_dim=self.config.encoder.hid_dim,
            context_window=self.config.encoder.context_window,
            n_layers=self.config.encoder.n_layers,
            n_heads=self.config.encoder.n_heads,
            ff_mult=self.config.encoder.ff_mult,
            ff_activation=self.config.encoder.ff_activation,
            dropout=self.config.encoder.dropout,
            attn_bias=self.config.encoder.attn_bias,
            initiate_steps=self.config.encoder.initiate_steps,
            include_steps=self.config.encoder.include_steps,
            n_experts=self.config.encoder.n_experts,
            top_k=self.config.encoder.top_k,
            capacity_factor=self.config.encoder.capacity_factor,
        )
    else:
        encoder = Encoder(
            vocab_dim=self.config.encoder.vocab_dim,
            hid_dim=self.config.encoder.hid_dim,
            context_window=self.config.encoder.context_window,
            n_layers=self.config.encoder.n_layers,
            n_heads=self.config.encoder.n_heads,
            ff_mult=self.config.encoder.ff_mult,
            ff_activation=self.config.encoder.ff_activation,
            dropout=self.config.encoder.dropout,
            attn_bias=self.config.encoder.attn_bias,
            initiate_steps=self.config.encoder.initiate_steps,
            include_steps=self.config.encoder.include_steps,
        )

    decoder: Decoder | MoEDecoder
    if isinstance(self.config.decoder, MoEDecoderConfig):
        decoder = MoEDecoder(
            vocab_dim=self.config.decoder.vocab_dim,
            hid_dim=self.config.decoder.hid_dim,
            context_window=self.config.decoder.context_window,
            n_layers=self.config.decoder.n_layers,
            n_heads=self.config.decoder.n_heads,
            dropout=self.config.decoder.dropout,
            attn_bias=self.config.decoder.attn_bias,
            ff_mult=self.config.decoder.ff_mult,
            ff_activation=self.config.decoder.ff_activation,
            n_experts=self.config.decoder.n_experts,
            top_k=self.config.decoder.top_k,
            capacity_factor=self.config.decoder.capacity_factor,
        )
    else:
        decoder = Decoder(
            vocab_dim=self.config.decoder.vocab_dim,
            hid_dim=self.config.decoder.hid_dim,
            context_window=self.config.decoder.context_window,
            n_layers=self.config.decoder.n_layers,
            n_heads=self.config.decoder.n_heads,
            dropout=self.config.decoder.dropout,
            attn_bias=self.config.decoder.attn_bias,
            ff_mult=self.config.decoder.ff_mult,
            ff_activation=self.config.decoder.ff_activation,
        )

    model = Seq2Seq(
        encoder=encoder,
        decoder=decoder,
        src_pad_idx=self.config.encoder.pad_idx,
        trg_pad_idx=self.config.decoder.pad_idx,
    )

    model.to(self.device)

    if self.compile_model:
        model = torch.compile(model)  # type: ignore

    print(f"The model has {self._count_parameters(model):,} trainable parameters")
    return model

determine_device(device=None, allow_mps=False) staticmethod

Determines the appropriate device for model placement.

Parameters:

Name Type Description Default
device str | None

Optional device specification.

None

Returns:

Type Description
device

The determined torch.device.

Source code in src/directmultistep/model/factory.py
@staticmethod
def determine_device(device: str | None = None, allow_mps: bool = False) -> torch_device:
    """Determines the appropriate device for model placement.

    Args:
        device: Optional device specification.

    Returns:
        The determined torch.device.
    """
    if device is None:
        if torch.cuda.is_available():
            device = "cuda"
        elif allow_mps and torch.backends.mps.is_available():
            device = "mps"
        else:
            device = "cpu"
    return torch.device(device)

from_config_file(config_path, device=None, compile_model=True) classmethod

Creates a ModelFactory instance from a configuration file.

Parameters:

Name Type Description Default
config_path str | Path

Path to the configuration file.

required
device str | None

Optional device specification.

None
compile_model bool

Whether to compile the model.

True

Returns:

Type Description
ModelFactory

The configured ModelFactory instance.

Source code in src/directmultistep/model/factory.py
@classmethod
def from_config_file(
    cls,
    config_path: str | Path,
    device: str | None = None,
    compile_model: bool = True,
) -> "ModelFactory":
    """Creates a ModelFactory instance from a configuration file.

    Args:
        config_path: Path to the configuration file.
        device: Optional device specification.
        compile_model: Whether to compile the model.

    Returns:
        The configured ModelFactory instance.
    """
    config = Seq2SeqConfig.load(Path(config_path))
    return cls(config=config, device=device, compile_model=compile_model)

from_preset(preset_name, device=None, compile_model=True) classmethod

Loads a preset configuration by name.

Parameters:

Name Type Description Default
preset_name str

The name of the preset configuration.

required
device str | None

Optional device specification.

None
compile_model bool

Whether to compile the model.

True

Returns:

Type Description
ModelFactory

The configured ModelFactory instance.

Raises:

Type Description
ValueError

If the preset is not found.

Source code in src/directmultistep/model/factory.py
@classmethod
def from_preset(cls, preset_name: str, device: str | None = None, compile_model: bool = True) -> "ModelFactory":
    """Loads a preset configuration by name.

    Args:
        preset_name: The name of the preset configuration.
        device: Optional device specification.
        compile_model: Whether to compile the model.

    Returns:
        The configured ModelFactory instance.

    Raises:
        ValueError: If the preset is not found.
    """
    try:
        with resources.path("directmultistep.model.default_configs", f"{preset_name}.yaml") as config_path:
            return cls.from_config_file(config_path, device, compile_model)
    except FileNotFoundError:
        raise ValueError(
            f"Preset '{preset_name}' not found. Available presets: deep_40M, explorer_xl_50M, flash_10M, flash_20M, flex_20M, wide_40M"
        )