Skip to content

Mixture of Experts

This document describes the Mixture of Experts (MoE) components used in the DMS model. MoE is a technique that improves model capacity and efficiency by routing different inputs to specialized sub-networks (experts).

Position-wise Feed-forward Layer

The standard feed-forward network serves as our baseline for comparison with MoE layers. It processes each position in the sequence independently through a simple two-layer network with expansion and projection. This is the traditional architecture used in transformer models.

Noisy Top-k Router

The router is the brain of the MoE system - it decides which experts should process each token. Key features:

  • Uses learned routing weights to match tokens with relevant experts
  • Adds learned noise to encourage exploration and prevent expert collapse
  • Selects top-k experts per token to enable specialization while maintaining redundancy
  • Produces sparse routing probabilities to enable efficient computation

The noise mechanism is particularly important as it:

  1. Prevents tokens from always taking the same path
  2. Helps balance load across experts
  3. Improves training stability

Expert Network

Each expert is a specialized feed-forward network that becomes tuned to handle specific types of tokens or patterns. The expert architecture mirrors the standard feed-forward layer, but each expert can learn different specializations. For example:

  • Some experts might focus on syntax
  • Others on specific vocabulary domains
  • Others on particular transformation patterns

Sparse MoE Layer

This is where everything comes together into an efficient, scalable system:

  1. Token Routing: The router examines each token and decides which experts should process it
  2. Load Balancing:
    • Uses capacity factors to prevent expert overload
    • Ensures even utilization of experts
    • Handles cases where too many tokens want the same expert
  3. Parallel Processing:
    • Tokens are grouped by assigned expert
    • Each expert processes its assigned group
    • Results are combined based on routing weights

The sparse computation pattern makes MoE layers much more efficient than simply running multiple full-size feed-forward layers.

Intuition Behind MoE

Think of MoE like a team of specialists:

  • Instead of every token going through the same general-purpose network
  • Tokens are routed to experts that are best suited to process them
  • Each expert becomes specialized in handling certain types of patterns
  • The router learns to match tokens with the right experts

This specialization allows the model to:

  • Handle a wider range of patterns effectively
  • Scale capacity without scaling computation for every token
  • Develop focused expertise in different aspects of the task

Source Code

directmultistep.model.components.moe

Expert

Bases: Module

A single expert in the MoE layer.

Applies a two-layer feedforward network to the input.

Shape suffixes

B: batch size L: sequence length D: model dimension F: feed-forward subnetwork hidden size

Source code in src/directmultistep/model/components/moe.py
class Expert(nn.Module):
    """A single expert in the MoE layer.

    Applies a two-layer feedforward network to the input.

    Shape suffixes:
        B: batch size
        L: sequence length
        D: model dimension
        F: feed-forward subnetwork hidden size
    """

    def __init__(
        self,
        hid_dim: int,
        ff_mult: int,
        ff_activation: str,
        dropout: float,
    ):
        """Initializes the Expert.

        Args:
            hid_dim: The hidden dimension size (D).
            ff_mult: The feed-forward expansion factor.
            ff_activation: The activation function type.
            dropout: The dropout rate.
        """
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(hid_dim, ff_mult * hid_dim),
            activation_dict[ff_activation],
            nn.Linear(ff_mult * hid_dim, hid_dim),
            nn.Dropout(dropout),
        )

    def forward(self, x_BLD: Tensor) -> Tensor:
        """Forward pass of the Expert.

        Args:
            x_BLD: The input tensor of shape (B, L, D).

        Returns:
            The output tensor of shape (B, L, D).
        """
        return self.net(x_BLD)  # type: ignore

__init__(hid_dim, ff_mult, ff_activation, dropout)

Initializes the Expert.

Parameters:

Name Type Description Default
hid_dim int

The hidden dimension size (D).

required
ff_mult int

The feed-forward expansion factor.

required
ff_activation str

The activation function type.

required
dropout float

The dropout rate.

required
Source code in src/directmultistep/model/components/moe.py
def __init__(
    self,
    hid_dim: int,
    ff_mult: int,
    ff_activation: str,
    dropout: float,
):
    """Initializes the Expert.

    Args:
        hid_dim: The hidden dimension size (D).
        ff_mult: The feed-forward expansion factor.
        ff_activation: The activation function type.
        dropout: The dropout rate.
    """
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(hid_dim, ff_mult * hid_dim),
        activation_dict[ff_activation],
        nn.Linear(ff_mult * hid_dim, hid_dim),
        nn.Dropout(dropout),
    )

forward(x_BLD)

Forward pass of the Expert.

Parameters:

Name Type Description Default
x_BLD Tensor

The input tensor of shape (B, L, D).

required

Returns:

Type Description
Tensor

The output tensor of shape (B, L, D).

Source code in src/directmultistep/model/components/moe.py
def forward(self, x_BLD: Tensor) -> Tensor:
    """Forward pass of the Expert.

    Args:
        x_BLD: The input tensor of shape (B, L, D).

    Returns:
        The output tensor of shape (B, L, D).
    """
    return self.net(x_BLD)  # type: ignore

NoisyTopkRouter

Bases: Module

Noisy top-k router for MoE.

Routes inputs to the top-k experts based on noisy logits.

Shape suffixes

B: batch size L: sequence length D: model dimension E: number of experts K: top_k

Source code in src/directmultistep/model/components/moe.py
class NoisyTopkRouter(nn.Module):
    """Noisy top-k router for MoE.

    Routes inputs to the top-k experts based on noisy logits.

    Shape suffixes:
        B: batch size
        L: sequence length
        D: model dimension
        E: number of experts
        K: top_k
    """

    def __init__(self, hid_dim: int, n_experts: int, top_k: int):
        """Initializes the NoisyTopkRouter.

        Args:
            hid_dim: The hidden dimension size (D).
            n_experts: The number of experts (E).
            top_k: The number of top experts to route to (K).
        """
        super().__init__()
        self.top_k = top_k
        self.topkroute_linear = nn.Linear(hid_dim, n_experts)
        self.noise_linear = nn.Linear(hid_dim, n_experts)

    def forward(self, x_BLD: Tensor) -> tuple[Tensor, Tensor]:
        """Forward pass of the NoisyTopkRouter.

        Args:
            x_BLD: The input tensor of shape (B, L, D).

        Returns:
            A tuple containing:
                - The router output tensor of shape (B, L, E).
                - The indices of the top-k experts of shape (B, L, K).
        """
        logits_BLE = self.topkroute_linear(x_BLD)
        noise_logits_BLE = self.noise_linear(x_BLD)
        # Adding scaled unit gaussian noise to the logits
        noise_BLE = torch.randn_like(logits_BLE) * F.softplus(noise_logits_BLE)
        noisy_logits_BLE = logits_BLE + noise_BLE

        top_k_logits_BLE, indices_BLK = noisy_logits_BLE.topk(self.top_k, dim=-1)
        zeros_BLE = torch.full_like(noisy_logits_BLE, float("-inf"))
        # creating a sparse tensor with top-k logits
        sparse_logits_BLE = zeros_BLE.scatter(-1, indices_BLK, top_k_logits_BLE)
        router_output_BLE = F.softmax(sparse_logits_BLE, dim=-1)
        return router_output_BLE, indices_BLK

__init__(hid_dim, n_experts, top_k)

Initializes the NoisyTopkRouter.

Parameters:

Name Type Description Default
hid_dim int

The hidden dimension size (D).

required
n_experts int

The number of experts (E).

required
top_k int

The number of top experts to route to (K).

required
Source code in src/directmultistep/model/components/moe.py
def __init__(self, hid_dim: int, n_experts: int, top_k: int):
    """Initializes the NoisyTopkRouter.

    Args:
        hid_dim: The hidden dimension size (D).
        n_experts: The number of experts (E).
        top_k: The number of top experts to route to (K).
    """
    super().__init__()
    self.top_k = top_k
    self.topkroute_linear = nn.Linear(hid_dim, n_experts)
    self.noise_linear = nn.Linear(hid_dim, n_experts)

forward(x_BLD)

Forward pass of the NoisyTopkRouter.

Parameters:

Name Type Description Default
x_BLD Tensor

The input tensor of shape (B, L, D).

required

Returns:

Type Description
tuple[Tensor, Tensor]

A tuple containing: - The router output tensor of shape (B, L, E). - The indices of the top-k experts of shape (B, L, K).

Source code in src/directmultistep/model/components/moe.py
def forward(self, x_BLD: Tensor) -> tuple[Tensor, Tensor]:
    """Forward pass of the NoisyTopkRouter.

    Args:
        x_BLD: The input tensor of shape (B, L, D).

    Returns:
        A tuple containing:
            - The router output tensor of shape (B, L, E).
            - The indices of the top-k experts of shape (B, L, K).
    """
    logits_BLE = self.topkroute_linear(x_BLD)
    noise_logits_BLE = self.noise_linear(x_BLD)
    # Adding scaled unit gaussian noise to the logits
    noise_BLE = torch.randn_like(logits_BLE) * F.softplus(noise_logits_BLE)
    noisy_logits_BLE = logits_BLE + noise_BLE

    top_k_logits_BLE, indices_BLK = noisy_logits_BLE.topk(self.top_k, dim=-1)
    zeros_BLE = torch.full_like(noisy_logits_BLE, float("-inf"))
    # creating a sparse tensor with top-k logits
    sparse_logits_BLE = zeros_BLE.scatter(-1, indices_BLK, top_k_logits_BLE)
    router_output_BLE = F.softmax(sparse_logits_BLE, dim=-1)
    return router_output_BLE, indices_BLK

PositionwiseFeedforwardLayer

Bases: Module

Positionwise feedforward layer.

Applies a two-layer feedforward network to the input.

Shape suffixes

B: batch size L: sequence length D: model dimension F: feed-forward subnetwork hidden size

Source code in src/directmultistep/model/components/moe.py
class PositionwiseFeedforwardLayer(nn.Module):
    """Positionwise feedforward layer.

    Applies a two-layer feedforward network to the input.

    Shape suffixes:
        B: batch size
        L: sequence length
        D: model dimension
        F: feed-forward subnetwork hidden size
    """

    def __init__(
        self,
        hid_dim: int,
        ff_mult: int,
        ff_activation: nn.Module,
        dropout: float,
    ):
        """Initializes the PositionwiseFeedforwardLayer.

        Args:
            hid_dim: The hidden dimension size (D).
            ff_mult: The feed-forward expansion factor.
            ff_activation: The activation function.
            dropout: The dropout rate.
        """
        super().__init__()

        self.fc_1 = nn.Linear(hid_dim, ff_mult * hid_dim)
        self.activ = ff_activation
        self.fc_2 = nn.Linear(hid_dim * ff_mult, hid_dim)

        self.dropout = nn.Dropout(dropout)

    def forward(self, x_BLD: Tensor) -> Tensor:
        """Forward pass of the PositionwiseFeedforwardLayer.

        Args:
            x_BLD: The input tensor of shape (B, L, D).

        Returns:
            The output tensor of shape (B, L, D).
        """
        x_BLF = self.dropout(self.activ(self.fc_1(x_BLD)))
        x_BLD = self.fc_2(x_BLF)
        return x_BLD

__init__(hid_dim, ff_mult, ff_activation, dropout)

Initializes the PositionwiseFeedforwardLayer.

Parameters:

Name Type Description Default
hid_dim int

The hidden dimension size (D).

required
ff_mult int

The feed-forward expansion factor.

required
ff_activation Module

The activation function.

required
dropout float

The dropout rate.

required
Source code in src/directmultistep/model/components/moe.py
def __init__(
    self,
    hid_dim: int,
    ff_mult: int,
    ff_activation: nn.Module,
    dropout: float,
):
    """Initializes the PositionwiseFeedforwardLayer.

    Args:
        hid_dim: The hidden dimension size (D).
        ff_mult: The feed-forward expansion factor.
        ff_activation: The activation function.
        dropout: The dropout rate.
    """
    super().__init__()

    self.fc_1 = nn.Linear(hid_dim, ff_mult * hid_dim)
    self.activ = ff_activation
    self.fc_2 = nn.Linear(hid_dim * ff_mult, hid_dim)

    self.dropout = nn.Dropout(dropout)

forward(x_BLD)

Forward pass of the PositionwiseFeedforwardLayer.

Parameters:

Name Type Description Default
x_BLD Tensor

The input tensor of shape (B, L, D).

required

Returns:

Type Description
Tensor

The output tensor of shape (B, L, D).

Source code in src/directmultistep/model/components/moe.py
def forward(self, x_BLD: Tensor) -> Tensor:
    """Forward pass of the PositionwiseFeedforwardLayer.

    Args:
        x_BLD: The input tensor of shape (B, L, D).

    Returns:
        The output tensor of shape (B, L, D).
    """
    x_BLF = self.dropout(self.activ(self.fc_1(x_BLD)))
    x_BLD = self.fc_2(x_BLF)
    return x_BLD

SparseMoE

Bases: Module

Sparse Mixture of Experts layer.

Routes inputs to a subset of experts and combines their outputs.

Shape suffixes

B: batch size L: sequence length D: model dimension E: number of experts K: top_k S: number of selected tokens for an expert

Source code in src/directmultistep/model/components/moe.py
class SparseMoE(nn.Module):
    """Sparse Mixture of Experts layer.

    Routes inputs to a subset of experts and combines their outputs.

    Shape suffixes:
        B: batch size
        L: sequence length
        D: model dimension
        E: number of experts
        K: top_k
        S: number of selected tokens for an expert
    """

    def __init__(
        self,
        hid_dim: int,
        n_experts: int,
        top_k: int,
        ff_mult: int,
        ff_activation: str,
        dropout: float,
        capacity_factor: float,
    ):
        """Initializes the SparseMoE layer.

        Args:
            hid_dim: The hidden dimension size (D).
            n_experts: The number of experts (E).
            top_k: The number of top experts to route to (K).
            ff_mult: The feed-forward expansion factor.
            ff_activation: The activation function type.
            dropout: The dropout rate.
            capacity_factor: The capacity factor for each expert.
        """
        super(SparseMoE, self).__init__()
        self.router = NoisyTopkRouter(hid_dim, n_experts, top_k)
        self.experts = nn.ModuleList([Expert(hid_dim, ff_mult, ff_activation, dropout) for _ in range(n_experts)])
        self.n_experts = n_experts
        self.top_k = top_k
        self.capacity_factor = capacity_factor

    def forward(self, x_BLD: Tensor) -> Tensor:
        """Forward pass of the SparseMoE layer.

        Args:
            x_BLD: The input tensor of shape (B, L, D).

        Returns:
            The output tensor of shape (B, L, D).
        """
        B, L, _ = x_BLD.shape
        gating_output_BLE, indices_BLK = self.router(x_BLD)
        final_output_BLD = torch.zeros_like(x_BLD)

        flat_x_FD = x_BLD.view(-1, x_BLD.size(-1))  # [B*L, D], define B*L=F
        flat_gating_output_FE = gating_output_BLE.view(-1, gating_output_BLE.size(-1))
        n_tkns = B * L * self.top_k
        capacity = int((n_tkns / self.n_experts) * self.capacity_factor)

        updates_FD = torch.zeros_like(flat_x_FD)
        for i, expert in enumerate(self.experts):
            # Create a mask for the inputs where the current expert is in top-k
            expert_mask_BL = (indices_BLK == i).any(dim=-1)
            flat_mask_F = expert_mask_BL.view(-1)
            selected_idxs_F = torch.nonzero(flat_mask_F).squeeze(-1)

            if selected_idxs_F.numel() > capacity:
                limited_idxs_F = selected_idxs_F[:capacity]
            else:
                limited_idxs_F = selected_idxs_F

            if limited_idxs_F.numel() > 0:
                expert_input_SD = flat_x_FD[limited_idxs_F]  # S = sum(flat_mask_F)
                expert_output_SD = expert(expert_input_SD)

                # Extract and apply gating scores, [S] -> [S, 1]
                gating_scores_S1 = flat_gating_output_FE[limited_idxs_F, i].unsqueeze(1)
                weighted_output_SD = expert_output_SD * gating_scores_S1

                updates_FD.index_add_(0, limited_idxs_F, weighted_output_SD)

        final_output_BLD += updates_FD.view(B, L, -1)

        return final_output_BLD

__init__(hid_dim, n_experts, top_k, ff_mult, ff_activation, dropout, capacity_factor)

Initializes the SparseMoE layer.

Parameters:

Name Type Description Default
hid_dim int

The hidden dimension size (D).

required
n_experts int

The number of experts (E).

required
top_k int

The number of top experts to route to (K).

required
ff_mult int

The feed-forward expansion factor.

required
ff_activation str

The activation function type.

required
dropout float

The dropout rate.

required
capacity_factor float

The capacity factor for each expert.

required
Source code in src/directmultistep/model/components/moe.py
def __init__(
    self,
    hid_dim: int,
    n_experts: int,
    top_k: int,
    ff_mult: int,
    ff_activation: str,
    dropout: float,
    capacity_factor: float,
):
    """Initializes the SparseMoE layer.

    Args:
        hid_dim: The hidden dimension size (D).
        n_experts: The number of experts (E).
        top_k: The number of top experts to route to (K).
        ff_mult: The feed-forward expansion factor.
        ff_activation: The activation function type.
        dropout: The dropout rate.
        capacity_factor: The capacity factor for each expert.
    """
    super(SparseMoE, self).__init__()
    self.router = NoisyTopkRouter(hid_dim, n_experts, top_k)
    self.experts = nn.ModuleList([Expert(hid_dim, ff_mult, ff_activation, dropout) for _ in range(n_experts)])
    self.n_experts = n_experts
    self.top_k = top_k
    self.capacity_factor = capacity_factor

forward(x_BLD)

Forward pass of the SparseMoE layer.

Parameters:

Name Type Description Default
x_BLD Tensor

The input tensor of shape (B, L, D).

required

Returns:

Type Description
Tensor

The output tensor of shape (B, L, D).

Source code in src/directmultistep/model/components/moe.py
def forward(self, x_BLD: Tensor) -> Tensor:
    """Forward pass of the SparseMoE layer.

    Args:
        x_BLD: The input tensor of shape (B, L, D).

    Returns:
        The output tensor of shape (B, L, D).
    """
    B, L, _ = x_BLD.shape
    gating_output_BLE, indices_BLK = self.router(x_BLD)
    final_output_BLD = torch.zeros_like(x_BLD)

    flat_x_FD = x_BLD.view(-1, x_BLD.size(-1))  # [B*L, D], define B*L=F
    flat_gating_output_FE = gating_output_BLE.view(-1, gating_output_BLE.size(-1))
    n_tkns = B * L * self.top_k
    capacity = int((n_tkns / self.n_experts) * self.capacity_factor)

    updates_FD = torch.zeros_like(flat_x_FD)
    for i, expert in enumerate(self.experts):
        # Create a mask for the inputs where the current expert is in top-k
        expert_mask_BL = (indices_BLK == i).any(dim=-1)
        flat_mask_F = expert_mask_BL.view(-1)
        selected_idxs_F = torch.nonzero(flat_mask_F).squeeze(-1)

        if selected_idxs_F.numel() > capacity:
            limited_idxs_F = selected_idxs_F[:capacity]
        else:
            limited_idxs_F = selected_idxs_F

        if limited_idxs_F.numel() > 0:
            expert_input_SD = flat_x_FD[limited_idxs_F]  # S = sum(flat_mask_F)
            expert_output_SD = expert(expert_input_SD)

            # Extract and apply gating scores, [S] -> [S, 1]
            gating_scores_S1 = flat_gating_output_FE[limited_idxs_F, i].unsqueeze(1)
            weighted_output_SD = expert_output_SD * gating_scores_S1

            updates_FD.index_add_(0, limited_idxs_F, weighted_output_SD)

    final_output_BLD += updates_FD.view(B, L, -1)

    return final_output_BLD