Skip to main content

Mojo function

flash_attention

flash_attention[dtype: DType, rank: Int, mask_rank: Int, q_origin: ImmutOrigin, output_origin: MutOrigin, //, input_k_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_v_fn: def[simd_width: Int, rank: Int](IndexList[rank]) capturing -> SIMD[dtype, simd_width], input_mask_fn: def[simd_width: Int, mask_rank: Int](IndexList[mask_rank]) capturing -> SIMD[dtype, simd_width]](q: LayoutTensor[dtype, q.layout, q_origin, element_layout=q.element_layout, layout_int_type=q.layout_int_type, linear_idx_type=q.linear_idx_type, masked=q.masked, alignment=q.alignment], k_shape: IndexList[rank], v_shape: IndexList[rank], mask_shape: IndexList[mask_rank], output: LayoutTensor[dtype, output.layout, output_origin, element_layout=output.element_layout, layout_int_type=output.layout_int_type, linear_idx_type=output.linear_idx_type, masked=output.masked, alignment=output.alignment], scale: Float32, sink_weights: OptionalReg[LayoutTensor[dtype, Layout.row_major(VariadicList(-1)), ImmutAnyOrigin]] = None)

Was this page helpful?