Skip to content

Benchmarks

v0.4.x trn1 numbers were CPU torch.matmul, not NKI (fixed in v0.4.3)

Releases v0.4.0 / v0.4.1 / v0.4.2 published "trn1 NKI" tables in this page and in CHANGELOG. A PJRT-plugin path resolution bug (our SSM runners didn't put the Neuron venv's bin/ on $PATH) caused every NKI dispatch to fail with FileNotFoundError: 'libneuronpjrt-path'; the _nki_*_impl.try/except wrappers silently fell back to torch.matmul for every one of those runs. As a result, each "trn1 NKI" warm number on this page through v0.4.2 reflects trn1's 8-vCPU Xeon, not the Trainium Tensor Engine.

Fix landed in v0.4.3 (commit d1b481f): PATH prepend in SSM runners + NkiFallbackWarning + test_nki_really_runs.py that forces TRNBLAS_REQUIRE_NKI=1. The tables below are re-measured from the same commit under real NKI dispatch (NEFF compile visible on cold, 10-15000× cold/warm ratios confirm the kernel actually runs).

The MP2 energy kernel (trnblas.nki.nki_mp2_energy) turned out to have a partition-limit bug that was masked by the silent fallback; its tests are skipped pending rewrite (tracked in #15). Not in the production DF-MP2 path.

All numbers on trn1.2xlarge, neuronxcc 2.24.5133, warm NEFF cache unless noted.

NKI GEMM — per-call kernel timing

Warm cache, mean of 5 calls. Aligned shapes (multiples of 128). Real NKI dispatch verified — test_compile_vs_cache_timing[1024³] reports cold=26.7ms warm=2.3ms speedup=11.8×, which is a NEFF-compile signature not reproducible on CPU.

Shape (M×K×N) Warm
512 × 512 × 512 1.3 ms
1024 × 1024 × 1024 2.3 ms

NKI TRSM — per-call timing (#19)

trnblas.trsm on Trainium uses a blocked panel algorithm: diagonal panels solved via torch.linalg.solve_triangular (tiny P×P, intrinsically sequential); trailing off-diagonal updates run through nki_gemm (dominant work for large M). Block size fixed at 128; autotuning is Phase 3 work (#26). Correctness: 7/7 @pytest.mark.neuron tests pass on trn1 across {lower, upper} × {trans, not} + unit-diag.

Warm-cache per-call timings (mean of 5, using the DF-MP2 call pattern uplo="lower", trans=True; real NKI + trailing GEMM, v0.4.3-measured):

Shape (M × N) trn1 NKI warm trn1 TFLOPS A10G warm A10G TFLOPS A10G vs trn1
512 × 512 5.59 ms 0.02 0.21 ms 0.65 27×
1024 × 512 13.27 ms 0.04 0.36 ms 1.50 37×
1024 × 1024 18.72 ms 0.06 0.47 ms 2.29 40×
2048 × 512 35.82 ms 0.06 0.81 ms 2.67 44×

Cold (first call, includes NEFF compile of each trailing-GEMM tile signature): 5.8–12.8 s.

Lower TFLOPS than GEMM/SYRK is inherent to TRSM — the sequential panel solve limits parallelism. On trn1 the blocked structure adds Python-loop + per-block nki_gemm dispatch overhead on top; closing that gap is a Phase 3 follow-up (autotuner #26 and eventually a pure NKI substitution kernel).

NKI SYRK — per-call timing (#18)

trnblas.syrk on Trainium dispatches to a dedicated kernel (single-A HBM load via two load_transpose2d calls) rather than gemm(A, A.T). Correctness: 7/7 @pytest.mark.neuron tests pass on trn1; outputs match torch.matmul(A, A.T) to atol=1e-3, rtol=1e-4.

Warm-cache per-call timings and effective TFLOPS (mean of 5 runs on real NKI, v0.4.3-measured):

Shape (M×K) trn1 NKI warm trn1 TFLOPS A10G warm A10G TFLOPS A10G vs trn1
512×512 2.14 ms 0.13 0.11 ms 2.39 19×
1024×512 6.21 ms 0.17 0.16 ms 6.90 39×
1024×1024 5.71 ms 0.38 0.21 ms 10.07 27×
2048×512 23.89 ms 0.18 0.53 ms 8.11 45×

Cold (first call, includes NEFF compile): 1.6–11.4 s depending on shape.

Same pattern as the DF-MP2 end-to-end: the NKI kernel is correct and well-tiled, but A10G's cuBLAS remains ~30× faster per-call on Ampere-era single-GPU hardware at these sizes. Reproducible:

AWS_PROFILE=aws ./scripts/run_neuron_tests.sh     # trn1 correctness
# Then ad-hoc:
python examples/bench_syrk.py                     # cpu
python examples/bench_syrk.py --device cuda       # on a g5.xlarge

NKI batched GEMM

Warm cache, batch=32 of 256×128×256. Per-slice cost after the first is HBM transfer + Tensor Engine dispatch only (NEFF cache hit).

Metric Value
Total 39.3 ms
Per-slice 1.23 ms

DF-MP2 energy step — 3-way kernel comparison

Small shape (nbasis=128, nocc=16, nvir=112, naux=384 — 256 pairs), trn1.2xlarge, warm NEFF cache, v0.5.2:

Energy path Warm energy Warm total vs torch
torch (chunk-GEMM baseline) 0.018 s 0.096 s
fused-gemm (per-pair, v0.5.1) 0.381 s 0.454 s 21× slower
batched-pair (v0.5.2) 0.005 s 0.081 s 3.6× faster

The batched-pair kernel is the first energy path that beats the chunk-GEMM torch baseline end-to-end on the energy step. Cold energy (first call, includes NEFF compile): 6.7 s for batched-pair — paid once per instance lifetime, amortised across all subsequent calls.

Energies agree to FP32 noise: torch / fused-gemm = -1.619250e-04, batched-pair = -1.619249e-04.

Medium shape (nbasis=512, nocc=64, nvir=448, naux=1536 — 4096 pairs), trn1.2xlarge, warm NEFF cache:

Energy path Version Warm energy Warm total vs torch
torch (chunk-GEMM baseline) v0.5.2 8.035 s 9.795 s
fused-gemm (per-pair) v0.5.2 9.174 s 10.877 s 1.11× slower
batched-pair (CPU fallback†) v0.5.2 5.239 s 7.111 s 1.38× faster
batched-pair (chunked NKI‡) v0.5.4 1.536 s 4.784 s 5.2× faster

† v0.5.2: NEFF compile for medium-shape batched-pair kernel failed — nl.affine_range traces all loop iterations eagerly at compile time, producing an 18 GB XLA graph IR that exceeded the trn1 root volume's 16 GB free space. The warm row used the cached-failed-NEFF path → torch.matmul fallback on CPU.

‡ v0.5.4: Chunked dispatch (issue #46) — outer i-loop moved to Python; one @nki.jit call per i-row processes all nocc j-pairs. 64 i-dispatches × ~24 ms each = 1.536 s warm energy (XLA dispatch overhead dominates; Tensor Engine executes each kernel in ~1 ms). Cold energy = 34 min (77 NEFF compilations at ~27 s each; paid once per instance lifetime). Device HBM note (confirmed 2026-04-21): at medium shape, all 64 loaded energy NEFFs remain resident after the cold pass — 12.6 GB DMA spill + 900 MB model code = 15.9 GB total. A warm pass in the same process fails with Failed to allocate 1.500GB (usage: tensors) — no headroom remains. Warm timing must be measured in a separate process that loads from the EBS NEFF cache; run_bench.sh does this via --passes cold then --passes warm.

Energies agree to FP32 noise: -2.487220e+00 (torch), -2.487219e+00 (fused-gemm), -2.487221e+00 (batched-pair fallback), -2.487218e+00 (chunked NKI).

Reading the medium numbers: The chunked dispatch closes the medium-shape gap — 5.2× faster than torch chunk-GEMM, 3.4× faster than the v0.5.2 CPU fallback. Each i-dispatch is ~24 ms (XLA overhead + ~1 ms Tensor Engine execution), so 64 dispatches costs 1.536 s instead of the 409 s that a per-pair loop would require. The full-batch kernel's 18 GB XLA graph is replaced by 64 × ~1.4 GB graphs that compile and cache normally.

DF-MP2 end-to-end — Trainium1 vs NVIDIA A10G

Synthetic inputs, same seed, same three shapes on both platforms. Energy matches bit-for-bit within fp32 reduction-order noise.

Vintage parity: Trainium1 launched Oct 2022; NVIDIA A10G (GA102 Ampere) launched Apr 2021 — closest single-GPU match on AWS. A10G via g5.xlarge (~$1/hr), trn1 via trn1.2xlarge (~$1.34/hr).

Shape Flops trn1 compile-cold trn1 EBS-warm‡ trn1 HBM-warm A10G warm A10G vs trn1
small (128/16/384) 3.4 G 0.091 s 0.001 s 91×
medium (512/64/1536) 2 757 G 2100.4 s†† 119.8 s‡‡ 4.784 s 0.266 s 18×
large (768/96/2304) 20 352 G 174.5 s††† 80.5 s§ 2.018 s 40×

†† Medium compile-cold (2026-04-22, empty cache): chol 32.4 s, half 24.6 s, metric 2.0 s, energy 2041.3 s = 2100.4 s total (77 unique NEFF compilations at ~27 s each). Prior measurement of 238.5 s (2026-04-21) was from a partially-warm EBS cache where GEMM/SYRK/TRSM NEFFs hit cache and only the energy kernel compiled fresh.

EBS-warm = fresh process, all NEFFs loaded from EBS NEFF cache (no compilation), but not yet resident in device HBM. This is the timing experienced by any fresh process after the instance has been used at least once at this shape.

‡‡ Medium EBS-warm (2026-04-22, fresh compilation): chol 11.5 s, half 9.1 s, metric 0.5 s, energy 98.7 s = 119.8 s total. Energy still ~99 s because 64 NEFFs load serially at ~1.5 s/NEFF ≈ 96 s DMA + kernel time.

† HBM-warm = NEFFs already resident in device HBM (in-process second pass). Energy step costs only the kernel dispatch: 64 i-dispatches × ~24 ms = 1.536 s. v0.5.4 chunked dispatch. HBM-warm is not reproducible at medium via a separate process: after the cold pass, 64 energy NEFFs + GEMM/SYRK/TRSM NEFFs fill 15.9 GB of the 16 GB HBM, leaving no headroom for tensor allocation in a second pass. The 4.784 s warm figure is from an earlier tracing run; current architecture cannot re-measure it without HBM OOM.

††† Large compile-cold (2026-04-24, empty energy cache): chol 38.5 s, half 96.5 s, metric 3.8 s, energy 35.6 s (PyTorch CPU fallback†††‡), total 174.5 s. Half-transform 96.5 s reflects GEMM NEFF compilation for large matrices. E = −4.351185×10¹ Ha.

†††‡ Large energy uses PyTorch CPU fallback, not NKI. _j_batched_kernel inner loop tile count N_A × N_B × N_K = 6 × 6 × 18 = 648 exceeds the NeuronCore NRT hardware load limit (~192; medium's 4 × 4 × 12 = 192 loads cleanly). Confirmed: j_chunk = 32, 16, 1 all fail with NRT_RESOURCE. The dispatch proactively detects this and returns _torch_batched_pair_energy immediately (avoids ~30 min of wasted NEFF compilation). Chol, half-transform, and metric steps use NKI; energy alone falls back to trn1 CPUs.

§ Large EBS-warm (2026-04-24): chol 12.5 s (NKI), half 30.4 s (NKI EBS load), metric 2.0 s (NKI), energy 35.5 s (PyTorch CPU, same as cold — no cache benefit), total 80.5 s. The 40× A10G gap reflects the energy step running on trn1 CPUs vs A10G cuBLAS.

† v0.5.4 chunked dispatch. Prior v0.5.3 used CPU fallback (9.910 s). The 18× gap vs A10G is down from 37× in v0.5.3.

Energy bit-exact across platforms: E_MP2 matches to fp32 noise for small (-1.619250e-04) and medium (-2.487218) under real NKI dispatch.

Reading this table

At medium with v0.5.4 chunked dispatch, cuBLAS on A10G is ~18× faster than trnblas NKI on trn1 — down from 37× with the v0.5.3 CPU fallback. The chunked dispatch (64 NKI calls at ~25 ms each) produces 1.536 s warm energy vs 8.035 s torch baseline, bringing the end-to-end total to 4.784 s. At small, the gap balloons to 91× because NKI dispatch overhead dominates the actual ~3 GFlops of compute.

The XLA dispatch overhead is the honest ceiling. Each chunked i-call costs ~24 ms (fixed XLA overhead) + ~1 ms (Tensor Engine execution). At medium shape, 64 calls = 1.6 s of dispatch overhead out of the 1.536 s energy step — the hardware is effectively executing for 64 ms and idling for 1.472 s waiting for the host. Reducing dispatch count (larger chunks, fewer i-rows per call) requires solving the XLA graph size problem for larger batches — which is the remaining Phase 3 frontier.

Closing the A10G gap further requires either batching more pairs per dispatch (reduces dispatch count) or investing in trn2 (2× NeuronCores, lower per-call overhead). See #25 — trn2 benchmarks and #26 — tile autotuner.

NEFF cache warmup

Same suite run twice on a freshly started instance:

Pass Wall time
Cold (first run after instance start) 7.01s
Warm (NEFF cache hit + warm XLA graph) 2.52s (2.8× faster)

The cache at /var/tmp/neuron-compile-cache/ persists across instance stop/start (EBS-backed), so kernel compile cost is paid exactly once per shape per cache lifetime.

Reproducing locally

# Micro-benchmark harness (CPU baselines + NKI when available):
pytest benchmarks/ --benchmark-only

# Full DF-MP2 bench on trn1 (provisions + runs + stops instance):
AWS_PROFILE=aws ./scripts/run_df_mp2_bench.sh --shape medium

# Same workload on A10G (cuBLAS reference for the same vintage):
AWS_PROFILE=aws ./scripts/run_cuda_bench.sh --shape medium

See AWS Setup for the one-time Terraform provisioning for each instance (infra/terraform/ for trn1, infra/terraform-cuda/ for the A10G).

Tile-shape autotuner (v0.5.0)

nki_gemm now sweeps six tile candidates {64,128} × {128} × {128,256,512} on the first call per shape bucket and caches the winner to /var/tmp/trnblas-autotune/cache.json (overrideable via TRNBLAS_AUTOTUNE_CACHE).

How it works

Step Detail
Shape bucket ceil_pow2(M) × ceil_pow2(K) × ceil_pow2(N) — all shapes in a DF-MP2 run land in the same bucket
Sweep 3 warm runs per candidate; candidates that don't evenly divide the padded shape are skipped
Winner Stored in-process in _autotune_mem; written to JSON cache
Cache hit Same bucket → dict lookup only, no re-sweep
Escape hatch TRNBLAS_AUTOTUNE=0 → fixed (128,128,512), identical to v0.4.x

The sweep runs once per shape bucket per instance lifetime (the cache file persists on EBS across stop/start). DF-MP2's nocc² pair loop sees zero sweep overhead after the first call.

Measured numbers (trn1.2xlarge, warm NEFF cache)

Hardware sweep timings are recorded after the first DF-MP2 bench run with v0.5.0. Numbers will be added here once the hardware run completes (#26 tracking).

Fused GEMM+energy kernel (v0.5.1)

nki_fused_gemm_energy(b_i, b_j, eps_occ_i, eps_occ_j, eps_vir) fuses the two GEMMs (T and T_T) and the VE energy expression into a single @nki.jit. Eliminates the (nvir, nvir) T_flat HBM round-trip.

Measured on trn1 (small shape: nocc=16, 256 pairs):

Path Warm energy step
Chunk-GEMM baseline 0.13 s
Per-pair fused (#41, v0.5.1) 27.8 s

The per-pair kernel is correct (energies match to 6 significant figures) but 215× slower — root cause is Neuron XLA's ~100ms per-NEFF-dispatch overhead multiplied by 256 pairs. Fixed in v0.5.2 below.

Batched-pair energy kernel (v0.5.2, #43)

nki_batched_pair_energy(B, eps_occ, eps_vir) computes all NOCC² pair energies in a single @nki.jit dispatch, reducing overhead from O(nocc²) to O(1).

Measured on trn1 (SHA 7dabe88, warm NEFF cache, nocc=4 / 16 pairs):

Metric Value
Batched warm 1.9 ms
Per-pair loop (16 pairs, warm) 25.4 ms
Speedup (warm cache) 13.5×

Reading the 13.5× number: With a warm NEFF cache each nki_fused_gemm_energy call takes ~1.6 ms (Tensor Engine compute only). 16 pairs × 1.6ms = 25.4ms vs one batched dispatch at 1.9ms. In production on a cold instance (first DF-MP2 call, nocc=16 / 256 pairs), each per-pair invocation costs ~100ms → 256 × 100ms = 25.6s vs one batched dispatch → speedup ~1340×. The Spike B measurement (800× at NOCC=4, 16 pairs) used the cold-cache scenario.

All 10 TestBatchedPairEnergy tests passed on trn1 (aligned, unaligned, vs-fused-gemm cross-check, zero-B). Total suite: 62/62.

Out of scope

  • cuBLAS head-to-head at batched-pair scale: planned once PR #44 merges and trn1 numbers are available.
  • trn2 benchmarks: infrastructure provisioned (infra/terraform-trn2/), hardware investigation deferred (#25).