Skip to content

Attention

This document describes the attention mechanisms used in the DMS model.

Summary

The core mechanism of attention emerges from needing to selectively focus on relevant information while processing sequences. When encoding tokens, each position must consider its relationship with all others to capture context. Attention computes similarity scores between each query position and all possible key positions, essentially asking "how relevant is each key to my current query?" These raw similarity scores are normalized through softmax to produce attention weights that sum to 1, creating a probability distribution over the keys for each query. The weighted sum of values according to these attention weights produces the final attention output, allowing the model to synthesize information from multiple positions with varying degrees of influence.

Flash Attention

Flash Attention reformulates attention computation to maximize use of fast SRAM cache while minimizing slower DRAM memory access. Rather than computing and storing the full attention matrix at once, it splits the computation into smaller blocks that fit in SRAM, computing partial attention scores and incrementally aggregating them. This tiling approach, combined with local softmax normalization within blocks, achieves mathematically equivalent results while drastically reducing memory bandwidth requirements. The key insight is maintaining rolling statistics of softmax normalization terms across blocks, allowing processing of long sequences without materializing the full attention matrix in memory – trading increased computation for reduced memory usage, which is favorable on modern hardware where memory bandwidth often constrains performance more than computational capacity.

Shape Convention

The shape suffixes follow a consistent convention:

  • B: Batch size
  • L: Target sequence length
  • M: Memory/source sequence length
  • D: Model hidden dimension
  • H: Number of attention heads

Source Code

directmultistep.model.components.attention

MultiHeadAttentionLayer

Bases: Module

Multi-head attention layer.

This layer applies multi-head attention to the input tensors.

Shape suffixes convention

B: batch size L: sequence length for decoder M: memory length (length of sequence being attended to) D: model dimension (sometimes called d_model or embedding_dim) H: number of attention heads in a layer

Parameters:

Name Type Description Default
hid_dim int

The hidden dimension size.

required
n_heads int

The number of attention heads.

required
dropout float

The dropout rate.

required
attn_bias bool

Whether to use bias in the linear layers.

required
Source code in src/directmultistep/model/components/attention.py
class MultiHeadAttentionLayer(nn.Module):
    """
    Multi-head attention layer.

    This layer applies multi-head attention to the input tensors.

    Shape suffixes convention:
        B: batch size
        L: sequence length for decoder
        M: memory length (length of sequence being attended to)
        D: model dimension (sometimes called d_model or embedding_dim)
        H: number of attention heads in a layer

    Args:
        hid_dim: The hidden dimension size.
        n_heads: The number of attention heads.
        dropout: The dropout rate.
        attn_bias: Whether to use bias in the linear layers.
    """

    def __init__(
        self,
        hid_dim: int,
        n_heads: int,
        dropout: float,
        attn_bias: bool,
        # device: torch.device,
    ):
        super().__init__()

        self.hid_dim = hid_dim
        self.n_heads = n_heads
        self.head_dim = hid_dim // n_heads

        self.query = nn.Linear(hid_dim, hid_dim, bias=attn_bias)
        self.key = nn.Linear(hid_dim, hid_dim, bias=attn_bias)
        self.value = nn.Linear(hid_dim, hid_dim, bias=attn_bias)

        self.projection = nn.Linear(hid_dim, hid_dim)

        self.dropout = nn.Dropout(dropout)
        # self.scale = torch.sqrt(torch.FloatTensor([self.head_dim])).to(device)

    def forward(
        self,
        query_BLD: Tensor,
        key_BMD: Tensor,
        value_BMD: Tensor,
        mask_B11M: Tensor | None = None,
    ) -> Tensor:
        """
        Forward pass of the multi-head attention layer.

        Shape suffixes convention:
            B: batch size
            L: sequence length for decoder
            M: memory length (length of sequence being attended to)
            D: model dimension (sometimes called d_model or embedding_dim)
            H: number of attention heads in a layer

        Args:
            query_BLD: The query tensor of shape (B, L, D).
            key_BMD: The key tensor of shape (B, M, D).
            value_BMD: The value tensor of shape (B, M, D).
            mask_B11M: The attention mask of shape (B, 1, 1, M).

        Returns:
            The output tensor of shape (B, L, D).
        """
        B, L, _ = query_BLD.shape
        Q_BLD = self.query(query_BLD)
        K_BMD = self.key(key_BMD)
        V_BMD = self.value(value_BMD)
        # Reshape into multiple heads
        Q_BHLD = Q_BLD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        K_BHMD = K_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
        V_BHMD = V_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

        if mask_B11M is not None:
            # Expand mask for all heads
            mask_BHLM = mask_B11M.expand(B, self.n_heads, L, -1)
            is_causal = False
        else:
            mask_BHLM = None
            is_causal = True

        attn_output_BHLD = nn.functional.scaled_dot_product_attention(
            query=Q_BHLD,
            key=K_BHMD,
            value=V_BHMD,
            attn_mask=mask_BHLM,
            dropout_p=self.dropout.p if self.training else 0.0,
            is_causal=is_causal,
            # scale=self.scale.item(),
        )
        attn_output_BLD = attn_output_BHLD.permute(0, 2, 1, 3).contiguous().view(B, L, self.hid_dim)
        output_BLD = cast(Tensor, self.projection(attn_output_BLD))
        return output_BLD

forward(query_BLD, key_BMD, value_BMD, mask_B11M=None)

Forward pass of the multi-head attention layer.

Shape suffixes convention

B: batch size L: sequence length for decoder M: memory length (length of sequence being attended to) D: model dimension (sometimes called d_model or embedding_dim) H: number of attention heads in a layer

Parameters:

Name Type Description Default
query_BLD Tensor

The query tensor of shape (B, L, D).

required
key_BMD Tensor

The key tensor of shape (B, M, D).

required
value_BMD Tensor

The value tensor of shape (B, M, D).

required
mask_B11M Tensor | None

The attention mask of shape (B, 1, 1, M).

None

Returns:

Type Description
Tensor

The output tensor of shape (B, L, D).

Source code in src/directmultistep/model/components/attention.py
def forward(
    self,
    query_BLD: Tensor,
    key_BMD: Tensor,
    value_BMD: Tensor,
    mask_B11M: Tensor | None = None,
) -> Tensor:
    """
    Forward pass of the multi-head attention layer.

    Shape suffixes convention:
        B: batch size
        L: sequence length for decoder
        M: memory length (length of sequence being attended to)
        D: model dimension (sometimes called d_model or embedding_dim)
        H: number of attention heads in a layer

    Args:
        query_BLD: The query tensor of shape (B, L, D).
        key_BMD: The key tensor of shape (B, M, D).
        value_BMD: The value tensor of shape (B, M, D).
        mask_B11M: The attention mask of shape (B, 1, 1, M).

    Returns:
        The output tensor of shape (B, L, D).
    """
    B, L, _ = query_BLD.shape
    Q_BLD = self.query(query_BLD)
    K_BMD = self.key(key_BMD)
    V_BMD = self.value(value_BMD)
    # Reshape into multiple heads
    Q_BHLD = Q_BLD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    K_BHMD = K_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)
    V_BHMD = V_BMD.view(B, -1, self.n_heads, self.head_dim).permute(0, 2, 1, 3)

    if mask_B11M is not None:
        # Expand mask for all heads
        mask_BHLM = mask_B11M.expand(B, self.n_heads, L, -1)
        is_causal = False
    else:
        mask_BHLM = None
        is_causal = True

    attn_output_BHLD = nn.functional.scaled_dot_product_attention(
        query=Q_BHLD,
        key=K_BHMD,
        value=V_BHMD,
        attn_mask=mask_BHLM,
        dropout_p=self.dropout.p if self.training else 0.0,
        is_causal=is_causal,
        # scale=self.scale.item(),
    )
    attn_output_BLD = attn_output_BHLD.permute(0, 2, 1, 3).contiguous().view(B, L, self.hid_dim)
    output_BLD = cast(Tensor, self.projection(attn_output_BLD))
    return output_BLD