Skip to content

NKI Backend

The NKI dispatch layer controls whether BLAS operations run on the native Trainium Tensor Engine or fall back to PyTorch.

Backend selection

import trnblas

trnblas.set_backend("auto")     # NKI on Trainium, PyTorch elsewhere (default)
trnblas.set_backend("pytorch")  # force PyTorch fallback
trnblas.set_backend("nki")      # force NKI (requires nki>=0.3.0)

trnblas.HAS_NKI is True when the nki package (NKI 0.3.0+, Neuron SDK 2.29+) is importable. trnblas uses the canonical nki.* namespace directly; the legacy neuronxcc.nki.* shim is not used.

Environment variables

Variable Effect
TRNBLAS_REQUIRE_NKI=1 Re-raise kernel exceptions instead of falling back to torch.matmul — surfaces silent breakage in validation runs.
TRNBLAS_USE_SIMULATOR=1 Route kernel dispatch through nki.simulate(kernel)(numpy_args) on CPU. Bypasses torch_xla + NEFF compile; used for fast correctness iteration. See docs/developing_kernels.md.

GEMM kernel

trnblas.nki.nki_gemm — NKI-dispatched GEMM with stationary tile reuse:

  • A tile (128×128) loaded once to SBUF, held stationary in the systolic array.
  • B tiles streamed through as the moving operand.
  • Partial products accumulated in PSUM.
  • HBM padding: M/K rounded to 128, N rounded to 512 (when N > 512); kernel uses TILE_N = min(N, 512) for single-tile small-N. Result is sliced back to the original (M, N).

Status: validated on trn1.2xlarge with neuronxcc 2.24. 17/17 hardware tests pass. Cached-NEFF speedup ~2.8× on warm runs; per-call kernel timings land at 1.6 ms (512³) and 4.5 ms (1024³) on warm cache.

Batched GEMM

trnblas.nki.nki_batched_gemm — per-slice dispatch through the cached 2D _gemm_kernel. Every slice after the first hits the NEFF cache (identical signature), so per-slice cost is HBM transfer + Tensor Engine dispatch only.

Fused MP2 energy-reduction kernel

trnblas.nki.nki_mp2_energy(T_flat, eps_occ_chunk, eps_occ_full, eps_vir) — computes

E_chunk = Σ_{i, j, a, b} T[i,j,a,b] * (2·T[i,j,a,b] - T[i,j,b,a]) / Δ[i,j,a,b]

where T[i,j,a,b] = T_flat[i*nvir + a, j*nvir + b] and Δ[i,j,a,b] = eps_occ[i] + eps_occ[j] - eps_vir[a] - eps_vir[b].

Partition-dim sub-tiling: P_TILE picked at trace time as the largest divisor of nvir that is ≤ 128 (the NKI partition limit). All three DF-MP2 bench shapes (nvir = 112 / 448 / 672) share P_TILE = 112, so one compiled kernel serves all.

Returns: (ic, nocc) fp32 tensor of per-(i,j) partials; caller reduces host-side via .sum().

Status: on-hardware correctness validated across nvir ∈ {8, 16, 64, 256, 448} (all 5 tests pass on trn1). Perf caveat: at medium DF-MP2 shape the kernel matches (not beats) the torch reduction — per-(i,j) dispatch/load chain swamps the compute savings. Production examples/df_mp2.py keeps the torch path; further perf work tracked under #15.