Mojo function
moe_create_indices_bucket_sort_kernel
moe_create_indices_bucket_sort_kernel[input_type: DType, token_expert_order_layout: Layout, expert_start_indices_layout: Layout, restore_token_order_layout: Layout, expert_ids_layout: Layout, expert_usage_stats_layout: Layout, topk_ids_layout: Layout, num_threads: Int = _resolve_warp_size(), expected_count: Int = 8192](token_expert_order: LayoutTensor[DType.uint32, token_expert_order_layout, MutableAnyOrigin], lock: LayoutTensor[DType.uint32, Layout.row_major(1), MutableAnyOrigin], expert_start_indices: LayoutTensor[DType.uint32, expert_start_indices_layout, MutableAnyOrigin], restore_token_order: LayoutTensor[DType.uint32, restore_token_order_layout, MutableAnyOrigin], expert_ids: LayoutTensor[DType.int32, expert_ids_layout, MutableAnyOrigin], expert_usage_stats: LayoutTensor[DType.uint32, expert_usage_stats_layout, MutableAnyOrigin], topk_ids: LayoutTensor[input_type, topk_ids_layout, MutableAnyOrigin])
The main goal of this kernel is to group tokens that use the same expert together. This allows for efficent batching when used by other kernels such as grouped matmul.
topk_ids: a 1D tensor of expert ids, the index of each expert_id cooresponds to a token. For example if topk_ids is [1, 0, 1, 3, 4, 2], then the cooresponding tokens are [0, 1, 2, 3, 4, 5]
token_expert_order: a 1D tensor of tokens grouped together by expert id. Using the previous topk_ids, the token expert order could be [0, 2, 1, 3, 4, 5]
expert_ids: a 1D tensor of all the experts that are being used. Using the previous topk_ids the our expert_ids would be [1, 0, 3, 4, 2]
expert_start_indices: tells us where each expert starts and end in the token_expert_order. Based on the order of our expert_ids our expert_start_indices would be [0, 2, 3, 4, 5, 6]. So if you wanted to see where expert 1 starts and ends you would get the index 'i' of expert 1 in expert_ids and would query expert_start_indices[i] and query expert_start_indices[i + 1] which is 0 and 2 respectively.
lock: a 1D tensor that holds a single scalar value, this single integer will be used to atomically synchronize the writes back to global memory. It will do this by storing how many blocks have finished writing and the current global memory offset.
expert_usage_stats: contains two values, the maximum number of tokens assigned to any expert and the number of active experts. For our example the stats would be [2, 5]
restore_token_order: a 1D tensor where each index represents a cooresponding token and holds the new index of the token in the token_expert_order tensor. For our example the restore_token_order would be [0, 2, 1, 3, 4, 5]
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!