Creating a model instance
There are several ways to create a DMS model instance, ranging from using preset configurations to custom configurations.
Using Preset Configurations
The simplest way to create a model is using one of the preset configurations:
from directmultistep.model import ModelFactory
factory = ModelFactory.from_preset("flash_10M", compile_model=True)
model = factory.create_model()
Available presets include: deep_40M
, explorer_xl_50M
, flash_10M
, flash_20M
, flex_20M
, and wide_40M
.
Custom Configuration
For more control, you can create a custom configuration:
from directmultistep.model.config import Seq2SeqConfig, EncoderAConfig, MoEDecoderConfig
config = Seq2SeqConfig(
encoder=EncoderAConfig(
vocab_dim=53,
hid_dim=256,
n_layers=6,
n_heads=8,
ff_mult=3,
ff_activation="gelu",
dropout=0.1,
attn_bias=False,
context_window=280,
start_idx=0,
mask_idx=51,
pad_idx=52,
initiate_steps=True,
include_steps=True
),
decoder=MoEDecoderConfig(
vocab_dim=53,
hid_dim=256,
n_layers=6,
n_heads=8,
ff_mult=3,
ff_activation="gelu",
dropout=0.1,
attn_bias=False,
context_window=1075,
start_idx=0,
mask_idx=51,
pad_idx=52,
n_experts=3,
top_k=2,
capacity_factor=1.0,
),
)
factory = ModelFactory(config, device=None, compile_model=True)
model = factory.create_model()
Configuration Types
The model supports different types of encoders and decoders:
- Encoders:
EncoderAConfig
: EncoderA Type (the one we've been using so far)-
MoEEncoderConfig
: Mixture of Experts encoder -
Decoders:
TransformerConfig
: Standard transformer decoderMoEDecoderConfig
: Mixture of Experts decoder
Saving and Loading Configurations
Configurations can be saved to and loaded from YAML files:
# Save configuration
config.save("model_config.yaml")
# Load configuration and create model
factory = ModelFactory.from_config_file("model_config.yaml")
model = factory.create_model()
Source Code
directmultistep.model.config
TransformerConfig
dataclass
Configuration for transformer components.
Attributes:
Name | Type | Description |
---|---|---|
vocab_dim |
int
|
Vocabulary dimension. |
hid_dim |
int
|
Hidden dimension. |
n_layers |
int
|
Number of layers. |
n_heads |
int
|
Number of attention heads. |
ff_mult |
int
|
Feedforward multiplier. |
ff_activation |
Literal['gelu', 'relu']
|
Feedforward activation function ('gelu' or 'relu'). |
dropout |
float
|
Dropout probability. |
attn_bias |
bool
|
Whether to use attention bias. |
context_window |
int
|
Context window size. |
start_idx |
int
|
Start token index. |
mask_idx |
int
|
Mask token index. |
pad_idx |
int
|
Padding token index. |
Source code in src/directmultistep/model/config.py
load(path)
classmethod
Load config from yaml file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to load the config from. |
required |
Returns:
Type | Description |
---|---|
T
|
Loaded config. |
Source code in src/directmultistep/model/config.py
save(path)
Save config to yaml file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to save the config to. |
required |
Source code in src/directmultistep/model/config.py
MoEDecoderConfig
dataclass
Bases: TransformerConfig
Configuration for Mixture of Experts decoder components.
Attributes:
Name | Type | Description |
---|---|---|
n_experts |
int
|
Number of experts. |
top_k |
int
|
Number of experts to use in forward pass. |
capacity_factor |
float
|
Capacity factor for experts. |
Source code in src/directmultistep/model/config.py
EncoderAConfig
dataclass
Bases: TransformerConfig
Configuration for EncoderA components.
Attributes:
Name | Type | Description |
---|---|---|
initiate_steps |
bool
|
Whether to initiate steps. |
include_steps |
bool
|
Whether to include steps. |
Source code in src/directmultistep/model/config.py
MoEEncoderConfig
dataclass
Bases: EncoderAConfig
Configuration for Mixture of Experts encoder components.
Attributes:
Name | Type | Description |
---|---|---|
n_experts |
int
|
Number of experts. |
top_k |
int
|
Number of experts to use in forward pass. |
capacity_factor |
float
|
Capacity factor for experts. |
Source code in src/directmultistep/model/config.py
Seq2SeqConfig
dataclass
Complete model configuration.
Attributes:
Name | Type | Description |
---|---|---|
encoder |
TransformerConfig
|
Encoder configuration. |
decoder |
TransformerConfig
|
Decoder configuration. |
Source code in src/directmultistep/model/config.py
load(path)
classmethod
Load config from yaml file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to load the config from. |
required |
Returns:
Type | Description |
---|---|
Seq2SeqConfig
|
Loaded Seq2SeqConfig. |
Source code in src/directmultistep/model/config.py
save(path)
Save config to yaml file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
path
|
Path
|
Path to save the config to. |
required |
Source code in src/directmultistep/model/config.py
directmultistep.model.factory
ModelFactory
Factory class for creating and configuring models.
Source code in src/directmultistep/model/factory.py
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
|
__init__(config, device=None, compile_model=True, allow_mps=False)
Initializes the ModelFactory.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config
|
Seq2SeqConfig
|
The complete model configuration. |
required |
device
|
str | None
|
Optional device specification. If None, the best available device is used. |
None
|
compile_model
|
bool
|
Whether to compile the model using torch.compile. |
True
|
allow_mps
|
bool
|
Whether to allow MPS device usage. |
False
|
Source code in src/directmultistep/model/factory.py
create_model()
Creates and configures a Seq2Seq model based on the provided configuration.
Returns:
Type | Description |
---|---|
Seq2Seq
|
The configured Seq2Seq model. |
Source code in src/directmultistep/model/factory.py
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
determine_device(device=None, allow_mps=False)
staticmethod
Determines the appropriate device for model placement.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
device
|
str | None
|
Optional device specification. |
None
|
Returns:
Type | Description |
---|---|
device
|
The determined torch.device. |
Source code in src/directmultistep/model/factory.py
from_config_file(config_path, device=None, compile_model=True)
classmethod
Creates a ModelFactory instance from a configuration file.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
config_path
|
str | Path
|
Path to the configuration file. |
required |
device
|
str | None
|
Optional device specification. |
None
|
compile_model
|
bool
|
Whether to compile the model. |
True
|
Returns:
Type | Description |
---|---|
ModelFactory
|
The configured ModelFactory instance. |
Source code in src/directmultistep/model/factory.py
from_preset(preset_name, device=None, compile_model=True)
classmethod
Loads a preset configuration by name.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
preset_name
|
str
|
The name of the preset configuration. |
required |
device
|
str | None
|
Optional device specification. |
None
|
compile_model
|
bool
|
Whether to compile the model. |
True
|
Returns:
Type | Description |
---|---|
ModelFactory
|
The configured ModelFactory instance. |
Raises:
Type | Description |
---|---|
ValueError
|
If the preset is not found. |