Attention as Bilinear Form

A Physicist's Guide to Transformer Attention using Tensor Calculus


The Core Insight

The attention mechanism in transformers can be understood through the lens of tensor calculus and differential geometry. This perspective reveals deep connections to physics and provides a rigorous mathematical foundation.

Standard attention formula:

$$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$$

In index notation:

$$O^{ib} = A^{ij} V^{jb}, \quad A^{ij} = \frac{\exp(S^{ij})}{\sum_k \exp(S^{ik})}, \quad S^{ij} = \frac{1}{\sqrt{d_k}} Q^{ia} K^{ja}$$


Key Perspectives

1. Bilinear Forms and Metric Tensors

The score computation is a bilinear form:

$$S^{ij} = Q^{ia} g_{ab} K^{jb}$$

where $g_{ab} = \frac{1}{\sqrt{d_k}} \delta_{ab}$ is the metric tensor. This gives us:

2. Softmax as Gibbs Distribution

The attention weights form a Gibbs distribution from statistical mechanics:

$$A^{ij} = \frac{e^{\beta S^{ij}}}{Z^i}, \quad Z^i = \sum_j e^{\beta S^{ij}}$$

where $\beta = 1$ is the inverse temperature. This reveals:

3. Hopfield Network Interpretation

Modern Hopfield networks show attention is an associative memory:

$$\xi^{\text{new}} = V^T \cdot \text{softmax}(\beta \cdot K \cdot \xi)$$

The patterns stored in $K$ are retrieved via the attention mechanism.


Gradient Derivations

Using index notation, we derive all gradients explicitly:

Gradient w.r.t. Queries: $$\frac{\partial L}{\partial Q^{kl}} = \frac{1}{\sqrt{d_k}} \frac{\partial L}{\partial S^{kj}} K^{jl}$$

Gradient through Softmax: $$\frac{\partial L}{\partial S^{ij}} = A^{ij} \left( \frac{\partial L}{\partial A^{ij}} - \sum_{j'} A^{ij'} \frac{\partial L}{\partial A^{ij'}} \right)$$

Gradient w.r.t. Values: $$\frac{\partial L}{\partial V^{kl}} = A^{ik} \frac{\partial L}{\partial O^{il}}$$

All gradients are verified against JAX autodiff.


Quick Start

import jax.numpy as jnp
from attn_tensors import scaled_dot_product_attention
from attn_tensors.bilinear import bilinear_form_batch, scaled_euclidean_metric

# Standard attention
Q = jnp.array([[1.0, 0.0], [0.0, 1.0]])
K = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])
V = jnp.array([[1.0, 0.0], [0.5, 0.5], [0.0, 1.0]])

output = scaled_dot_product_attention(Q, K, V)

# With explicit metric tensor
g = scaled_euclidean_metric(d=2)
scores = bilinear_form_batch(Q, K, g)

Installation

# Clone the repository
git clone https://github.com/planckeon/attn-as-bilinear-form
cd attn-as-bilinear-form

# Install with uv
uv sync

# Run tests
uv run pytest tests/ -v

Modules

ModuleDescription
attentionCore attention operations
bilinearMetric tensors and bilinear forms
gradientsManual gradient derivations
softmaxSoftmax, entropy, Gibbs distribution
multiheadMulti-head attention
maskingCausal and padding masks
hopfieldHopfield network interpretation

Theory Deep Dives


References

  1. Vaswani et al. (2017). Attention Is All You Need
  2. Ramsauer et al. (2020). Hopfield Networks is All You Need
  3. Amari (1998). Natural Gradient Works Efficiently in Learning