Attention Mechanism

The Attention Formula

The scaled dot-product attention from "Attention Is All You Need":

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

Let's decompose this in index notation.

Step-by-Step Breakdown

Step 1: Score Computation

$$S^{ij} = \frac{1}{\sqrt{d_k}} Q^{ia} K^{ja}$$

Step 2: Softmax Normalization

$$A^{ij} = \frac{\exp(S^{ij})}{\sum_{j'} \exp(S^{ij'})}$$

Step 3: Value Aggregation

$$O^{ib} = A^{ij} V^{jb}$$

Tensor Shapes

TensorShapeIndices
Queries $Q$$(n_q, d_k)$$Q^{ia}$
Keys $K$$(n_k, d_k)$$K^{ja}$
Values $V$$(n_k, d_v)$$V^{jb}$
Scores $S$$(n_q, n_k)$$S^{ij}$
Attention $A$$(n_q, n_k)$$A^{ij}$
Output $O$$(n_q, d_v)$$O^{ib}$

Self-Attention vs Cross-Attention

Self-attention: $Q$, $K$, $V$ all come from the same sequence

Cross-attention: $Q$ from one sequence, $K$, $V$ from another

Causal (Masked) Attention

For autoregressive models, we mask future positions:

$$S^{ij}_{\text{masked}} = \begin{cases} S^{ij} & \text{if } j \leq i \\ -\infty & \text{if } j > i \end{cases}$$

This ensures each position can only attend to earlier positions.

Multi-Head Attention

Introduce a head index $h$:

$$Q^{hia} = X^{ib} W_Q^{hba}$$ $$K^{hja} = X^{jb} W_K^{hba}$$ $$V^{hjb} = X^{jc} W_V^{hcb}$$

Each head computes attention independently:

$$O^{hib} = A^{hij} V^{hjb}$$

Then concatenate and project:

$$\text{Output}^{ic} = O^{hib} W_O^{hbc}$$

Code Example

from attn_tensors import scaled_dot_product_attention
from attn_tensors.multihead import multihead_attention

# Single-head attention
Q = jnp.randn(10, 64)  # 10 queries, 64 dims
K = jnp.randn(20, 64)  # 20 keys
V = jnp.randn(20, 64)  # 20 values

output = scaled_dot_product_attention(Q, K, V)
# output.shape = (10, 64)

# Get attention weights too
output, weights = scaled_dot_product_attention(Q, K, V, return_weights=True)
# weights.shape = (10, 20)

# Multi-head attention
output = multihead_attention(Q, K, V, num_heads=8)