Understanding Attention: From Theory to Implementation

Author

Ravi Kalia

Published

January 31, 2025

Attention

Made with ❤️ and GitHub Copilot
GitHub Copilot Logo

Overview

This article provides a structured mathematical explanation of attention mechanisms in deep learning, focusing on their application in transformer architectures. We’ll explore how sequences are processed through attention layers and understand the mathematical foundations of these powerful neural network components.

1 Key Concepts

Before diving into the mathematics, let’s establish our key concepts:

  • Attention: A mechanism allowing models to focus on relevant parts of input data
  • Self-Attention: A specific form where each element in a sequence attends to all others
  • Multi-Head Attention: Multiple parallel attention mechanisms working together
  • Positional Encoding: Method to incorporate sequential information

2 Data Preprocessing

2.1 From Text to Vectors

The transformation of text into numerical representations involves several steps:

  1. Tokenization: Convert text into token IDs
  2. Embedding: Map tokens to dense vectors
  3. Position Encoding: Add sequential information

Let’s examine each step in detail.

2.2 Tokenization Process

Given an input sentence:

"The cat sat on the mat"

We convert it to token IDs using subword tokenization:

tokens = [101, 2023, 3679, 2003, 2307, 102]  # Example IDs

2.3 Embedding Layer

The embedding process transforms discrete tokens into continuous vectors:

\[ E \in \mathbb{R}^{V \times d} \]

where:

  • \(V\) = vocabulary size
  • \(d\) = embedding dimension

For each token \(t_i\), we compute:

\[ x_i = E[t_i] \in \mathbb{R}^d \]

Resulting in input matrix:

\[ X = \begin{bmatrix} x_1 \\ x_2 \\ \vdots \\ x_n \end{bmatrix} \in \mathbb{R}^{n \times d} \]

2.4 Positional Encoding

To preserve sequence order, we add positional encodings:

\[ \begin{aligned} P_{i,2j} &= \sin\left(\frac{i}{10000^{2j/d}}\right) \\ P_{i,2j+1} &= \cos\left(\frac{i}{10000^{2j/d}}\right) \end{aligned} \]

Final input representation:

\[ X_{\text{final}} = X + P \]

3 Self-Attention Mechanism

3.1 Query, Key, and Value Transformations

Self-attention begins by creating three matrices from the input:

\[ \begin{aligned} Q &= X W_Q \\ K &= X W_K \\ V &= X W_V \end{aligned} \]

where \(W_Q, W_K, W_V \in \mathbb{R}^{d \times d}\) are learnable parameters.

3.2 Attention Computation

  1. Compute Attention Scores:

    \[ S = \frac{Q K^T}{\sqrt{d}} \]

  2. Apply Softmax:

    \[ A = \text{softmax}(S) \]

  3. Compute Weighted Values:

    \[ Z = A V \]

4 Multi-Head Attention

4.1 Parallel Attention Heads

For \(h\) heads, we compute:

\[ \begin{aligned} Q^{(h)} &= X W_Q^{(h)} \\ K^{(h)} &= X W_K^{(h)} \\ V^{(h)} &= X W_V^{(h)} \end{aligned} \]

4.2 Head Outputs

Each head produces its output:

\[ Z^{(h)} = \text{softmax} \left( \frac{Q^{(h)} (K^{(h)})^T}{\sqrt{d_k}} \right) V^{(h)} \]

4.3 Combining Head Outputs

  1. Concatenate:

    \[ H_{\text{concat}} = [Z^{(1)} \| Z^{(2)} \| \cdots \| Z^{(H)}] \]

  2. Project:

    \[ H_{\text{output}} = H_{\text{concat}} W_O \]

5 Training Process

5.1 Gradient Descent Updates

The attention weights are updated using:

\[ \begin{aligned} W_Q &\leftarrow W_Q - \eta \frac{\partial L}{\partial W_Q} \\ W_K &\leftarrow W_K - \eta \frac{\partial L}{\partial W_K} \\ W_V &\leftarrow W_V - \eta \frac{\partial L}{\partial W_K} \end{aligned} \]

5.2 Optimization Strategy

  • Use Adam optimizer for stable training
  • Apply gradient clipping to prevent exploding gradients
  • Implement learning rate warmup

6 Implementation Considerations

When implementing attention mechanisms, consider:

  1. Memory efficiency
  2. Numerical stability
  3. Parallelization opportunities
  4. Attention masking for padding

7 Practical Example

import torch
import torch.nn as nn

class SelfAttention(nn.Module):
    def __init__(self, d_model, n_heads):
        super().__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        batch_size = x.size(0)

        # Linear projections
        q = self.w_q(x).view(batch_size, -1, self.n_heads, self.d_k)
        k = self.w_k(x).view(batch_size, -1, self.n_heads, self.d_k)
        v = self.w_v(x).view(batch_size, -1, self.n_heads, self.d_k)

        # Transpose for attention computation
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        # Compute attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        attn = F.softmax(scores, dim=-1)

        # Apply attention to values
        out = torch.matmul(attn, v)

        # Reshape and project
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, -1, self.d_model)
        return self.w_o(out)