Skip to content

Planning

The planner analyzes einsum subscripts to select a concrete execution strategy, estimate cost, and cache the decision so repeated calls are cheap.

plan_contraction(subscripts, *operands, *, precision="fast") -> ContractionPlan

Inspect the dispatch decision without executing.

plan = trntensor.plan_contraction("ij,jk->ik", A, B)
plan.strategy             # "matmul" | "bmm" | "torch" | "path"
plan.contraction_indices  # list of indices being summed over
plan.batch_indices        # list of batch indices
plan.output_indices       # list of output indices in order
plan.transA, plan.transB  # whether to pre-transpose before matmul
plan.contraction_path     # [(i,j), ...] for "path" strategy
plan.precision            # "fast" | "kahan" | "dd"

Results are cached by (subscripts, operand shapes, precision). Repeated calls with the same subscript, shapes, and precision skip replanning entirely. Call clear_plan_cache() to invalidate (e.g. after a backend change).

precision values

Value Effect on planning
"fast" Default; NKI kernels eligible for matmul/bmm strategies
"kahan" Signals fp64 promotion at execution time; plan is still computed but einsum will use torch.einsum in fp64 regardless of strategy
"dd" Raises NotImplementedError at execution time (trnblas#22 pending)
plan_fast  = trntensor.plan_contraction("ij,jk->ik", A, B, precision="fast")
plan_kahan = trntensor.plan_contraction("ij,jk->ik", A, B, precision="kahan")
# Two distinct cache entries — plan_fast is not plan_kahan

ContractionPlan

Dataclass returned by plan_contraction. All fields are read-only after the planner returns.

Field Type Description
subscripts str Original subscript string
strategy str "matmul" | "bmm" | "torch" | "path"
backend str "nki" or "pytorch"
transA bool Pre-transpose first operand (matmul only)
transB bool Pre-transpose second operand (matmul only)
contraction_indices list[str] Indices summed over
batch_indices list[str] Shared batch indices
output_indices list[str] Output index order
estimated_flops int Multiply-add estimate
contraction_path list[tuple[int,int]] Greedy pair order for "path" strategy (opt_einsum convention)
precision str "fast" | "kahan" | "dd"

Strategy selection

  • "matmul" — 2-operand, 2D tensors, single contracted index → torch.matmul (or NKI kernel when backend is "nki")
  • "bmm" — 2-operand, 3D tensors, single contracted index + single batch index → torch.bmm (or NKI batched kernel)
  • "torch" — 2-operand fallback for patterns that don't fit matmul/bmm → torch.einsum
  • "path" — 3+ operands → greedy binary contraction ordering; each binary step is dispatched through the full backend-selection stack so large sub-contractions still reach NKI

estimate_flops(subscripts, *operands) -> int

Estimate multiply-add operations for a contraction as the product of all distinct index sizes.

trntensor.estimate_flops("ij,jk->ik", A, B)  # M*K*N
trntensor.estimate_flops("iap,jbp->ijab", B, B)  # nocc² * nvir² * naux

Plan cache helpers

clear_plan_cache() -> None

Discard all cached contraction plans. Call after changing the active backend or when memory is a concern.

trntensor.set_backend("nki")
trntensor.clear_plan_cache()  # flush plans that may have chosen "pytorch"

plan_cache_info() -> dict[str, int]

Return cache statistics. Currently returns {"size": N} where N is the number of cached plans.

info = trntensor.plan_cache_info()
print(f"cached plans: {info['size']}")