Mojo function
matmul
matmul[use_tf32: Bool = False](ctx: DeviceContext, c: NullableTileTensor[c.dtype, c.LayoutType, c.origin, address_space=c.address_space, linear_idx_type=c.linear_idx_type, element_size=c.element_size], a: TileTensor[a.dtype, a.LayoutType, a.origin, address_space=a.address_space, linear_idx_type=a.linear_idx_type, element_size=a.element_size], b: TileTensor[b.dtype, b.LayoutType, b.origin, address_space=b.address_space, linear_idx_type=b.linear_idx_type, element_size=b.element_size], *, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)
Matmul using the vendor BLAS library for NullableTileTensor operands.
Note: This overload does not support a_scales/b_scales. Add scale parameters here when a TileTensor caller needs scaled vendor matmul.
matmul[c_type: DType, a_type: DType, b_type: DType, c_layout: Layout, a_layout: Layout, b_layout: Layout, *, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, c_tensor: LayoutTensor[c_type, c_layout, c_tensor.origin], a_tensor: LayoutTensor[a_type, a_layout, a_tensor.origin], b_tensor: LayoutTensor[b_type, b_layout, b_tensor.origin], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, ImmutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, ImmutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)
matmul[c_type: DType, a_type: DType, b_type: DType, *, use_tf32: Bool = False, scales_type: DType](ctx: DeviceContext, c_tensor: TileTensor[c_type, c_tensor.LayoutType, c_tensor.origin, address_space=c_tensor.address_space, linear_idx_type=c_tensor.linear_idx_type, element_size=c_tensor.element_size], a_tensor: TileTensor[a_type, a_tensor.LayoutType, a_tensor.origin, address_space=a_tensor.address_space, linear_idx_type=a_tensor.linear_idx_type, element_size=a_tensor.element_size], b_tensor: TileTensor[b_type, b_tensor.LayoutType, b_tensor.origin, address_space=b_tensor.address_space, linear_idx_type=b_tensor.linear_idx_type, element_size=b_tensor.element_size], *, a_scales: TileTensor[scales_type, a_scales.LayoutType, a_scales.origin, address_space=a_scales.address_space, linear_idx_type=a_scales.linear_idx_type, element_size=a_scales.element_size], b_scales: TileTensor[scales_type, b_scales.LayoutType, b_scales.origin, address_space=b_scales.address_space, linear_idx_type=b_scales.linear_idx_type, element_size=b_scales.element_size], c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)
Overload accepting TileTensors for all operands and scale factors.
Converts TileTensor scale factors to LayoutTensor for the core dispatch which passes them through to the cublasLt backend.
matmul[c_type: DType, a_type: DType, b_type: DType, use_tf32: Bool = False, scales_type: DType = DType.invalid, a_scales_layout: Layout = Layout.row_major(-1), b_scales_layout: Layout = Layout.row_major(-1)](ctx: DeviceContext, handle: Handle[handle.backend], c_tensor: NullableTileTensor[c_type, c_tensor.LayoutType, c_tensor.origin, address_space=c_tensor.address_space, linear_idx_type=c_tensor.linear_idx_type, element_size=c_tensor.element_size], a_tensor: TileTensor[a_type, a_tensor.LayoutType, a_tensor.origin, address_space=a_tensor.address_space, linear_idx_type=a_tensor.linear_idx_type, element_size=a_tensor.element_size], b_tensor: TileTensor[b_type, b_tensor.LayoutType, b_tensor.origin, address_space=b_tensor.address_space, linear_idx_type=b_tensor.linear_idx_type, element_size=b_tensor.element_size], *, a_scales: OptionalReg[LayoutTensor[scales_type, a_scales_layout, ImmutAnyOrigin]] = None, b_scales: OptionalReg[LayoutTensor[scales_type, b_scales_layout, ImmutAnyOrigin]] = None, c_row_major: Bool = False, transpose_a: Bool = False, transpose_b: Bool = False, alpha: Float32 = 1, beta: Float32 = 0, batch_size: Int = 1)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!