Skip to main content

Mojo function

flash_attention_kv_cache

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, q_origin: ImmutOrigin, output_origin: MutOrigin, //](q: LayoutTensor[dtype, q.layout, q_origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: cache_t, v: cache_t, mask: LayoutTensor[dtype, mask.layout, mask.origin, element_layout=mask.element_layout, layout_int_type=mask.layout_int_type, linear_idx_type=mask.linear_idx_type, masked=mask.masked, alignment=mask.alignment], scale: Float32, output: LayoutTensor[dtype, output.layout, output_origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, q_origin: ImmutOrigin, output_origin: MutOrigin, //](q: LayoutTensor[dtype, q.layout, q_origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, output.layout, output_origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, q_origin: ImmutOrigin, output_origin: MutOrigin, //](q: LayoutTensor[dtype, q.layout, q_origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], q_input_row_offsets: LayoutTensor[DType.uint32, q_input_row_offsets.layout, q_input_row_offsets.origin, element_layout=q_input_row_offsets.element_layout, layout_int_type=q_input_row_offsets.layout_int_type, linear_idx_type=q_input_row_offsets.linear_idx_type, masked=q_input_row_offsets.masked, alignment=q_input_row_offsets.alignment], kv_input_row_offsets: LayoutTensor[DType.uint32, kv_input_row_offsets.layout, kv_input_row_offsets.origin, element_layout=kv_input_row_offsets.element_layout, layout_int_type=kv_input_row_offsets.layout_int_type, linear_idx_type=kv_input_row_offsets.linear_idx_type, masked=kv_input_row_offsets.masked, alignment=kv_input_row_offsets.alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, output.layout, output_origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

Entrypoint for ragged tensors.

Was this page helpful?