Skip to content

trnfft: FFT is a GEMM, and then it isn't

trnfft v0.12–v0.15 shipped three new FFT dispatch paths — DFT-GEMM, Stockham radix-4 with twiddle precomputation, and Stockham radix-8 with a Tensor-engine W₈ kernel — producing 20–37% improvements over the butterfly baseline at medium and large N. The architectural argument running through all three is the same: on Trainium, the bottleneck is not arithmetic but engine utilization and kernel launches. Whether that argument holds at a given N, and at what cost, is where most of the engineering work actually lived.

The problem

cuFFT works by dispatching radix-2 (or mixed-radix) butterfly stages, each one a massively parallel warp-per-butterfly operation on the GPU. A Trainium port runs into two problems immediately.

The first is cosmetic but structural: Trainium has no complex dtype. trnfft's ComplexTensor wraps a split real/imaginary pair and expresses every operation as four real FP32 tensor ops. That's the same Strassen-style identity the existing _complex_gemm_kernel already exploits:

C_real = A_real @ B_real  A_imag @ B_imag
C_imag = A_real @ B_imag + A_imag @ B_real

Four real matmuls, which maps cleanly to four nisa.nc_matmul accumulations into PSUM.

The second problem is structural and not cosmetic: butterfly stages run on the Vector engine. Each stage is a twiddle multiply (FP32 element-wise) followed by a butterfly add/subtract. Both operations land on the Vector engine. The Tensor engine — the part of the chip with most of the compute throughput — never fires during an FFT.

The question driving v0.12–v0.15: is there a formulation of FFT that routes meaningfully through the Tensor engine?

What the architecture suggests

The Tensor engine's native operation is nc_matmul: a batched systolic-array matrix multiply accumulating into a FP32 PSUM tile. Its partition dimension is fixed at 128; its free dimension goes to 512. At training workloads, the engine is fed with contiguous tiles of large matrices, and the per-launch overhead amortizes over thousands of multiply-accumulate operations.

What the architecture suggests, then, depends on N.

Small N (≤ 256): the entire DFT fits in one systolic-array matmul. The DFT matrix W is N×N. For N=256, (B, 256) @ (256, 256) is a single nc_matmul call. One Tensor-engine launch replaces log₂(N) = 8 Vector-engine butterfly stages. The O(N²) arithmetic is usually cited as a reason not to do this; on Trainium, the one-launch/Tensor-engine advantage outweighs the FLOP count until the PSUM accumulation error floor binds — measured at N=256 in FP32.

Medium N (power-of-4, > 256): one matmul can't cover the whole transform without exceeding the precision budget. Stockham radix-4 decomposes N into log₄(N) stages of 4-point DFTs. The critical observation: W₄ has coefficients {1, −1, i, −i}, meaning the W₄ matvec reduces to adds and sign flips. No multiplications. The Tensor engine is idle in each stage; the twiddle multiply (which does require real FP32 multiplications) dominates the per-stage cost.

Large N (power-of-8, > 256): W₈ has entries exp(−2πi·j·k/8), which include ±√2/2 ± i√2/2. These require actual multiplications. Using the Tensor engine for the W₈ matmul now earns its keep.

The dispatch table that results:

flowchart TD
  A["trnfft.fft(x, N)"] --> B{N ≤ 256?}
  B -->|yes| C["DFT-GEMM\none nc_matmul\nTensor engine"]
  B -->|no| D{N = 8^k?}
  D -->|yes| E["Radix-8 Stockham\nlog₈(N) stages\nVector twiddle + Tensor W₈"]
  D -->|no| F{N = 4^k?}
  F -->|yes| G["Radix-4 Stockham\nlog₄(N) stages\nVector twiddle + Vector W₄"]
  F -->|no| H["Butterfly\nlog₂(N) stages\nVector engine"]

The approach

DFT-GEMM (v0.12)

_fft_via_gemm constructs the N×N DFT matrix W on CPU, transfers it once, and dispatches complex_gemm(x, W) — a single nc_matmul call per transform. Dispatch condition: n <= _DFT_GEMM_THRESHOLD and precision != "kahan". The threshold is precision-bound, not performance-bound: at N=512, FP32 PSUM accumulation exceeds the 1e-3 relative-error tolerance. N=256 stays inside it.

Stockham radix-4 with twiddle precomputation (v0.13)

The initial radix-4 implementation computed twiddle factors per-stage inside the loop — one torch.cos/torch.sin + expand + CPU-to-device transfer per stage. Profiling revealed that twiddle recomputation was the dominant overhead term, not the kernel time or permute cost. The fix: precompute all log₄(N) twiddle tensors on CPU before the loop, transfer them as independent standalone tensors (slicing an XLA device tensor creates DynamicSlice HLO ops that bust the NEFF cache). This alone delivered the 6–9% improvement over butterfly that made radix-4 the default path for power-of-four N > 256.

Stockham radix-8 with Tensor-engine W₈ (v0.15)

stockham_radix8_w8_kernel takes pre-twiddled (total_groups, 8) input and applies W₈ via four nisa.nc_matmul calls — the standard complex GEMM decomposition:

@nki.jit
def stockham_radix8_w8_kernel(a_re, a_im, w8_re, w8_im):
    ...
    for m in nl.affine_range(n_partition_tiles):
        ar_t = nl.load_transpose2d(a_re[m_off:m_off + groups_chunk, :8])
        ai_t = nl.load_transpose2d(a_im[m_off:m_off + groups_chunk, :8])
        nisa.nc_matmul(dst=psum_cr, stationary=ar_t, moving=w8_r, accumulate=True)
        nisa.nc_matmul(dst=psum_cr, stationary=ai_t, moving=neg_w8_i, accumulate=True)
        nisa.nc_matmul(dst=psum_ci, stationary=ar_t, moving=w8_i, accumulate=True)
        nisa.nc_matmul(dst=psum_ci, stationary=ai_t, moving=w8_r, accumulate=True)
        ...

W₈ is symmetric (W₈[j,k] = W₈[k,j] since j·k = k·j), so W₈ = W₈ᵀ and the matrix can be passed directly as the moving tile without transposition. The twiddle multiply runs on the XLA device as a PyTorch element-wise op before the kernel call — not in NKI — for a reason documented in the next section.

What didn't work

The benchmarking saga. Nine consecutive hardware runs all returned a 17,997-byte JSON file — exactly SSM's 24,000-character StandardOutputContent limit applied to a base64-encoded 18 KB payload. pytest-benchmark stores every raw timing sample; a 5-benchmark JSON was ~500 KB. Nine different approaches to stripping the file before fetching — single-quote Python args, double-quote args, base64-encoded scripts, set -e compound commands — all silently succeeded at the SSM level while producing empty stdout. Root cause was never definitively identified (the strip script's file writes never persisted, for reasons the remote execution environment didn't surface). Fix: a single SSM command reads, strips, and base64-encodes the JSON in memory via Python, writing ~400 chars to stdout instead of 18 KB.

Thread C gather regression. Profiling showed the radix-4 driver's reshape + permute + .contiguous() + reshape chain costs ~97 µs per stage — ~10% of total. The intuition: replace these four PyTorch ops with a single precomputed flat-index gather, reducing XLA graph nodes from 8 to 2 per stage. Hardware result: 11–39% slower across all tested N. Neuron's transpose HLO is a hardware-optimized DMA permute path; GatherOp with non-affine indices is not. The permute stays.

Radix-8 kernel-local scratch buffer. The initial radix-8 kernel used an internal nl.ndarray(buffer=nl.shared_hbm) scratch to bridge two phases: twiddle multiply (Vector engine, writes scratch) and W₈ matmul (Tensor engine, reads scratch via nl.load_transpose2d). This compiled and passed simulator tests, then failed NEFF compilation on hardware with no useful error. The constraint, discovered by inspection: nl.load_transpose2d in NKI 0.3.0 only accepts function-argument HBM tensors as its source — kernel-local allocations are not addressable. Fix: move twiddle multiply to the PyTorch driver as an element-wise XLA op, and make the NKI kernel W₈-only, taking external HBM tensors as input. A concrete upstream ask: nl.load_transpose2d should either accept kernel-local shared_hbm allocations or produce a compile-time error that names the constraint instead of failing silently.

Numbers

Hardware bench: trn1.2xlarge, Neuron SDK 2.29.0, NKI 0.3.0, 2026-04-20.

N DFT-GEMM (µs) Radix-8 (µs) Radix-4 (µs) Butterfly (µs) Auto-dispatch path
64 ~1 883 3 402 4 254 4 767 DFT-GEMM
256 ~1 882 5 446 6 067 DFT-GEMM
512 n/a 4 483 ~6 600 (est.) Radix-8
1024 n/a 6 628 7 399 Radix-4
4096 n/a 5 917 8 424 9 387 Radix-8

DFT-GEMM values at N=64/256 are from v0.12 on SDK 2.24; other values on SDK 2.29. N=512 butterfly is extrapolated from per-stage timing (~739 µs/stage × 9 stages).

Forward error matters as much as wall-clock time. The precision="fast" butterfly accumulates O(u log₂N) FP32 rounding error; precision="kahan" (Dekker 2Prod compensated complex multiply) cuts it by 7–8× at no algorithmic change:

N fast rel error kahan rel error improvement
256 1.41e-6 1.92e-7 7.3×
512 2.15e-6 2.69e-7 8.0×
1024 2.04e-6 3.02e-7 6.8×
4096 3.60e-6 4.55e-7 7.9×

Both paths are below 1e-3. Use set_precision("kahan") when your forward-error budget is tight — iterative solvers, spectral methods, anything that chains multiple FFTs.

Where Trainium is well-indexed for this: the Tensor engine maps naturally to both the small-N DFT-GEMM case (one large matmul) and the medium-N W₈ case (batched 8×8 matmuls across all groups simultaneously). The partition dimension absorbs all total_groups rows in one pass.

Where it is not well-indexed: N values that are not powers of 2, 4, or 8 fall to the butterfly path, which stays Vector-engine throughout. Non-power-of-2 FFTs use Bluestein's algorithm (three power-of-2 FFTs), which chains errors through the 3-FFT pipeline. precision="double" is the escape hatch for those paths, at the cost of CPU roundtripping since Trainium's PSUM is always FP32.

What's next

  • Mixed-radix Stockham shipped (v0.16). N=1024 ([8,8,4,4], 4 stages, −15% vs radix-4) and N=2048 ([8,8,8,4], 4 stages, first Stockham coverage) are live. _mixed_radix_plan(n) finds the optimal [8^a, 4^b] decomposition for any power-of-2 N.
  • Iterative FFT refinement (research direction). Compute FFT in BF16 using the Tensor engine; use the FP32 PSUM accumulator as a residual buffer; apply one correction step. The result: BF16 throughput, near-FP32 accuracy — enabled by PSUM being a structural FP32 accumulator that no one has exploited for FFT on a production deterministic systolic array. This is the next post.
  • Multi-NeuronCore distribution. Large N FFTs (N > 4096) partitioned across NeuronCores. Linear speedup with core count; CLAUDE.md "future" since v0.11.

Issues tracking the above are open on trnsci/trnsci.

Takeaway

FFT on Trainium is a problem about matching transform structure to engine capabilities, not about minimizing FLOP count. The DFT matrix IS a matmul — at small N, skipping the butterfly entirely and using one Tensor-engine call is the right answer. At larger N where O(N²) error is prohibitive, the question becomes which radix decomposition lets the Tensor engine participate in the per-stage computation. W₄ uses {1, −1, i, −i} and needs no multiplications; W₈ uses irrational entries and does. That distinction — not the stage count, not the arithmetic complexity — is what determined the dispatch hierarchy. The Tensor engine is idle during radix-4 stages; it is active during radix-8. Hardware validated both claims.

Comments