Motivation

Pairwise attention is powerful, but it compresses interaction structure into second-order forms. Most efficient attention methods try to approximate or factor the $N \times N$ attention matrix. Triple attention takes a different perspective:

Instead of modeling pairwise token interactions, build a structured higher-order memory and let tokens read from it.

This post explains how triple attention works conceptually, and how we implement it in Triton using fused kernels that scale linearly in sequence length.


Why third-order memory?

Standard linear attention builds a pooled memory $S \in \mathbb{R}^{D \times D}$:

$$ S_{ij} = \sum_{n=1}^{N} K_{ni} V_{nj}, $$

and predicts with:

$$ Y_{nj} = \sum_{i=1}^{D} Q_{ni} S_{ij}. $$

Triple attention instead accumulates a third-order state $S \in \mathbb{R}^{D_q \times D_v \times D_q}$:

$$ S_{ijk} = \sum_{n=1}^{N} K^{(1)}_{ni} V_{nj} K^{(2)}_{nk}, $$

with output:

$$ Y_{nj} = \sum_{i=1}^{D} \sum_{k=1}^{D} Q^{(1)}_{ni} S_{ijk} Q^{(2)}_{nk}. $$

The sequence reduction stays linear in $N$, while representational capacity expands from $D \times D$ to $D \times D \times D$. The tradeoff: compute scales as $O(N D_q^2 D_v)$, so this mechanism suits settings where $D_q$ is fixed and $N$ is large.


From quadratic attention to structured memory

Standard self-attention forms $A = \text{softmax}(QK^\top)V$, which requires materializing or implicitly computing an $N \times N$ interaction. Triple attention avoids this entirely. Instead of computing token-to-token interactions, we construct a third-order state tensor $S \in \mathbb{R}^{D_q \times D_v \times D_q}$ that aggregates global information from all tokens in a streaming fashion. Each token then reads from this state via two learned query projections. No $N \times N$ matrix is ever formed.


Tensor Shapes

We begin with projected tensors:

  • $Q_1, Q_2, K_1, K_2 \in \mathbb{R}^{B \times H \times N \times D_q}$
  • $V \in \mathbb{R}^{B \times H \times N \times D_v}$

For kernel simplicity, batch and heads are flattened:

Q1, Q2, K1, K2: [BH, N, Dq]
V:              [BH, N, Dv]

We allocate:

  • STATE: [BH, Dq, Dv, Dq] (accumulated in fp32)
  • O: [BH, N, Dv]

The memory cost is independent of sequence length $N$.


Forward pass

The forward pass is split into two fused Triton kernels.

Phase 1: Streaming state construction

Kernel:

triple_fwd_state_kernel

Each instance of the kernel:

  • Tiles indices $(i, j, k)$ of the state tensor.
  • Streams over tokens in chunks (CHUNK_N, e.g., 4096).
  • Accumulates contributions into STATE using atomic adds.

Conceptually:

$$ S_{ijk} = \sum_{n=1}^N K_{1,n,i} \cdot V_{n,j} \cdot K_{2,n,k} $$

Key properties:

  • Complexity is O(N D_q^2 D_v).
  • No token-token matrix is constructed.
  • Streaming over sequence makes it linear in $N$.
  • fp32 accumulation ensures stability.

This phase builds a compact global memory.


Phase 2: Output contraction

Kernel:

triple_fwd_out_kernel

This kernel:

  • Tiles over tokens (BLOCK_N_OUT)
  • Contracts $Q_1$ and $Q_2$ with the precomputed STATE
  • Writes output blocks to O

Conceptually:

$$ O_n = \sum_{i,j,k} Q_{1,n,i} \cdot S_{ijk} \cdot Q_{2,n,k} $$

Interpretation:

  • The state is a global communication hub.
  • Each token decides how to read from it via two learned projections.

This is structurally similar to routing through a learned latent space, but implemented as a multilinear contraction.


Why this scales

Standard attention scales as:

$$ O(N^2 D) $$

Triple attention scales as:

$$ O(N D_q^2 D_v) $$

If $D_q$ is held fixed as $N$ grows, this is linear in sequence length.

Memory never grows with $N^2$.

This makes it viable for long-sequence workloads, including:

  • PDE surrogate models
  • Large point-cloud processing
  • Long-context sequence modeling

Numerical considerations

Several implementation details are critical:

  • Accumulate STATE in fp32.
  • Use tensor cores (TF32 where available).
  • Store outputs in fp16/bf16.
  • Chunk over sequence dimension to fit memory.
  • Use atomic adds carefully to avoid race conditions.

The mixed-precision strategy preserves stability while keeping memory bandwidth low.


Backward pass

The backward pass mirrors the forward decomposition.

Implemented as a custom torch.autograd.Function, it performs:

1. Accumulate dSTATE

Compute:

$$ dS = \sum_n dO_n \cdot Q_{1,n} \cdot Q_{2,n} $$

Streaming over tokens in chunks.

2. Gradients for K1, K2, V

Contract dSTATE with remaining factors:

  • dK1
  • dK2
  • dV

Each has its own fused Triton kernel.

3. Gradients for Q1, Q2

Contract saved STATE with dO.

The code includes explicit einsum expressions for gradient verification, making parity testing against a reference implementation straightforward.


Conceptual perspective

Triple attention reflects a broader idea:

Global communication does not require pairwise token interaction.

Instead of asking “which tokens attend to which?”, we ask:

Can we compress global structure into a structured tensor, and let tokens read from it?

This viewpoint connects to:

  • Low-rank attention
  • Latent routing methods
  • Multilinear tensor contractions
  • Structured operator learning

It also suggests a hypothesis:

If a structured self-attention mechanism captures global communication efficiently, it may extend naturally to causal and autoregressive settings.

Exploring that extension is ongoing work.


Open problems

  • Better approximations to reduce $D^3$ pressure
  • Hybrid blocks combining low-rank gather–scatter with triple memory
  • Causal decoding variants for language modeling

Closing thoughts

Triple attention is not just a kernel experiment — it is an exploration of structured global memory.

By fusing state construction and output contraction in Triton, we obtain linear scaling in sequence length, stable mixed-precision execution, and a flexible multilinear attention primitive. This kernel serves as a foundation for further experiments in structured and adaptive attention mechanisms.

The full implementation is available in the FLARE repository alongside the paper (arXiv:2508.12594).


References

  1. Vaswani, A. et al. Attention Is All You Need. NeurIPS (2017). https://arxiv.org/abs/1706.03762
  2. Katharopoulos, A. et al. Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention. ICML (2020). https://arxiv.org/abs/2006.16236
  3. Kozachinskiy, A. et al. Strassen Attention, Split VC Dimension and Compositionality in Transformers. arXiv (2025). https://arxiv.org/abs/2501.19215
  4. Roy, A. et al. Fast and Simplex: 2-Simplicial Attention in Triton. arXiv (2025). https://arxiv.org/abs/2507.02754
  5. Qin, Z. et al. The Devil in Linear Transformer. arXiv (2022). https://arxiv.org/abs/2210.10340