Transformers and Attention
The Attention Mechanism
Attention is the key innovation that powers transformers. It allows models to focus on relevant parts of the input.
Why Attention?
Traditional RNNs process sequences step-by-step, creating a bottleneck. Attention allows direct connections between any two positions.
RNN: Word1 → Word2 → Word3 → ... → WordN (sequential)
Attention: Every word can directly attend to every other word!
Scaled Dot-Product Attention
The core attention formula:
Attention(Q, K, V) = softmax(Q × K^T / √d_k) × V
Where: - Q (Query): What am I looking for? - K (Key): What do I contain? - V (Value): What information do I provide? - d_k: Dimension of keys (for scaling)
Intuition
Think of it like a search engine: - Query: Your search terms - Keys: Document titles - Values: Document contents - Attention weights: Relevance scores
PyTorch Implementation
import torch
import torch.nn as nn
import torch.nn.functional as F
import mathdef scaled_dot_product_attention(query, key, value, mask=None):
"""
Args:
query: (batch, seq_len, d_k)
key: (batch, seq_len, d_k)
value: (batch, seq_len, d_v)
mask: optional mask for padding/causal attention
Returns:
attention output: (batch, seq_len, d_v)
attention weights: (batch, seq_len, seq_len)
"""
d_k = query.size(-1)
Compute attention scores
(batch, seq_len, d_k) @ (batch, d_k, seq_len) → (batch, seq_len, seq_len)
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
Apply mask (for padding or causal attention)
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
Softmax to get attention weights
attention_weights = F.softmax(scores, dim=-1)
Weighted sum of values
output = torch.matmul(attention_weights, value)
return output, attention_weightsExample usage
batch_size, seq_len, d_model = 2, 10, 64
Q = torch.randn(batch_size, seq_len, d_model)
K = torch.randn(batch_size, seq_len, d_model)
V = torch.randn(batch_size, seq_len, d_model)output, weights = scaled_dot_product_attention(Q, K, V)
print(f"Output shape: {output.shape}")
(2, 10, 64)
print(f"Attention weights shape: {weights.shape}") (2, 10, 10)
Self-Attention
When Q, K, V all come from the same sequence, it's called self-attention. Each position attends to all positions (including itself).