Skip to main content

Mojo module

attention

comptime values

ImmutTileTensor1D

comptime ImmutTileTensor1D[dtype: DType] = TileTensor[dtype, Layout[*?, *?], ImmutAnyOrigin]

Parameters

KVTMATile

comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int] = TMATensorTile[dtype, 3, _padded_shape[3, dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode](), _ragged_shape[3, dtype, IndexList(BN, 1, BK, __list_literal__=Tuple()), swizzle_mode]()]

Parameters

QTMATile

comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool, fuse_gqa: Bool = False, num_qk_stages: Int = 1] = TMATensorTile[dtype, 4 if decoding or fuse_gqa else 3, _padded_shape[4 if decoding or fuse_gqa else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, fuse_gqa=fuse_gqa, num_qk_stages=num_qk_stages](), swizzle_mode](), _ragged_shape[4 if decoding or fuse_gqa else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, fuse_gqa=fuse_gqa, num_qk_stages=num_qk_stages](), swizzle_mode]()]

Parameters

Structs

Traits

Functions

Was this page helpful?