Mojo function
fp8_index
fp8_index[dtype: DType, output_layout: Layout, q_layout: Layout, qs_layout: Layout, k_layout: Layout, ks_layout: Layout, //, batch_size: Int, num_heads: Int, depth: Int](output: LayoutTensor[DType.float32, output_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q: LayoutTensor[dtype, q_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_s: LayoutTensor[DType.float32, qs_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: LayoutTensor[dtype, k_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k_s: LayoutTensor[DType.float32, ks_layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], valid_length: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], cache_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], max_seq_len: Int, max_num_keys: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!