Mojo function
spatial_merge_kernel
spatial_merge_kernel[dtype: DType, input_layout: Layout, output_layout: Layout, grid_thw_layout: Layout](output: LayoutTensor[dtype, output_layout, MutAnyOrigin], input: LayoutTensor[dtype, input_layout, MutAnyOrigin], grid_thw: LayoutTensor[DType.int64, grid_thw_layout, MutAnyOrigin], batch_size: Int, hidden_size: Int, merge_size: Int)
Spatial merge kernel.
Grid: 1D over all output patches (one block per output patch). Threads: loop over channels (hidden_size x merge_size^2).
Args:
- output (
LayoutTensor): Output tensor. - input (
LayoutTensor): Input tensor. - grid_thw (
LayoutTensor): Grid dimensions tensor (B, 3) containing [t, h, w] for each item. - batch_size (
Int): Number of items in batch. - hidden_size (
Int): Hidden dimension size. - merge_size (
Int): Size of spatial merge blocks.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!