Mojo function
spatial_merge_kernel
spatial_merge_kernel[dtype: DType, InputLayoutType: TensorLayout, input_origin: ImmutOrigin, OutputLayoutType: TensorLayout, output_origin: MutOrigin, GridThwLayoutType: TensorLayout, grid_thw_origin: ImmutOrigin](output: TileTensor[dtype, OutputLayoutType, output_origin], input: TileTensor[dtype, InputLayoutType, input_origin], grid_thw: TileTensor[DType.int64, GridThwLayoutType, grid_thw_origin], batch_size: Int, hidden_size: Int, merge_size: Int)
Spatial merge kernel.
Grid: 1D over all output patches (one block per output patch). Threads: loop over channels (hidden_size x merge_size^2).
Args:
- βoutput (
TileTensor[dtype, OutputLayoutType, output_origin]): Output tensor. - βinput (
TileTensor[dtype, InputLayoutType, input_origin]): Input tensor. - βgrid_thw (
TileTensor[DType.int64, GridThwLayoutType, grid_thw_origin]): Grid dimensions tensor (B, 3) containing [t, h, w] for each item. - βbatch_size (
Int): Number of items in batch. - βhidden_size (
Int): Hidden dimension size. - βmerge_size (
Int): Size of spatial merge blocks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!