Skip to content

trnblas: fusing DF-MP2 energy reduction into one NKI kernel

trnblas v0.4.0 shipped hardware-validated NKI kernels for GEMM, SYRK, and a fused MP2 energy reduction on trn1. End-to-end density-fitted MP2 matches PySCF to 10 µHa (1×10⁻⁵ Ha) on H₂O, CH₄, and NH₃ at cc-pVDZ. The interesting story isn't the GEMM. It's the fused energy kernel — a single NKI pass that holds the contraction, the orbital-denominator division, and the scalar sum-reduction SBUF-resident, and how the choice to build it looks nothing like a cuBLAS port.

The problem

Density-fitted MP2 is a second-order perturbation-theory method used across quantum chemistry for computing correlation energies on systems too large for coupled-cluster. Its computational signature is dominated by a single expression, evaluated per pair of occupied orbitals (i, j) and summed over the virtual orbital pair (a, b):

E_MP2 = Σ_{i,j,a,b} T[i,j,a,b] · (2·T[i,j,a,b] − T[i,j,b,a]) / Δ[i,j,a,b]

where T[i,j,a,b] = Σ_P B[i,a,P]·B[j,b,P] (a three-center tensor contraction over the auxiliary basis) and Δ[i,j,a,b] = ε_i + ε_j − ε_a − ε_b is the orbital-energy denominator. The tensor T at realistic basis sizes is dense and large — for a medium benchmark shape (nocc=64, nvir=448), the full per-chunk T is 16.6 GB in fp32.

The cuBLAS / PyTorch-on-GPU mental model for this pattern is a chain of per-op kernel calls: one GEMM to produce T, three elementwise kernels for 2T − Tᵀ and the division, one reduction. Each call writes its intermediate to HBM before the next one reads it back. For the medium shape, that's ~50 GB of intermediate HBM traffic per chunk — dominated not by arithmetic but by memory round-trips. Profiled on trn1 before any fusion, the energy-reduction step takes 8.03 s of a 9.79 s total wall time, ~82 % of the workload.

cuBLAS cannot collapse this into one call. Its primitive is gemmStridedBatched and similar — matrix products. The multiply- subtract-divide-sum chain sits between cuBLAS calls, in a separately written CUDA reduction kernel a user ships alongside. A cuBLAS-shaped library has no natural place for it.

What the architecture suggests

NKI kernels compile to NEFF — Neuron Executable File Format — which expresses the whole kernel as a single scheduled program across four engines (Tensor, Vector, Scalar, GpSimd), with explicit SBUF (on-chip 128-partition scratchpad) and PSUM (32-bit tensor accumulator) as named memories. Three properties of this model point at a different design than cuBLAS+reduction:

SBUF-resident intermediates cost nothing extra. Once a tile lands in SBUF, every engine that reads it reads from on-chip memory at orders of magnitude higher bandwidth than HBM. An expression chain T * (2T − Tᵀ) / Δ that would ping-pong through HBM on a per-op model sits inside one SBUF tile for its entire lifetime if the kernel is written as one @nki.jit function.

Partition dim is the native parallel axis, not a thread block. NKI's partition dim is 128 physical lanes. A single nl.subtract(a, b) on a (128, N) tile runs across all 128 partitions simultaneously; there is no intra-tile thread-block boundary to cross. Reductions have an asymmetry: the free dim is reducible, the partition dim is not (directly). This shapes the tile layout more than anything else.

NEFF cache amortizes compile across invocations. Every @nki.jit function with the same trace-time shape compiles once, then every subsequent invocation hits the cache and pays only dispatch + execution. For a pair-energy loop invoking the same kernel nocc² times (4 096 to 9 216 pairs at the bench shapes), the compile cost divides across thousands of calls. The NEFF cache is what makes "one kernel per pair" cheap in the first place.

These three point at the same design: one kernel, SBUF-resident intermediates through the full expression chain, and a tiny host- side sum at the end.

The approach

trnblas's fused MP2 kernel — _mp2_energy_kernel — is that design made concrete:

  • One @nki.jit function, one NEFF. Signature takes T_flat, the three ε vectors, returns a (P_TILE, IC, NOCC) partial.
  • Tile layout fixes partition-dim alignment through the whole computation. The virtual-orbital index a sits on the partition dim; b sits on the free dim. T and its transpose both load with the same partition semantics. Nothing reshapes across partition axes mid-kernel.
  • Δ is built SBUF-resident per strip from the three ε vectors, never materialized in HBM. This is the step that cuBLAS can't express at all — Δ is not a matrix product, just an outer-sum of three small vectors.
  • Free-dim sum reduces each strip. The result is (P_TILE, 1) per strip. Strips accumulate into an SBUF (P_TILE, NSTRIP) tile and reduce to (P_TILE, 1) after the strip loop — one HBM store per (i, j) pair, not per strip. NKI can't reduce along the partition dim, so the caller does one final .sum() on the tiny (P_TILE, IC, NOCC) partial (≤ 258 KB at the large bench shape; host-side noise).
flowchart LR
    subgraph HBM
      T["T_flat"]
      E["ε_occ, ε_vir"]
      P["e_partial<br/>(P_TILE, IC, NOCC)"]
    end
    subgraph SBUF["One NKI pass · per (i, j) pair"]
      L["load T tile +<br/>load_transpose2d T.T"]
      D["build Δ SBUF-resident<br/>from ε via broadcast_to"]
      M["T · (2T − Tᵀ) · 1/Δ<br/>Vector Engine"]
      R["nl.sum axis=1<br/>free-dim reduce"]
    end
    T --> L
    E --> D
    L --> M
    D --> M
    M --> R
    R -->|one store per pair| P
    P --> S["host .sum()<br/>→ scalar E_MP2"]

Implementation

The core shape, simplified (full source linked above):

@nki.jit
def _mp2_energy_kernel(T_flat, eps_occ_chunk, eps_occ_full,
                       eps_vir_col, eps_vir_row):
    NVIR = eps_vir_row.shape[1]
    IC = eps_occ_chunk.shape[1]
    NOCC = eps_occ_full.shape[1]
    P_TILE = min(NVIR, 128)
    while NVIR % P_TILE != 0:
        P_TILE -= 1
    NSTRIP = NVIR // P_TILE

    e_partial = nl.ndarray((P_TILE, IC, NOCC),
                           dtype=nl.float32,
                           buffer=nl.shared_hbm)
    ev_row = nl.load(eps_vir_row[0:1, 0:NVIR])

    for i in nl.affine_range(IC):
        eo_i = nl.load(eps_occ_chunk[0:1, i:i+1])
        for j in nl.affine_range(NOCC):
            eo_j = nl.load(eps_occ_full[0:1, j:j+1])
            eo_sum = nl.add(eo_i, eo_j)
            acc_rows = nl.zeros((P_TILE, NSTRIP),
                                dtype=nl.float32, buffer=nl.sbuf)
            for s in nl.affine_range(NSTRIP):
                a_off = s * P_TILE
                t = nl.load(T_flat[i*NVIR + a_off : i*NVIR + a_off + P_TILE,
                                    j*NVIR : (j+1)*NVIR])
                t_T = nl.load_transpose2d(T_flat[
                    i*NVIR : (i+1)*NVIR,
                    j*NVIR + a_off : j*NVIR + a_off + P_TILE])
                ev_col = nl.load(eps_vir_col[a_off:a_off+P_TILE, 0:1])

                # Δ built SBUF-resident: all three eps operands
                # lifted to (P_TILE, NVIR) for partition-matched arith.
                eo_sum_bc = nl.broadcast_to(eo_sum, (P_TILE, NVIR))
                ev_col_bc = nl.broadcast_to(ev_col, (P_TILE, NVIR))
                ev_row_bc = nl.broadcast_to(ev_row, (P_TILE, NVIR))
                denom = nl.subtract(nl.subtract(eo_sum_bc, ev_col_bc),
                                    ev_row_bc)

                # The whole expression chain, SBUF-resident.
                term = nl.multiply(
                    nl.multiply(t, nl.subtract(nl.multiply(t, 2.0), t_T)),
                    nl.reciprocal(denom),
                )
                strip_partial = nl.sum(term, axis=1, keepdims=True)
                acc_rows[0:P_TILE, s:s+1] = strip_partial

            acc_row = nl.sum(acc_rows, axis=1, keepdims=True)
            nl.store(e_partial[0:P_TILE, i:i+1, j:j+1], value=acc_row)

    return e_partial

What didn't work

Four items.

The examples/df_mp2.py revert. The first attempt to make nki_mp2_energy the default (#15) flipped the DF-MP2 example's energy-reduction call to the fused kernel. The 1.48× speedup on the energy step was real but below the 3× minimum target laid out in the kernel's design RFC, and the CHANGELOG framing accompanying the flip over-claimed. The example's default was reverted back to the torch reduction, --fused-energy added as an opt-in flag, and the CHANGELOG rewritten with the measured numbers honestly. The kernel works; the speedup isn't yet large enough to justify being the default.

NKI 0.3.0 partition-broadcast strictness. Neuron SDK 2.29's MLIR verifier is stricter than 2.28 on tensor-tensor arithmetic with mismatched partition dims. An earlier version of the kernel built Δ with two subtracts of shape (1,1) − (P_TILE, 1) — rejected wholesale in 0.3.0. The 5 MP2 test cases were re-skipped for a release before the fix landed (commit c1769c6): lift all three eps operands to (P_TILE, NVIR) via nl.broadcast_to before subtracting. The simulator gate catches four of the five 0.3.0 breaking-change classes; this one is MLIR- verifier-level and still requires hardware.

The v0.4.x "NKI" numbers were silent torch.matmul fallback. The SSM runners in v0.4.0 through v0.4.2 invoked the Neuron venv's python without prepending its bin/ to $PATH. torch_neuronx calls libneuronpjrt-path at import to locate the PJRT plugin; that binary lives in the venv's bin/, so it failed with FileNotFoundError, the _nki_*_impl try/except wrappers swallowed it, and every reported "trn1 NKI" warm number was actually trn1's 8-vCPU Xeon running torch.matmul. Correctness tests still passed — torch gives the same answer as the kernel — but perf attribution was wrong for three releases. v0.4.3 published the full retraction, added NkiFallbackWarning so future regressions surface without requiring TRNBLAS_REQUIRE_NKI=1, and added tests/test_nki_really_runs.py as an anti-regression gate.

The post-M2 perf-tuning rewrite didn't help. A follow-up PR (#32) hoisted the pair-invariant part of Δ out of the (i, j) loop, collapsing denom construction from 5 scheduled ops per iteration to 1 plus a pre-loop precompute. Measured result: 1.48× → 1.50× at medium, 1.47× → 1.49× at large. Within noise. The NEFF compiler was already doing this work on trnblas's behalf; the explicit hoist was a no-op. The profiler investigation (#33) meant to diagnose the remaining gap hit its own blocker: the 2.29 DLAMI's neuron-profile show-session tool rejects the trace format its own inspect command produces, and view --disable-ui requires InfluxDB the DLAMI doesn't pre-install. The remaining hypothesis (a cross-pair HBM-store fence) is queued as #35; until it's tested, the gap is documented, not solved.

Numbers

All on trn1.2xlarge, neuronxcc 2.24.5133, warm NEFF cache, v0.4.3-measured under real NKI dispatch.

Per-call kernel timings:

Op Shape Warm
NKI GEMM 1024 × 1024 × 1024 2.3 ms
NKI TRSM (DF-MP2) 2048 × 512 35.82 ms

Fused MP2 energy kernel (end-to-end energy step):

Shape Torch reduction Fused kernel Speedup
medium (nbasis=512, nocc=64, nvir=448) 8.03 s 5.43 s 1.48×
large (nbasis=768, nocc=96, nvir=672) 44.57 s 30.27 s 1.47×

Cross-platform DF-MP2 medium end-to-end:

Platform Warm wall TFLOPS
trn1.2xlarge 9.91 s 0.28
A10G g5.xlarge 0.266 s 10.3

A10G is 37× faster at the medium bench; the cost story shifts only at scales where memory bandwidth and multi-chip topology matter more than single-GPU throughput.

Chemistry validation (via test_df_mp2_pyscf.py): H₂O/STO-3G matches PySCF's mp.dfmp2.DFMP2 to 1 µHa (1×10⁻⁶ Ha). H₂O/cc-pVDZ, CH₄/cc-pVDZ, and NH₃/cc-pVDZ match to 10 µHa (1×10⁻⁵ Ha). These tolerances come from FP32 accumulation in the Tensor Engine against PySCF's FP64 reference. For closed- shell correlation energies in the ~10⁻¹ to 10⁻⁰ Ha range, 10⁻⁵ Ha relative error sits comfortably above chemical-accuracy thresholds (~1 mHa ≈ 2.6 kJ/mol).

What's next

  • #35 — cross-pair batching. The remaining perf hypothesis for the 1.48× gap: batching K pair partials in SBUF before one HBM store relaxes the per-pair synchronization fence. Phase 3 perf work.
  • #26 — NKI GEMM tile-shape autotuner. Phase 3. Plan-time tile selection feeds the NEFF cache key; measured-best tile per shape replaces the current fixed (128, 128, 512).
  • Phase 2 — double-double FP64 GEMM. Emulated FP64 via two FP32 values. Opens the door to workloads where 10⁻⁵ Ha relative error is not enough.
  • Phase 4 — tensor-parallel GEMM across NeuronCores. trn1.32xlarge has 16 chips; sharding a DF-MP2 chunk across them is the next architectural exercise.

Phase tracker: trnsci ROADMAP.

Takeaway

The fused MP2 energy kernel is what happens when a library takes Trainium's architecture seriously instead of porting cuBLAS primitives one-to-one. One kernel, one NEFF, one dispatch per pair, every intermediate SBUF-resident. cuBLAS can't express this — not because it's hard, but because its primitive is matrix products and this expression isn't one. Phase 1's measurable win (1.48×) is smaller than the 3× design target; the gap is documented and queued as Phase 3 work.

Comments