Mojo module
attention
comptime values
ImmutTileTensor1D
comptime ImmutTileTensor1D[dtype: DType] = TileTensor[dtype, Layout[#kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).shape), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).shape[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64])), #kgen.variadic.reduce(#kgen.variadic.tabulate(len[IntTuple](Layout.row_major(VariadicList(-1)).stride), [idx: __mlir_type.index] _int_to_dim(Layout.row_major(VariadicList(-1)).stride[idx].value())), base=, reducer=[PrevV: Variadic[CoordLike], VA: Variadic[Dim], idx: __mlir_type.index] #kgen.variadic.concat(PrevV, ComptimeInt[VA[idx]._value_or_missing] if (VA[idx] != -31337) else RuntimeInt[DType.int64]))], ImmutAnyOrigin]
Parameters
- dtype (
DType):
KVTMATile
comptime KVTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BN: Int, BK: Int] = TMATensorTile[dtype, 3, _padded_shape[3, dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode](), _ragged_shape[3, dtype, IndexList(VariadicList(BN, 1, BK), Tuple()), swizzle_mode]()]
Parameters
- dtype (
DType): - swizzle_mode (
TensorMapSwizzle): - BN (
Int): - BK (
Int):
QTMATile
comptime QTMATile[dtype: DType, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, group: Int, decoding: Bool, num_qk_stages: Int = 1] = TMATensorTile[dtype, 4 if decoding else 3, _padded_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode](), _ragged_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode]()]
Parameters
Structs
-
MHAPosition: Position of the MHA-kernel. Whendecoding=False,q_head_stride == q_num_heads. Whendecoding=True,q_head_stride == 1. -
NonNullPointer: -
NullPointer: -
Pack: -
PositionSummary:
Traits
Functions
-
get_q_head_idx: -
get_seq_info: -
kv_coord: -
output_reg_to_smem: -
output_reg_to_smem_st_matrix: -
produce: -
q_coord: Returns the coordinates for a tma load on theQmatrix. This load can be 3D, 4D, or 5D. -
q_gmem_shape: -
q_smem_shape: -
q_tma:
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!