Skip to content

Architecture

trntensor/
├── trntensor/
│   ├── __init__.py
│   ├── einsum.py        # einsum(), multi_einsum()
│   ├── plan.py          # ContractionPlan, plan_contraction(), estimate_flops()
│   ├── decompose.py     # cp_decompose, tucker_decompose (HOSVD)
│   └── nki/
│       ├── __init__.py
│       └── dispatch.py  # Fused contraction kernels
├── tests/
├── examples/
│   └── df_mp2_einsum.py # DF-MP2 energy via einsum

Contraction planning

The planner analyzes einsum subscripts and selects a dispatch target:

  • matmul: 2D contraction over a single shared index → torch.matmul
  • bmm: batched 2D contraction → torch.bmm
  • torch: complex patterns → torch.einsum
  • nki (future): fused multi-index contractions on the Tensor Engine

ContractionPlan carries the dispatch decision, the FLOPs estimate, and any reshape/transpose preamble so the same plan can be re-executed cheaply.

Fused contraction-reduction: the architectural core

What makes trntensor different from torch.einsum or trnblas.gemm is fusion across the boundary between contraction and reduction. On Trainium:

  • the Tensor Engine does the contraction, accumulating into PSUM
  • the Vector Engine does the elementwise reshaping, reading from SBUF
  • a small SBUF scalar accumulates the reduction
  • nothing lands in HBM until the whole program finishes

A cuTENSOR clone would compose these three phases as separate ops with HBM round-trips between them. NKI lets us write them as one @nki.jit program with the same data living in PSUM → SBUF → accumulator across phases. That's the architectural lever.

The canonical demonstration is trntensor.mp2_energy — the DF-MP2 correlation energy is

T_{i,j,a,b} = Σ_P B_{i,a,P} B_{j,b,P}              (contraction)
term        = T * (2T - T^T) / Δ_{i,j,a,b}         (elementwise)
E           = Σ_{i,j,a,b} term                      (reduction)

The fused kernel in trntensor/nki/_kernels.py::mp2_energy_kernel does all three in one program, iterating over (i, j) pairs in affine_range. No intermediate T tensor is ever materialized to HBM.

This pattern — fused contract + elementwise + reduce with SBUF-resident intermediates — is what trntensor extends across the suite. Generic einsum dispatch to matmul / bmm kernels is the foundation; primitives like mp2_energy are the architectural flagships.

Performance honest note. The fused kernel is architecturally correct but currently carries ~15–40 ms of XLA dispatch + compile overhead per call, which means at small chemistry sizes the CPU fallback path still wins. The fusion work pays off once per-call overhead is amortized (tracked in #33/#34). See Benchmarks for the current numbers.

Decompositions for quantum chemistry

  • CP (CANDECOMP/PARAFAC) — Tensor hypercontraction (THC) of two-electron integrals. Reduces $O(N^4)$ storage to $O(N^2 R)$.
  • Tucker (HOSVD) — Low-rank approximation of DF coefficient tensors. Reduces memory for large auxiliary basis sets.

Both use alternating-least-squares on top of trnblas-style GEMM primitives. The decompositions can be fed back through einsum at evaluation time — e.g., contracting Tucker factors directly against a state vector without materializing the full tensor.