Skip to main content

Mojo function

batched_matmul

batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_a: Bool, transpose_b: Bool, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, c_buf.origin, c_buf.shape, c_buf.strides], a_buf: NDBuffer[a_type, a_buf.origin, a_buf.shape, a_buf.strides], b_buf: NDBuffer[b_type, b_buf.origin, b_buf.shape, b_buf.strides], *, context: DeviceContextPtr = DeviceContextPtr())

NDBuffer overload of batched_matmul. Converts to TileTensor and delegates.

batched_matmul[rank: Int, a_type: DType, b_type: DType, c_type: DType, //, *, transpose_b: Bool, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: NDBuffer[c_type, c_buf.origin, c_buf.shape, c_buf.strides], a_buf: NDBuffer[a_type, a_buf.origin, a_buf.shape, a_buf.strides], b_buf: NDBuffer[b_type, b_buf.origin, b_buf.shape, b_buf.strides], *, context: DeviceContextPtr = DeviceContextPtr())

NDBuffer overload of batched_matmul (no transpose_a). Converts to TileTensor and delegates.

batched_matmul[*, transpose_a: Bool = False, transpose_b: Bool = False, elementwise_epilogue_fn: Optional[elementwise_epilogue_type] = None, saturated_vnni: Bool = False, single_thread_blocking_override: Bool = False, target: StringSlice[StaticConstantOrigin] = "cpu"](c_buf: TileTensor[c_buf.dtype, c_buf.LayoutType, c_buf.origin, linear_idx_type=c_buf.linear_idx_type, element_size=c_buf.element_size], a_buf: TileTensor[a_buf.dtype, a_buf.LayoutType, a_buf.origin, linear_idx_type=a_buf.linear_idx_type, element_size=a_buf.element_size], b_buf: TileTensor[b_buf.dtype, b_buf.LayoutType, b_buf.origin, linear_idx_type=b_buf.linear_idx_type, element_size=b_buf.element_size], *, context: DeviceContextPtr = DeviceContextPtr())

TileTensor primary implementation of batched_matmul.

Was this page helpful?