Mojo function
flatten_leading
flatten_leading[dtype: DType, layout: TensorLayout, //](tensor: TileTensor[dtype, layout, tensor.origin, address_space=tensor.address_space, linear_idx_type=tensor.linear_idx_type, element_size=tensor.element_size]) -> TileTensor[dtype, Layout[RuntimeInt[DType.int64], layout._shape_types[(add layout.rank._mlir_value, -1)], #kgen.variadic.concat(RuntimeInt[layout._shape_types[(add layout.rank._mlir_value, -1)].DTYPE if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__()._mlir_value else DType.int] if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__() if layout._shape_types[(add layout.rank._mlir_value, -1)].is_static_value.__bool__().__invert__()._mlir_value else False else ComptimeInt[(layout._shape_types[(add layout.rank._mlir_value, -1)].static_value * ComptimeInt[1].static_value)], ComptimeInt[1])], tensor.origin, address_space=tensor.address_space, linear_idx_type=tensor.linear_idx_type, element_size=tensor.element_size]
Merge the first two dimensions of a rank-3 TileTensor: (A, B, C) -> (A*B, C).
Returns a new TileTensor sharing the same pointer with row-major strides computed from the merged shape. Zero-cost operation.
Common use case: converting 3D batched tensors (num_experts, N, K) to 2D (num_experts*N, K) for TMA descriptor creation in MoE kernels.
Parameters:
- โdtype (
DType): Element type (inferred from tensor). - โlayout (
TensorLayout): Layout type (inferred from tensor).
Args:
- โtensor (
TileTensor): A rank-3 TileTensor.
Returns:
TileTensor: A rank-2 TileTensor where dim[0] = old dim[0] * dim[1].
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!