Mojo struct
VBufferTransposeLoads
struct VBufferTransposeLoads[out_type: DType, in_type: DType, shape: IndexList[3], group_size: Int, transpose_b: Bool, mut: Bool, dtype: DType, layout: Layout, address_space: AddressSpace, alignment: Int, origin: Origin[mut], masked: Bool, layout_int_type: DType, linear_idx_type: DType, //, tensor_core_mma: TiledTensorCore[out_type, in_type, shape, group_size, transpose_b], BN: Int, BK: Int, depth: Int, num_threads: Int, num_stages: Int = 1]
Fields
- load_tile (
LayoutTensor[dtype, Layout.row_major(((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].loads_per_thread_per_depth_tile * (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size)) * num_stages), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]): - mma_tile (
LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]): - smem_iter (
LayoutTensorIter[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]): - global_iterator (
LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, BK, depth]()[0], origin, address_space=address_space, axis=0, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BK, depth]()]): - global_base_tile (
LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]): - current_stage (
Int):
Implemented traits
AnyType,
KVBuffer,
UnknownDestructibility
Aliases
__del__is_trivial
comptime __del__is_trivial = True
base_layout
comptime base_layout = Layout.row_major(VBufferTransposeLoads.pad[out_type, in_type, shape, group_size, transpose_b, mut, dtype, layout, address_space, alignment, origin, masked, layout_int_type, linear_idx_type, tensor_core_mma, BN, BK, depth, num_threads, num_stages, depth](), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)
depth_tile_size
comptime depth_tile_size = min(depth, 128)
GlobalTensorType
comptime GlobalTensorType = LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment]
GlobalTiledIteratorType
comptime GlobalTiledIteratorType = LayoutTensorIter[dtype, LayoutTensor._compute_tile_layout[mut, dtype, layout, origin, address_space, Layout(IntTuple(1), IntTuple(1)), layout_int_type, linear_idx_type, masked, alignment, BK, depth]()[0], origin, address_space=address_space, axis=0, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked if masked else _tile_is_masked[layout, BK, depth]()]
load_width
comptime load_width = 4 if (depth == 64) else VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width
loads_per_thread_per_depth_tile
comptime loads_per_thread_per_depth_tile = ((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size * BK) // (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width * num_threads))
LoadTileType
comptime LoadTileType = LayoutTensor[dtype, Layout.row_major(((VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].loads_per_thread_per_depth_tile * (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].depth_tile_size)) * num_stages), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].load_width), MutAnyOrigin, address_space=AddressSpace.LOCAL]
MMA_K
comptime MMA_K = shape.__getitem__[3, DType.int64, Int](2)
MMA_M
comptime MMA_M = shape.__getitem__[3, DType.int64, Int](0)
mma_tile_layout
comptime mma_tile_layout = Layout.row_major((depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M), VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)
MMATileType
comptime MMATileType = LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
num_depth_tiles
comptime num_depth_tiles = (depth // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_M)
num_k_tiles
comptime num_k_tiles = ceildiv(BK, (VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].MMA_K * group_size))
num_repeats
comptime num_repeats = (BK // VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].simd_width)
SharedIterType
comptime SharedIterType = LayoutTensorIter[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True]
SharedTileType
comptime SharedTileType = LayoutTensorIter[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].smem_layout, MutAnyOrigin, address_space=AddressSpace.SHARED, circular=True].LayoutTensorType
simd_width
comptime simd_width = simd_width_of[dtype]()
smem_layout
comptime smem_layout = blocked_product(VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].base_layout, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].tiler_layout, True)
tiler_layout
comptime tiler_layout = Layout.row_major(1, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].num_repeats)
Methods
__init__
__init__(out self, global_tile: LayoutTensor[dtype, layout, origin, address_space=address_space, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], shared_ptr: LegacyUnsafePointer[Scalar[dtype], address_space=AddressSpace.SHARED, mut=mut, origin=origin])
get_dtype
pad
load_from_dram
load_from_dram(mut self)
get_mma_tile
get_mma_tile(self) -> LayoutTensor[dtype, VBufferTransposeLoads[tensor_core_mma, BN, BK, depth, num_threads, num_stages].mma_tile_layout, MutAnyOrigin, address_space=AddressSpace.LOCAL]
Returns:
copy_to_shared
copy_to_shared[tile_id: Int = 0](self)
load_from_shared
load_from_shared[k_mma: Int](self)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!