Mojo struct
VBufferTransposeLoads
struct VBufferTransposeLoads[dtype: DType, kv_tile_layout: TensorLayout, //, tensor_core_mma: TiledTensorCore[tensor_core_mma.out_type, tensor_core_mma.in_type, tensor_core_mma.shape, tensor_core_mma.group_size, tensor_core_mma.transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]
Fields
- load_tile (
VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].LoadTile): - mma_tile (
VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MmaTile): - smem_ptr (
UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED]): - gmem_tile (
VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].GmemTileType): - reg_loader (
VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].RegLoaderType): - tile_idx (
Int): - current_stage (
Int):
Implemented traits
AnyType,
ImplicitlyDestructible,
KVBuffer
comptime members
depth_tile_size
comptime depth_tile_size = min(depth, 128)
GmemTileType
comptime GmemTileType = TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin]
load_width
comptime load_width = 4 if (depth == 64) else VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width
loads_per_thread_per_depth_tile
comptime loads_per_thread_per_depth_tile = ((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size * BK) // (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width * num_threads))
LoadTile
comptime LoadTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
MMA_K
comptime MMA_K = shape[2]
MMA_M
comptime MMA_M = shape[0]
mma_tile_layout
comptime mma_tile_layout = Layout.row_major((depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)
MmaTile
comptime MmaTile = TileTensor[dtype, Layout[*?, *?], MutExternalOrigin, address_space=AddressSpace.LOCAL]
num_depth_tiles
comptime num_depth_tiles = (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M)
num_k_tiles
comptime num_k_tiles = ceildiv(BK, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_K * group_size))
num_repeats
comptime num_repeats = (BK // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)
RegLoaderType
comptime RegLoaderType = RegTileLoader[dtype, row_major[4, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width)](), warp_scope=True]
simd_width
comptime simd_width = simd_width_of[dtype]()
Methods
__init__
__init__(out self, gmem_tile: TileTensor[dtype, kv_tile_layout, ImmutAnyOrigin], shared_ptr: UnsafePointer[Scalar[dtype], MutAnyOrigin, address_space=AddressSpace.SHARED])
pad
load_from_dram
load_from_dram(mut self)
get_mma_tile
get_mma_tile(self) -> LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
load_from_shared
load_from_shared[k_mma: Int](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!