Mojo function
amd_4wave_scheduled_matmul
amd_4wave_scheduled_matmul[a_type: DType, b_type: DType, c_type: DType, //, enable_swizzle: Bool = True, block_m_override: Int = 0, block_n_override: Int = 0, dump_asm_path: StringSlice[StaticConstantOrigin] = StringSlice("")](a: TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size], c: TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], ctx: DeviceContext)
Launches the schedule-compiler-driven 4-wave matmul on the device.
Identical dispatch to amd_4wave_matmul (same auto-pick
heuristic, same override gates, same chiplet/L2 swizzle, 1D launch
grid), but invokes AMD4WaveMatmul.run with the
use_framework_schedule=True comptime flag. Use this as the
framework arm of an A/B against the inline arm to attribute perf
gaps to op ordering vs scaffolding.
Parameters:
- βa_type (
DType): Element type ofa. - βb_type (
DType): Element type ofb. - βc_type (
DType): Element type ofc. - βenable_swizzle (
Bool): Enable LDS bank-conflict avoidance. - βblock_m_override (
Int): If > 0, force BM to this value (must be 64 or 128). - βblock_n_override (
Int): If > 0, force BN to this value (must be 64, 128, or 256). Default 0 uses BM=BN. - βdump_asm_path (
StringSlice[StaticConstantOrigin]): If non-empty, dumps the compiled GCN assembly to the given file path. Only used for ASM-level diff-debugging.
Args:
- βa (
TileTensor[a_type, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size]): Input tile-tensor for A. - βb (
TileTensor[b_type, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size]): Input tile-tensor for B. - βc (
TileTensor[c_type, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size]): Output tile-tensor for C. - βctx (
DeviceContext): Device context used to enqueue the kernel.
Raises:
An error if device enqueue fails.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!