Mojo function
group_norm_gpu_multi_block_stats
group_norm_gpu_multi_block_stats[StatsLayoutType: TensorLayout, stats_origin: MutOrigin, //, dtype: DType, simd_width: UInt, input_fn: fn[width: Int](row: Int, col: Int) capturing -> SIMD[dtype, width]](stats: TileTensor[get_accum_type[dtype](), StatsLayoutType, stats_origin], num_splits: Int, group_size: Int)
Multi-block stats kernel: computes partial Welford statistics per split.
Grid: num_rows * num_splits blocks. Each block handles one split of one group and writes partial (mean, m2, count) to the stats buffer. Stats layout: stats[block_idx * 3 + {0,1,2}] = {mean, m2, count}.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!