Mojo function
tma_tile_scales
tma_tile_scales[BN: Int](ctx: DeviceContext, ptr: UnsafePointer[Float32, MutAnyOrigin], total_elements: Int, out res: TMATensorTile[DType.float32, 2, IndexList(1, BN, Tuple())])
Create a TMA descriptor for per-token float32 scales.
The scales are a flat array of float32 values indexed by the same row_idx as the KV cache blocks. We create a 2D TMA with shape [1, total_elements] and tile [1, BN] so that each async_copy loads BN contiguous float32 values (BN * 4 bytes) starting at the specified column offset.
Returns:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!