Skip to main content

Mojo function

spatial_merge_kernel

spatial_merge_kernel[dtype: DType, input_origin: ImmutOrigin, input_shape_types: Variadic[CoordLike], input_stride_types: Variadic[CoordLike], output_origin: MutOrigin, output_shape_types: Variadic[CoordLike], output_stride_types: Variadic[CoordLike], grid_thw_origin: ImmutOrigin, grid_thw_shape_types: Variadic[CoordLike], grid_thw_stride_types: Variadic[CoordLike]](output: TileTensor[dtype, output_origin], input: TileTensor[dtype, input_origin], grid_thw: TileTensor[DType.int64, grid_thw_origin], batch_size: Int, hidden_size: Int, merge_size: Int)

Spatial merge kernel.

Grid: 1D over all output patches (one block per output patch). Threads: loop over channels (hidden_size x merge_size^2).

Args:

  • output (TileTensor): Output tensor.
  • input (TileTensor): Input tensor.
  • grid_thw (TileTensor): Grid dimensions tensor (B, 3) containing [t, h, w] for each item.
  • batch_size (Int): Number of items in batch.
  • hidden_size (Int): Hidden dimension size.
  • merge_size (Int): Size of spatial merge blocks.

Was this page helpful?