Skip to main content

Mojo function

tma_tile_scales

tma_tile_scales[BN: Int](ctx: DeviceContext, ptr: UnsafePointer[Float32, MutAnyOrigin], total_elements: Int, out res: TMATensorTile[DType.float32, 2, IndexList(1, BN, Tuple())])

Create a TMA descriptor for per-token float32 scales.

The scales are a flat array of float32 values indexed by the same row_idx as the KV cache blocks. We create a 2D TMA with shape [1, total_elements] and tile [1, BN] so that each async_copy loads BN contiguous float32 values (BN * 4 bytes) starting at the specified column offset.

Returns:

TMATensorTile

Was this page helpful?