Skip to content

trnfft: the NeuronCore is the unit of parallelism

v0.20 ships multi-NeuronCore batch-split FFT. B transforms dispatched across C NeuronCores gives B/C per core — no shared state, no communication, linear scaling up to B = C. The architecture made this obvious. Getting the dispatch infrastructure right was less so.

The problem

precision="bf16" (v0.17) delivers ~1.2 ms per transform at N=256 on trn1. For a single spectrogram or a one-shot filter bank that is fine. For workloads that run thousands of transforms — STFT over a multi-second audio stream, spectral PDE solvers, RF signal pipelines — single-core throughput caps the pipeline. trn1 has 16 NeuronCores. Until v0.20 they were all idle except one.

The question: can a batch of B transforms run B/C transforms per core with C NeuronCores?

What the architecture suggests

A GPU spreads work within a kernel via threads and warps; NeuronCores are independent processors with separate SBUF, PSUM, and Tensor Engines. There is no shared scratchpad between NeuronCores. The natural parallelism is data-parallel, not instruction-parallel — dispatch an independent transform to each core, collect results.

The Tensor Engine and NEFF cache make this more interesting than it sounds. On first call, torch_neuronx.trace compiles the FFT kernel to a NEFF artifact. Subsequent calls on the same shape hit the cache and skip compilation. torch_neuronx.DataParallel wraps one compiled model and replicates it across C cores. The compilation cost is paid once per (N, inverse, num_cores) triple; after warmup the per-call overhead is dispatch latency only.

This is the framing cuFFT's cufftPlanMany works in: allocate once, execute many times. On Trainium the plan is a NEFF artifact, and DataParallel is the "plan many" layer.

The approach

class _FFTModule(torch.nn.Module):
    def forward(self, real, imag):
        y = fft_core(ComplexTensor(real, imag), inverse=self.inverse)
        return y.real, y.imag

# Compile once
traced = torch.jit.trace(module, [sample_real, sample_imag])
neuron_model = torch_neuronx.trace(traced, [sample_real, sample_imag])
dp_model = torch_neuronx.DataParallel(neuron_model)
_dp_model_cache[(n, inverse, num_cores)] = dp_model

The batch split itself is three lines:

real_shards = x.real.chunk(num_cores, dim=0)
imag_shards = x.imag.chunk(num_cores, dim=0)
results = [dp_model(r, i) for r, i in zip(real_shards, imag_shards, strict=True)]

On CPU (no torch_neuronx): the same chunk-and-dispatch loop runs sequentially. The output is identical; there's no hardware parallelism, but the code path is tested and the architecture is clear.

What didn't work

The import torch_neuronx problem. The initial implementation used a try/except to check availability:

try:
    import torch_neuronx
    results = _neuron_dp_dispatch(...)
except ImportError:
    results = [...]

Ruff's F401 rule flagged torch_neuronx as "imported but unused" — because from ruff's static analysis, torch_neuronx isn't referenced by name after the import (the dispatch goes through _neuron_dp_dispatch). The fix: importlib.util.find_spec("torch_neuronx") at module load time, then branch on HAS_TORCH_NEURONX instead of a runtime try/except. The error message was accurate; the fix was non-obvious.

zip() strictness. B905 requires an explicit strict= argument on all zip() calls. strict=True is semantically correct here (real_shards and imag_shards always have the same length from the same chunk call), but ruff won't infer that. Added strict=True throughout.

First-call compilation cost. On hardware, torch_neuronx.trace takes 30–90 s per (N, inverse, num_cores) key. This is a property of the Neuron compiler, not of the dispatch design — it compiles a full NEFF artifact and writes it to the NEFF cache on disk. After warmup (or after the NEFF cache warms across runs), subsequent calls take milliseconds. The model cache in _dp_model_cache prevents re-compilation within a session; the NEFF disk cache persists across sessions. The first call per shape per instance type is slow. This needs a note in the API docs; it is not a bug, but it surprises users who expect multi_core_fft to be faster than fft on the first call.

Upstream ask (Neuron SDK): A torch_neuronx.is_available() function analogous to torch.cuda.is_available() would be cleaner than importlib.util.find_spec. The recommended pattern in the current SDK docs is to catch ImportError, which triggers static analysis warnings in lint-enforced codebases.

Numbers

Hardware validation pending. The batch-split architecture is CPU-verified: 16 sizes and core counts, roundtrip tested, shape contracts confirmed. Theoretical throughput for a batched workload with B=16 transforms at N=256 across C=16 NeuronCores: ~1.2 ms total (≈ single-core per-transform latency, all 16 in parallel). The actual number depends on dispatch overhead, NEFF cache state, and batch balance — hardware measurement is the next step.

Fit assessment: multi-NeuronCore batch split is well-indexed for workloads with large batch dimensions (STFT, spectral solvers) where each transform is independent. It is poorly indexed for single-transform workloads, which is what trn1's training-first design is optimised for — a single N=256 FFT leaves 15 of 16 NeuronCores completely idle. The stage-parallel path (v0.21, row-column decomposition) addresses single-transform scaling for large N.

What's next

  • Hardware benchmark: set_multicore(True, num_cores=16) on trn1 with batch sizes {1, 4, 8, 16, 32}. Measure wall-clock throughput and per-core utilisation.
  • Stage parallelism (v0.21): single-transform multi-core via row-column FFT decomposition. N = n1×n2: column DFTs on C cores, twiddle multiply, row DFTs on C cores. The twiddle is the "inter-core exchange" and runs in FP32 on the host — no NKI inter-core primitive needed. See the v0.21 CHANGELOG.
  • ozaki_hq precision characterisation: hardware measurement of 2-level Ozaki rel error (theory: ~2e-9 at N=64) via scripts/run_precision_characterization.sh.

Full roadmap: trnsci/trnsci.

Takeaway

The NeuronCore is the unit of data parallelism on Trainium, not the thread or warp. The NEFF cache is the "plan" abstraction — compile once, dispatch many. torch_neuronx.DataParallel wraps both into a model-parallel interface that looks like any other PyTorch DataParallel, except the first call compiles a NEFF artifact rather than allocating a CUDA kernel. The throughput model is simple: B transforms, C cores, B/C per core, linear scaling up to B = C. What's left is the single-transform case, where the row-column decomposition replaces the missing inter-core allreduce with a host-side twiddle multiply.

Comments