Skip to main content

Mojo function

flash_attention_kv_cache

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]] = OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]]({:i1 0, 1}))

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]] = OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]]({:i1 0, 1}))

flash_attention_kv_cache[dtype: DType, cache_t: KVCacheT, mask_t: MHAMask, //](q: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], q_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], kv_input_row_offsets: LayoutTensor[DType.uint32, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], k: cache_t, v: cache_t, mask: mask_t, scale: Float32, output: LayoutTensor[dtype, layout, origin, element_layout=element_layout, layout_int_type=layout_int_type, linear_idx_type=linear_idx_type, masked=masked, alignment=alignment], sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]] = OptionalReg[LayoutTensor[dtype, Layout.row_major(-1), MutableAnyOrigin]]({:i1 0, 1}))

Entrypoint for ragged tensors.

Was this page helpful?