Mojo function
fp8_index_naive
fp8_index_naive[dtype: DType, output_layout: Layout, q_layout: Layout, qs_layout: Layout, k_layout: Layout, ks_layout: Layout, //, num_heads: Int, depth: Int](output: LayoutTensor[DType.float32, output_layout, output.origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], q: LayoutTensor[dtype, q_layout, q.origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], q_s: LayoutTensor[DType.float32, qs_layout, q_s.origin, element_layout=q_s.element_layout, layout_int_type=q_s.layout_int_type, linear_idx_type=q_s.linear_idx_type, masked=q_s.masked, alignment=q_s.alignment], k: LayoutTensor[dtype, k_layout, k.origin, element_layout=k.element_layout, layout_int_type=k.layout_int_type, linear_idx_type=k.linear_idx_type, masked=k.masked, alignment=k.alignment], k_s: LayoutTensor[DType.float32, ks_layout, k_s.origin, element_layout=k_s.element_layout, layout_int_type=k_s.layout_int_type, linear_idx_type=k_s.linear_idx_type, masked=k_s.masked, alignment=k_s.alignment], valid_length: LayoutTensor[DType.uint32, valid_length.layout, valid_length.origin, element_layout=valid_length.element_layout, layout_int_type=valid_length.layout_int_type, linear_idx_type=valid_length.linear_idx_type, masked=valid_length.masked, alignment=valid_length.alignment], cache_row_offsets: LayoutTensor[DType.uint32, cache_row_offsets.layout, cache_row_offsets.origin, element_layout=cache_row_offsets.element_layout, layout_int_type=cache_row_offsets.layout_int_type, linear_idx_type=cache_row_offsets.linear_idx_type, masked=cache_row_offsets.masked, alignment=cache_row_offsets.alignment], batch_size: Int, max_seq_len: Int, max_num_keys: Int, ctx: DeviceContext)
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!