Skip to content

Quickstart

einsum

import torch
import trntensor

# 2-index contraction (dispatches to matmul)
A = torch.randn(32, 64)
B = torch.randn(48, 64)
C = trntensor.einsum("ap,bp->ab", A, B)   # (32, 48)

# Batched (dispatches to bmm)
Ab = torch.randn(16, 32, 64)
Bb = torch.randn(16, 48, 64)
Cb = trntensor.einsum("iap,ibp->iab", Ab, Bb)   # (16, 32, 48)

Contraction planning

plan = trntensor.plan_contraction("ap,bp->ab", A, B)
print(plan.dispatch)   # "matmul" | "bmm" | "torch" | "nki"
print(plan.flops)      # estimated FLOP count

flops = trntensor.estimate_flops("ap,bp->ab", A, B)

Decompositions

# CP (CANDECOMP/PARAFAC)
X = torch.randn(16, 16, 16)
factors = trntensor.cp_decompose(X, rank=8)
reconstructed = trntensor.cp_reconstruct(factors)

# Tucker (HOSVD)
core, factors = trntensor.tucker_decompose(X, ranks=(4, 4, 4))
reconstructed = trntensor.tucker_reconstruct(core, factors)

Backend selection

trntensor.set_backend("auto")     # default
trntensor.set_backend("pytorch")  # force PyTorch fallback
trntensor.set_backend("nki")      # force NKI (requires Neuron hardware)