Transformers and Attention

150 min
Deep Learning
40%

The Attention Mechanism

Attention is the key innovation that powers transformers. It allows models to focus on relevant parts of the input.

Why Attention?

Traditional RNNs process sequences step-by-step, creating a bottleneck. Attention allows direct connections between any two positions.

RNN: Word1 → Word2 → Word3 → ... → WordN (sequential)
Attention: Every word can directly attend to every other word!

Scaled Dot-Product Attention

The core attention formula:

Attention(Q, K, V) = softmax(Q × K^T / √d_k) × V

Where: - Q (Query): What am I looking for? - K (Key): What do I contain? - V (Value): What information do I provide? - d_k: Dimension of keys (for scaling)

Intuition

Think of it like a search engine: - Query: Your search terms - Keys: Document titles - Values: Document contents - Attention weights: Relevance scores

PyTorch Implementation

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(query, key, value, mask=None): """ Args: query: (batch, seq_len, d_k) key: (batch, seq_len, d_k) value: (batch, seq_len, d_v) mask: optional mask for padding/causal attention Returns: attention output: (batch, seq_len, d_v) attention weights: (batch, seq_len, seq_len) """ d_k = query.size(-1)

Compute attention scores

(batch, seq_len, d_k) @ (batch, d_k, seq_len) → (batch, seq_len, seq_len)

scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)

Apply mask (for padding or causal attention)

if mask is not None: scores = scores.masked_fill(mask == 0, float('-inf'))

Softmax to get attention weights

attention_weights = F.softmax(scores, dim=-1)

Weighted sum of values

output = torch.matmul(attention_weights, value) return output, attention_weights

Example usage

batch_size, seq_len, d_model = 2, 10, 64 Q = torch.randn(batch_size, seq_len, d_model) K = torch.randn(batch_size, seq_len, d_model) V = torch.randn(batch_size, seq_len, d_model)

output, weights = scaled_dot_product_attention(Q, K, V) print(f"Output shape: {output.shape}")

(2, 10, 64)

print(f"Attention weights shape: {weights.shape}")

(2, 10, 10)

Self-Attention

When Q, K, V all come from the same sequence, it's called self-attention. Each position attends to all positions (including itself).