Architecture¶
Layer diagram¶
+--------------------------------------------+
| User Code / Model |
+--------------------------------------------+
| trnfft.api (torch.fft API) |
| fft() ifft() rfft() stft() fftn() |
+--------------------------------------------+
| trnfft.fft_core | trnfft.nn |
| Cooley-Tukey | ComplexLinear |
| Bluestein | ComplexConv1d |
| Plan caching | ComplexModReLU |
+------------------------+-------------------+
| trnfft.nki.dispatch |
| "auto" | "pytorch" | "nki" |
+--------------------------------------------+
| PyTorch ops | NKI kernels |
| (any device) | (Trainium only) |
+------------------+------------------------+
Key design decisions¶
-
Split real/imaginary — Trainium has no complex dtype.
ComplexTensorwraps paired real tensors. -
Iterative Cooley-Tukey — Radix-2 decimation-in-time with bit-reversal. Each butterfly stage vectorized across all groups.
-
Bluestein for arbitrary sizes — Converts arbitrary-N DFT to circular convolution via three power-of-2 FFTs.
-
Plan-based execution — Like FFTW/cuFFT. Plans cached by
(size, inverse). -
NKI dispatch — Auto-detects Neuron hardware. Falls back to PyTorch on CPU/GPU.
File structure¶
trnfft/
├── __init__.py # Public API re-exports
├── api.py # torch.fft-compatible functions
├── complex.py # ComplexTensor
├── fft_core.py # Cooley-Tukey + Bluestein
├── nn.py # Complex NN layers
├── plan.py # FFTPlan caching
├── precision.py # set_precision / get_precision
└── nki/
├── dispatch.py # Backend dispatch + GEMM/mul kernels
├── butterfly.py # NKI butterfly kernel (fast + Kahan variants)
├── autograd.py # torch.autograd.Function wrappers for kernels
└── multicore.py # Multi-NeuronCore (scaffold)
Precision modes¶
Introduced in v0.11.0. Select via trnfft.set_precision(mode). The mode
mainly changes Bluestein's arbitrary-size FFT path; power-of-2 FFTs run
through Cooley-Tukey and are less error-prone.
| Mode | Scope | Typical rel error (Bluestein) | Cost |
|---|---|---|---|
"fast" (default) |
FP32 throughout | ~2e-2 at N ∈ [500, 1000]; worse beyond | Baseline |
"kahan" |
Compensated Kahan/Dekker multiplies in Bluestein chirps; twoProd variant of NKI butterfly kernel | Matches "fast" on CPU (chirps aren't the error hotspot); meaningful reduction on NKI butterfly | ~2× butterfly op count on NKI |
"double" |
Promote entire Bluestein host math to FP64, cast back on exit | ~1e-11 at any N; 10+ orders of magnitude better than "fast" | Bluestein uses PyTorch FP64 (no NKI); power-of-2 FFTs unaffected |
Recommendation: use "fast" for STFT, batched FFT, and any
power-of-2 case. Use "double" when you need precision for Bluestein
sizes (non-power-of-2, N ≥ 500). "kahan" is useful only on Trainium
hardware where FP64 isn't available — on CPU it's equivalent to "fast"
for this particular API.
Source of the error: Bluestein runs three power-of-2 FFTs in series.
Each butterfly stage accumulates rounding error in the complex multiply
prod_re = t_re*o_re - t_im*o_im, and the three-FFT chain compounds
that loss. The chirp multiplications on the host are O(N) — compensating
them alone (as "kahan" on CPU does) doesn't reach the dominant O(N log N)
error source in the FFT itself.