Skip to content

NKI Backend

trnfft auto-detects Trainium hardware and dispatches to optimized NKI kernels when available.

Backend control

import trnfft

trnfft.HAS_NKI          # True if neuronxcc is installed
trnfft.get_backend()    # Current setting: "auto", "pytorch", or "nki"
trnfft.set_backend("auto")     # NKI if available, else PyTorch (default)
trnfft.set_backend("pytorch")  # Force PyTorch (any device)
trnfft.set_backend("nki")      # Force NKI (fails if not on Trainium)

NKI kernels

Complex GEMM

Stationary tile reuse on the Tensor Engine systolic array:

  • Phase 1: A_real stationary, stream B_real and B_imag
  • Phase 2: A_imag stationary, stream -B_imag and B_real
  • 4 SBUF loads instead of 8 (50% fewer HBM transfers)

Fused complex multiply

Element-wise (a+bi)(c+di) in a single kernel invocation. Loads all 4 inputs in one pass, computes ac-bd and ad+bc in SBUF, writes 2 outputs. Replaces 6 separate HBM round-trips.

Butterfly FFT

Each Cooley-Tukey butterfly stage dispatches to an NKI kernel that processes all butterflies using the Vector Engine. Twiddle factors are preloaded to SBUF and reused across the batch dimension.

Architecture

+------------------+------------------------+
|  PyTorch ops     |  NKI kernels           |
|  (any device)    |  (Trainium only)       |
|  torch.matmul    |  nisa.nc_matmul        |
|  element-wise    |  Tensor Engine         |
|                  |  Vector Engine          |
|                  |  SBUF <-> PSUM pipeline |
+------------------+------------------------+