Skip to content

Quickstart

import torch
import trnblas

A = torch.randn(256, 128)
B = torch.randn(128, 64)

# Level 3 — Matrix multiply (the hot path)
C = trnblas.gemm(alpha=1.0, A=A, B=B)

# Batched GEMM (DF-MP2 tensor contractions)
A_batch = torch.randn(32, 128, 64)
B_batch = torch.randn(32, 64, 32)
C_batch = trnblas.batched_gemm(1.0, A_batch, B_batch)

# Triangular solve (Cholesky-based density fitting)
L = torch.linalg.cholesky(A.T @ A + torch.eye(128))
B2 = torch.randn(128, 16)
X = trnblas.trsm(1.0, L, B2, uplo="lower")

# Symmetric rank-k update (metric construction)
J = trnblas.syrk(1.0, A, trans=True)

# Level 2 — Matrix-vector
x = torch.randn(128)
y = trnblas.gemv(1.0, A, x)

# Level 1 — Vector operations
u = torch.randn(256)
v = torch.randn(256)
w = trnblas.axpy(2.0, u, v)
d = trnblas.dot(u, v)
n = trnblas.nrm2(u)

Backend selection

import trnblas

trnblas.set_backend("auto")     # NKI on Trainium, PyTorch elsewhere (default)
trnblas.set_backend("pytorch")  # force PyTorch
trnblas.set_backend("nki")      # force NKI (requires Neuron hardware)

DF-MP2 example

python examples/df_mp2.py --demo
python examples/df_mp2.py --nbasis 100 --nocc 20