NKI Backend¶
trnfft auto-detects Trainium hardware and dispatches to optimized NKI kernels when available.
Backend control¶
import trnfft
trnfft.HAS_NKI # True if neuronxcc is installed
trnfft.get_backend() # Current setting: "auto", "pytorch", or "nki"
trnfft.set_backend("auto") # NKI if available, else PyTorch (default)
trnfft.set_backend("pytorch") # Force PyTorch (any device)
trnfft.set_backend("nki") # Force NKI (fails if not on Trainium)
NKI kernels¶
Complex GEMM¶
Stationary tile reuse on the Tensor Engine systolic array:
- Phase 1: A_real stationary, stream B_real and B_imag
- Phase 2: A_imag stationary, stream -B_imag and B_real
- 4 SBUF loads instead of 8 (50% fewer HBM transfers)
Fused complex multiply¶
Element-wise (a+bi)(c+di) in a single kernel invocation. Loads all 4 inputs in one pass, computes ac-bd and ad+bc in SBUF, writes 2 outputs. Replaces 6 separate HBM round-trips.
Butterfly FFT¶
Each Cooley-Tukey butterfly stage dispatches to an NKI kernel that processes all butterflies using the Vector Engine. Twiddle factors are preloaded to SBUF and reused across the batch dimension.
Architecture¶
+------------------+------------------------+
| PyTorch ops | NKI kernels |
| (any device) | (Trainium only) |
| torch.matmul | nisa.nc_matmul |
| element-wise | Tensor Engine |
| | Vector Engine |
| | SBUF <-> PSUM pipeline |
+------------------+------------------------+