Skip to main content

Mojo function

q_tma

q_tma[dtype: DType, //, swizzle_mode: TensorMapSwizzle, *, BM: Int, depth: Int, q_num_heads: Int, group: Int, decoding: Bool, num_qk_stages: Int = 1](ctx: DeviceContext, ptr: UnsafePointer[Scalar[dtype], ptr.origin], rows: Int) -> TMATensorTile[dtype, 4 if decoding else 3, _padded_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode](), _ragged_shape[4 if decoding else 3, dtype, q_smem_shape[dtype, swizzle_mode, BM=BM, group=group, depth=depth, decoding=decoding, num_qk_stages=num_qk_stages](), swizzle_mode]()]

Returns:

TMATensorTile

Was this page helpful?