Skip to content

trnfft

trnfft: what trn1 and trn2 tell us about the Ozaki frontier

The v0.18 and v0.19 posts claimed hardware precision of O(sqrt(N)·u_bf16²) ≈ 1.6e-5 and O(sqrt(N)·u_bf16⁴) ≈ 2e-9 for the Ozaki modes. The trn1 hardware measurement says those numbers are wrong — both modes deliver ~1.7e-3, equivalent to single-pass BF16. trn2 was then tested with the same characterization. The result is identical. Both generations. The conclusion is still not "Ozaki is a dead end" — but the generational gap theory needs revision.

trnfft: the residual must stay FP32

v0.19 ships precision="ozaki_hq" — six BF16 matmuls that together reach O(sqrt(N)·u_bf16⁴) ≈ 2e-9 relative error, near-FP64 accuracy on the Tensor Engine. The implementation is a 40-line extension of the v0.18 Ozaki scheme. There is one non-obvious constraint that the algorithm turns entirely on. Getting it wrong gives you 1-level accuracy out of a 2-level design, silently, with no error.

trnfft: the missing dtype and the 80× cliff

The first working version of trnfft's NKI butterfly kernel passed every correctness test. It was also 80× slower than the PyTorch fallback for batched STFT — a regression so large the benchmark was assumed to be broken. It wasn't. The kernel was calling NKI once per batch row in a Python loop, paying full XLA graph compilation overhead for each row.

That discovery, and the fix, is what Phase 1 is mostly about.

trnfft: the FP32 accumulator you didn't know you had

trnfft v0.17 ships two new precision modes — "bf16" and "bf16_refined" — for the DFT-GEMM fast path. The headline numbers: 1.4–1.5× faster than FP32 at N=64–256 on trn1, with near-FP32 accuracy after one correction step. The mechanism is an architectural property of Trainium that was already present in every kernel, just never exploited.

trnfft: FFT is a GEMM, and then it isn't

trnfft v0.12–v0.15 shipped three new FFT dispatch paths — DFT-GEMM, Stockham radix-4 with twiddle precomputation, and Stockham radix-8 with a Tensor-engine W₈ kernel — producing 20–37% improvements over the butterfly baseline at medium and large N. The architectural argument running through all three is the same: on Trainium, the bottleneck is not arithmetic but engine utilization and kernel launches. Whether that argument holds at a given N, and at what cost, is where most of the engineering work actually lived.

trnfft: FFT on hardware that doesn't want to be an FFT engine

Between v0.7 and v0.12, trnfft's NKI story moved from one per-row butterfly dispatch into a batched butterfly plus a fused DFT-as-GEMM fast path, with opt-in Kahan-compensated precision. All of it is hardware-validated on trn1.2xlarge. What landed on silicon looks very little like cuFFT: no complex dtype, no thread-per-butterfly, no bit-reversal in the fast path. What Trainium's architecture — four programmable engines, a fixed 128-partition × 512-moving tile, explicit SBUF/PSUM memory — suggested was a different decomposition, and this post is the retrospective on what that turned out to be.