einsum¶
Einstein summation with contraction planning. Mirrors the subscript
notation used by torch.einsum and numpy.einsum, with additional
control over scaling, accumulation precision, and mixed-precision compute.
einsum(subscripts, *operands, *, alpha, beta, out, dtype, precision) -> torch.Tensor¶
Execute a single contraction.
einsum(
subscripts: str,
*operands: torch.Tensor,
alpha: float = 1.0,
beta: float = 0.0,
out: torch.Tensor | None = None,
dtype: str | torch.dtype | None = None,
precision: str = "fast",
) -> torch.Tensor
The planner selects matmul, bmm, path, or torch.einsum based on
the subscript pattern and operand shapes. When a NKI backend is active and
the pattern maps to a 2-index matmul, the kernel in
trntensor.nki._kernels.matmul_kernel is used.
Arguments¶
| Argument | Default | Description |
|---|---|---|
subscripts |
— | Einstein summation string, e.g. "ij,jk->ik" |
*operands |
— | Input tensors |
alpha |
1.0 |
Scalar multiplier on the contraction result |
beta |
0.0 |
Scalar multiplier on out before accumulation |
out |
None |
Accumulation tensor; result is α·contract + β·out |
dtype |
None |
Cast all operands to this dtype before contracting; result is returned in this dtype. Accepts torch.dtype or strings: "bf16", "bfloat16", "fp16", "float16", "f32", "float32" |
precision |
"fast" |
Accumulation precision: "fast", "kahan", or "dd" (see below) |
Basic examples¶
import trntensor
# Matrix multiply
C = trntensor.einsum("ij,jk->ik", A, B)
# Batched matmul
C = trntensor.einsum("bij,bjk->bik", A, B)
# DF-MP2 pair contraction
T_ab = trntensor.einsum("ap,bp->ab", B_i, B_j)
# 4-index AO→MO transform step
Inu = trntensor.einsum("mi,mnP->inP", C, eri)
alpha / beta scaling¶
Matches cuTENSOR's GEMM-style interface:
alpha·contract(A, B) + beta·C.
# Scaled GEMM: 2·A@B + 0.5·C
result = trntensor.einsum("ij,jk->ik", A, B, alpha=2.0, beta=0.5, out=C)
# In-place accumulation (beta=1, alpha=1): result += A@B
trntensor.einsum("ij,jk->ik", A, B, beta=1.0, out=acc)
dtype — mixed-precision compute¶
Cast operands to a lower dtype to hit the NKI bf16 matmul path without changing the model's weight dtype. The result is returned in the requested dtype.
# Route through NKI bf16 kernel even if weights are fp32
result_bf16 = trntensor.einsum("ij,jk->ik", A, B, dtype="bf16")
# Use torch.dtype directly
result = trntensor.einsum("ij,jk->ik", A, B, dtype=torch.float16)
precision — accumulation precision¶
| Value | Behavior |
|---|---|
"fast" |
Native operand dtype; NKI kernels accumulate in fp32 via PSUM |
"kahan" |
Promotes operands to fp64 before contracting, casts back to original dtype; ~15.9 significant digits; bypasses NKI dispatch (runs on CPU) |
"dd" |
Double-double accumulation via trnblas Phase 2 — raises NotImplementedError until trnblas#22 lands |
# High-precision DF-MP2 energy accumulation
E_corr = trntensor.einsum("ijab,ijab->", T, T2, precision="kahan")
multi_einsum(*contractions) -> list[torch.Tensor]¶
Execute several contractions in one call. Each contraction is a tuple
of (subscripts, *operands). Results are returned in input order.
When NKI dispatch is active, operand tensors that appear in more than one contraction (by object identity) are pre-pinned to the XLA device once before executing the loop, eliminating redundant host↔device transfers.
results = trntensor.multi_einsum(
("ap,bp->ab", B_i, B_j), # Coulomb term
("ap,bp->ab", B_i, B_k), # exchange term — B_i pinned once
)
Plan cache helpers¶
These are documented on the Planning page but are also relevant to einsum performance:
trntensor.clear_plan_cache()— discard all cached planstrntensor.plan_cache_info()— return{"size": N}cache statistics