trnrand: the integer-multiply gap pointed to a better algorithm¶
The previous trnrand post closed with: "the silicon just needs one more op to let the library say it out loud." aws-neuron-sdk#1308 is still open. trnrand 0.4.0 ships hardware-validated uniform RNG on trn1 anyway — not by fixing Philox, but by using Threefry4x32-20, the PRNG Salmon et al. designed in the same SC'11 paper for hardware without fast integer multiply. The library said it out loud without waiting for the op.
The problem¶
Philox 4x32-10 needs one primitive NKI cannot provide exactly: a 32x32-bit integer multiply returning both halves (hi, lo) of the 64-bit result. NKI routes all 32-bit tile operations through the Vector Engine's float32 activation path, which represents integers exactly only up to 2^24 (approximately 16.7M).
The first attempted fix was to decompose the multiply into 8-bit byte
sub-products, keeping each sub-product under 2^16 — well inside the float32
ceiling. The decomposition is arithmetically correct. A pure-numpy port
(_mul32_hi_lo_numpy) verifies it bit-exact against a Python unbounded-integer
ground truth. The kernel compiled, the sub-products were right, and the Philox
output was still wrong.
The problem was not the multiply. Philox counter values exceed 2^24 the moment
the counter advances past the first lane. The moment such a value enters an NKI
tile — via nl.copy, nl.bitwise_and, nl.right_shift, anything — it gets
rounded through float32 before the multiply decomposition can run. The input is
already corrupted. No kernel-level decomposition can work around a lossless-load
problem. The concrete failure: for input 0x7FFFFFFF the kernel returns
0x80000000 outright — the NaN-cast sentinel — because the MSB triggers the
float32 INT_MIN rounding path. Distribution mean on trn1: 0.31 vs the
expected 0.50.
What the architecture suggests¶
Threefry4x32-20 (same Salmon SC'11 paper, same test vectors, same statistical guarantees as Philox) uses only three primitives: 32-bit add, XOR, and rotate-left. None require integer multiply. All three decompose cleanly into byte arithmetic on NKI.
If a 32-bit word is stored as four separate 8-bit byte tiles — each a (P, 1)
uint32 tile with value in [0, 255] — then:
- Addition propagates carries byte-by-byte. Each intermediate sum is at most 255 + 255 + 1 = 511 < 2^10 — two orders of magnitude below float32's exact-integer ceiling.
- XOR operates per-byte independently. Each output byte is in [0, 255].
- Rotate-left by (q bytes + r bits) reindexes bytes and shifts by at most 8 bits. Sub-byte rotation intermediates are at most 32640 < 2^15.
The architecture does not merely permit byte-tile Threefry — it makes it exact where uint32-tile Philox is approximate. The precision problem is not "float32 is imprecise"; it is "Philox needs one primitive float32 cannot model; Threefry does not." Salmon et al. designed Threefry specifically for hardware without fast multiply. Trainium's current NKI substrate is exactly that hardware.
flowchart LR
subgraph input ["Input: counter + key words"]
W["32-bit int32 tile\n(P, 1)"]
end
subgraph bytes ["_b_split: byte decomposition"]
B0["b0 in [0,255]"]
B1["b1 in [0,255]"]
B2["b2 in [0,255]"]
B3["b3 in [0,255]"]
end
subgraph gpsimd ["GpSimd: 20 rounds, all intermediates <= 511"]
ADD["_add32_b\ncarry-propagating add"]
XOR["_xor32_b\nper-byte XOR"]
ROT["_rotl32_b\nbyte-shift + bit-shift"]
ADD --> XOR --> ROT --> ADD
end
subgraph conv ["Output conversion (no uint32 tile)"]
MAN["mantissa = b0 + b1*256 + b2*65536\n<= 2^24 - 1, exact in float32"]
F32["/ 2^24 -> float32 U[0,1)"]
MAN --> F32
end
subgraph ve ["Vector Engine"]
BM["Box-Muller\nlog / sqrt / cos / sin\n-> N(0,1)"]
end
W --> B0 & B1 & B2 & B3 --> ADD
ROT -->|"key injection\nevery 4 rounds"| MAN
F32 -.SBUF-resident.-> BM
The output conversion is where the approach pays off a second time: the three least-significant bytes of each output word give a 24-bit mantissa — the maximum resolution a float32 significand can carry. No uint32 tile is ever assembled. The kernel goes from counter inputs to float32 uniforms without touching any representation that exceeds float32's exact-integer ceiling.
The approach¶
Every 32-bit word in the kernel — the four counter words, four key words,
running Threefry state, and all intermediate computation — lives as four separate
(P, 1) uint32 tiles. The helper _b_split extracts four bytes from a
(P, 1) int32 tile immediately on load. The kernel never holds a uint32 value
above 255 in any tile for the duration of the 20 Threefry rounds and 5 key
injections.
The fused threefry_normal_kernel carries byte-tile state directly into the
Vector Engine Box-Muller stage. Output uniforms are SBUF-resident between the
GpSimd computation and the transcendental stage — no HBM round-trip between the
RNG and the transform. This is the four-engine pipeline described in the prior
post, now realized end to end: counter arithmetic on GpSimd, transcendentals on
the Vector Engine, SBUF-resident output available to downstream consumers such
as trnfft noise injection or the stochastic-trace estimators in trnblas.
Threefry is stateless: outputs are a pure function of (counter, key). Like
Philox, this makes partition-axis splitting trivially correct. Each of 128
partition lanes holds an independent Threefry stream; no state synchronization
across lanes is needed. The batch index occupies the second counter word (c1),
giving disjoint counter ranges across calls by construction.
Implementation¶
The three core byte-tile helpers, from
trnrand/nki/dispatch.py:
def _add32_b(a_b, b_b):
"""Carry-propagating 32-bit addition. Each intermediate <= 511."""
s0 = nl.add(a_b[0], b_b[0], dtype=nl.uint32) # <= 510
c0 = nl.right_shift(s0, 8, dtype=nl.uint32)
r0 = nl.bitwise_and(s0, 0xFF, dtype=nl.uint32)
s1 = nl.add(nl.add(a_b[1], b_b[1], dtype=nl.uint32), c0, dtype=nl.uint32)
c1 = nl.right_shift(s1, 8, dtype=nl.uint32)
r1 = nl.bitwise_and(s1, 0xFF, dtype=nl.uint32)
s2 = nl.add(nl.add(a_b[2], b_b[2], dtype=nl.uint32), c1, dtype=nl.uint32)
c2 = nl.right_shift(s2, 8, dtype=nl.uint32)
r2 = nl.bitwise_and(s2, 0xFF, dtype=nl.uint32)
s3 = nl.add(nl.add(a_b[3], b_b[3], dtype=nl.uint32), c2, dtype=nl.uint32)
r3 = nl.bitwise_and(s3, 0xFF, dtype=nl.uint32)
return [r0, r1, r2, r3]
Output conversion per word (unrolled x4 to satisfy the hardware compiler):
b = x0_b
out[:, 0:1] = nl.multiply(
nl.add(
nl.add(nl.copy(b[0], dtype=nl.float32),
nl.multiply(nl.copy(b[1], dtype=nl.float32), _s256,
dtype=nl.float32), dtype=nl.float32),
nl.multiply(nl.copy(b[2], dtype=nl.float32), _s65536,
dtype=nl.float32), dtype=nl.float32,
),
inv24, dtype=nl.float32, # inv24 = 1.0 / 16_777_216.0
)
The _rotl32_b helper covers all (q, r) combinations through fully unrolled
branches rather than a loop — 64 cases total. The reason is in the next section.
What didn't work¶
Three things: one at the algorithm level, and two from the NKI hardware compiler.
Fixing the Philox multiply works; fixing the Philox inputs doesn't. The
8-bit byte decomposition of the 32x32-bit multiply is correct and bit-exact.
The mistake was assuming that fixing the multiply would fix the output. The
root cause was one layer up, at the tile-load boundary, where counter inputs
above 2^24 were being silently rounded before any arithmetic began. An
nki.isa.tensor_copy path was also considered as a potential bypass around
the VE float32 cast; the AWS Neuron team's response on aws-neuron-sdk#1308
("VE/Scalar engines use FP32 casting for all 32-bit types") indicates this is
a systemic property of the current architecture rather than a single-op
workaround. Threefry removes the question. The outstanding upstream ask remains
specific: document the 2^24 exact-integer ceiling in the NKI type-casting
reference, and provide either a bitwise-exact nl.copy path for integer tiles
or a compile-time error when the cast truncates. Filed with a reproducer at
aws-neuron-sdk#1308.
The CPU simulator and the trn1 hardware compiler accept different Python constructs. Three categories of syntax are silently accepted by the simulator and rejected by the real compiler, found across three separate SSM hardware runs after the kernel passed all simulator tests:
- Inner function definitions inside any function in the
@nki.jitcall tree. Fix: extract to module level. The_mul32_hi_lohelper originally defined four inner functions; all became module-level helpers. - List comprehensions (
[expr for i in range(n)]) inside jit-traced code. Fix: explicit element-by-element construction. This is why_rotl32_bis fully unrolled rather than a compact loop. - Subscript expressions as left-hand assignment targets in tuple unpacking.
x_b_list[0], x_b_list[1] = _mix_b(...)is rejected; named-variable unpacking (x0_b, x1_b = _mix_b(...)) is accepted.
The pattern: where a construct touches Python's dynamic object model at the AST level, the real NKI compiler may not trace it. Use the simplest syntactic form available. None of these failures produced a useful error message on first encounter — each surfaced as a generic compile failure on the trn1 host that had no simulator analog. A specific ask for AWS: the hardware compiler's rejection message for inner function defs should name the outer function and the line number of the inner def; the current message does not.
NCC_IBIR605 (pre-existing, not a Threefry regression). The fused
threefry_normal_kernel Box-Muller stage is blocked on trn1 by the same
compiler restriction that blocked the standalone box_muller_kernel in 0.3.0:
InstActivation rejects non-immediate bias parameters when the activation is
Ln. The two affected hardware tests are marked xfail(strict=False) and
tracked in trnrand#2. This is
trn1-only; trn2+ and the CPU simulator are unaffected, and it has no bearing
on the Threefry algorithm or the uniform kernel.
Numbers¶
No throughput benchmark yet — that's 0.5 scope (trnrand#3), deferred until both uniform and normal have clean trn1 paths. Hardware correctness as of 0.4.0:
| Test | Simulator | trn1 hardware |
|---|---|---|
| KAT vectors (3 Salmon SC'11 reference vectors) | pass | pass |
| Reference parity (128-lane numpy vs NKI output) | pass | pass |
| U[0,1) distribution (mean 0.500 +/- 0.01, var 1/12 +/- 0.005) | pass | pass |
| Seed determinism + seed isolation | pass | pass |
| threefry_normal N(0,1) distribution | pass | xfail (NCC_IBIR605, trnrand#2) |
Four of five TestThreefryNKI hardware cases pass. The fifth xfail is
pre-existing trn1 compiler behavior and does not affect the uniform kernel
or the Threefry algorithm on any other platform.
What's next¶
aws-neuron-sdk#1308 stays open. Philox remains the intended long-term primary for on-device RNG — it is the cuRAND and JAX standard, stateless, and partition-parallel by construction. Threefry is the production path until AWS ships a true uint32 integer-multiply primitive; at that point Philox hardware validation reopens.
trnrand#2 (NCC_IBIR605) —
the threefry_normal_kernel trn1 path unblocks when the trn1 compiler fix
ships. No workaround exists at the kernel level; trn2+ and the simulator path
are clean today.
trnrand#3 — benchmarks vs cuRAND on equivalent generation sizes. Deferred to 0.5; the comparison is only useful once both chips have end-to-end correct on-device paths.
The gamma, chi-squared, beta, and Poisson distributions (added in 0.2.0, CPU paths only) all wait on on-device acceleration via NKI kernels built on top of the uniform primitive validated here.
Takeaway¶
The float32 exact-integer ceiling that blocks Philox on Trainium does not block Threefry, because Threefry was designed for exactly that constraint. Byte-tile arithmetic is not a workaround for the ceiling; it is the representation that keeps every intermediate at least three orders of magnitude below it. The four-engine pipeline — GpSimd byte arithmetic, SBUF-resident intermediate output, Vector Engine transcendentals — is now hardware-validated for uniforms on trn1. The integer-multiply gap that blocked Philox did not block trnrand; it identified the algorithm the architecture already preferred.