For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
tma_tile_scales
def tma_tile_scales[BN_QK: Int](ctx: DeviceContext, ptr: UnsafePointer[Float32, MutAnyOrigin], total_elements: Int, out res: TMATensorTile[DType.float32, Int(2), IndexList(Int(1), BN_QK, __list_literal__=NoneType(None))])
Create a TMA descriptor for per-token float32 scales.
The scales are a flat array of float32 values indexed by the same row_idx as the KV cache blocks. We create a 2D TMA with shape [1, total_elements] and tile [1, BN_QK] so that each async_copy loads BN_QK contiguous float32 values (BN_QK * 4 bytes) starting at the specified column offset.
Returns:
TMATensorTile[DType.float32, Int(2), IndexList(Int(1), BN_QK, __list_literal__=NoneType(None))]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!