Skip to content

NKI Backend

The NKI dispatch layer controls whether BLAS operations run on the native Trainium Tensor Engine or fall back to PyTorch.

Backend selection

import trnblas

trnblas.set_backend("auto")     # NKI on Trainium, PyTorch elsewhere (default)
trnblas.set_backend("pytorch")  # force PyTorch fallback
trnblas.set_backend("nki")      # force NKI (requires nki>=0.3.0)

trnblas.HAS_NKI is True when the nki package (NKI 0.3.0+, Neuron SDK 2.29+) is importable. trnblas uses the canonical nki.* namespace directly; the legacy neuronxcc.nki.* shim is not used.

Environment variables

Variable Effect
TRNBLAS_REQUIRE_NKI=1 Re-raise kernel exceptions instead of falling back to torch.matmul — surfaces silent breakage in validation runs.
TRNBLAS_USE_SIMULATOR=1 Route kernel dispatch through nki.simulate(kernel)(numpy_args) on CPU. Bypasses torch_xla + NEFF compile; used for fast correctness iteration. See docs/developing_kernels.md.

GEMM kernel

trnblas.nki.nki_gemm — NKI-dispatched GEMM with stationary tile reuse:

  • A tile (128×128) loaded once to SBUF, held stationary in the systolic array.
  • B tiles streamed through as the moving operand.
  • Partial products accumulated in PSUM.
  • HBM padding: M/K rounded to 128, N rounded to 512 (when N > 512); kernel uses TILE_N = min(N, 512) for single-tile small-N. Result is sliced back to the original (M, N).

Status: validated on trn1.2xlarge with neuronxcc 2.24. 17/17 hardware tests pass. Cached-NEFF speedup ~2.8× on warm runs; per-call kernel timings land at 1.6 ms (512³) and 4.5 ms (1024³) on warm cache.

Batched GEMM

trnblas.nki.nki_batched_gemm — per-slice dispatch through the cached 2D _gemm_kernel. Every slice after the first hits the NEFF cache (identical signature), so per-slice cost is HBM transfer + Tensor Engine dispatch only.

Fused MP2 energy-reduction kernel

trnblas.nki.nki_mp2_energy(T_flat, eps_occ_chunk, eps_occ_full, eps_vir) — computes

E_chunk = Σ_{i, j, a, b} T[i,j,a,b] * (2·T[i,j,a,b] - T[i,j,b,a]) / Δ[i,j,a,b]

where T[i,j,a,b] = T_flat[i*nvir + a, j*nvir + b] and Δ[i,j,a,b] = eps_occ[i] + eps_occ[j] - eps_vir[a] - eps_vir[b].

Partition-dim sub-tiling: P_TILE picked at trace time as the largest divisor of nvir that is ≤ 128 (the NKI partition limit). All three DF-MP2 bench shapes (nvir = 112 / 448 / 672) share P_TILE = 112, so one compiled kernel serves all.

Returns: (P_TILE, ic, nocc) fp32 tensor of per-partition partials; caller reduces host-side via .sum().

Status: on-hardware correctness validated across nvir ∈ {8, 16, 64, 256, 448} — all 5 tests pass on trn1 (under the NKI 0.3.0 MLIR verifier after the #15 M2 broadcast-fix commit c1769c6).

Measured perf (trn1.2xlarge, warm NEFF cache, SDK 2.29):

DF-MP2 shape Energy step — torch Energy step — fused Speedup
medium (nbasis=512, nocc=64, nvir=448) 8.03 s 5.43 s 1.48×
large (nbasis=768, nocc=96, nvir=672) 44.57 s 30.27 s 1.47×

Bit-parity with the torch reference at both shapes (atol=1e-4, rtol=1e-4). The fused kernel beats torch consistently but the speedup ratio is roughly shape-invariant — per-(i,j) dispatch cost scales with nocc² same as the torch path, so the RFC's 3–5× target isn't hit. The remaining perf work (larger free-dim tiles, atomic-add variant, multi-engine pipeline evidence) is tracked under #15. Production examples/df_mp2.py exposes the kernel via --fused-energy; default remains the torch path until a future milestone hits the 3× bar.