Mojo struct
GlobalMemoryManager
struct GlobalMemoryManager[dtype: DType, BM: UInt32, BN: UInt32, BK: UInt32, depth: UInt32, num_heads: UInt32, group: UInt32, token_gen: Bool, q_depth: UInt32 = depth, output_depth: UInt32 = depth]
Fields
- q_offset (
UInt32): - output_offset (
UInt32): - valid_rows (
UInt32):
Implemented traits
AnyType,
ImplicitlyDestructible
comptime members
kv_gmem_layout
comptime kv_gmem_layout = Layout(IntTuple(Int[UInt32](BN), Int[UInt32](depth)), IntTuple(Int[UInt32]((GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_num_heads * depth)), 1))
kv_num_heads
comptime kv_num_heads = (num_heads // group)
KvTileLayout
comptime KvTileLayout = Layout[*?, *?]
output_gmem_layout
comptime output_gmem_layout = Layout(IntTuple(Int[UInt32](BM), Int[UInt32](output_depth)), IntTuple(Int[UInt32]((num_heads * output_depth)), 1)) if not token_gen else Layout.row_major(Int[UInt32](BM), Int[UInt32](output_depth))
OutputTileLayout
comptime OutputTileLayout = Layout[*?, *?]
q_gmem_layout
comptime q_gmem_layout = Layout(IntTuple(Int[UInt32](BM), Int[UInt32](q_depth)), IntTuple(Int[UInt32]((num_heads * q_depth)), 1)) if not token_gen else Layout.row_major(Int[UInt32](BM), Int[UInt32](q_depth))
QTileLayout
comptime QTileLayout = Layout[*?, *?]
Methods
__init__
__init__(out self, q_tile_idx: UInt32, kv_head_idx: UInt32, seq_len: Int, q_offset: UInt32, output_offset: UInt32)
get_q_tensor
get_q_tensor[qtype: DType](self, ptr: UnsafePointer[Scalar[qtype], ImmutAnyOrigin]) -> LayoutTensor[qtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].q_gmem_layout, ImmutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]
Returns:
get_q_tile
get_q_tile[qtype: DType](self, ptr: UnsafePointer[Scalar[qtype], ImmutAnyOrigin]) -> TileTensor[qtype, Layout[*?, *?], ImmutAnyOrigin]
Return the Q DRAM tile as a TileTensor with RuntimeInt valid_rows.
Args:
- ptr (
UnsafePointer): Base pointer to the Q buffer.
Returns:
TileTensor: A TileTensor with RuntimeInt rows and ComptimeInt strides.
get_output_tensor
get_output_tensor[out_type: DType](self, ptr: UnsafePointer[Scalar[out_type], MutAnyOrigin]) -> LayoutTensor[out_type, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].output_gmem_layout, MutAnyOrigin, layout_int_type=DType.int32, linear_idx_type=DType.int32, masked=True]
Returns:
get_output_tile
get_output_tile[out_type: DType](self, ptr: UnsafePointer[Scalar[out_type], MutAnyOrigin]) -> TileTensor[out_type, Layout[*?, *?], MutAnyOrigin]
Return the output DRAM tile as a TileTensor with RuntimeInt valid_rows.
The RuntimeInt dim[0] ensures make_amd_buffer_resource computes correct OOB clamping bounds when the tile exceeds valid data.
Args:
- ptr (
UnsafePointer): Base pointer to the output buffer.
Returns:
TileTensor: A TileTensor with RuntimeInt rows and ComptimeInt strides.
get_kv_tensor
get_kv_tensor[kvtype: DType, //](self, ptr: UnsafePointer[Scalar[kvtype], ImmutAnyOrigin], kv_tile_num_rows: UInt32) -> LayoutTensor[kvtype, GlobalMemoryManager[dtype, BM, BN, BK, depth, num_heads, group, token_gen, q_depth, output_depth].kv_gmem_layout, ImmutAnyOrigin, masked=True]
Returns:
get_kv_tile
get_kv_tile[kvtype: DType, //](self, ptr: UnsafePointer[Scalar[kvtype], ImmutAnyOrigin], kv_tile_num_rows: UInt32) -> TileTensor[kvtype, Layout[*?, *?], ImmutAnyOrigin]
Return the KV DRAM tile as a TileTensor with RuntimeInt valid_rows.
The RuntimeInt dim[0] ensures make_amd_buffer_resource computes correct OOB clamping bounds when the tile exceeds valid data.
Args:
- ptr (
UnsafePointer): Base pointer to the KV cache buffer. - kv_tile_num_rows (
UInt32): Number of valid rows in this tile.
Returns:
TileTensor: A TileTensor with RuntimeInt rows and ComptimeInt strides.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!