Skip to main content

Mojo function

mha_decoding_single_batch_amd

mha_decoding_single_batch_amd[output_type: DType, q_type: DType, k_t: MHAOperand, v_t: MHAOperand, mask_t: MHAMask, group: Int, config: MHAConfig](output: UnsafePointer[SIMD[output_type, 1]], q: UnsafePointer[SIMD[q_type, 1]], k: k_t, v: v_t, exp_sum_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], qk_max_ptr: UnsafePointer[SIMD[get_accum_type[::DType,::DType](), 1]], seq_len: Int, num_keys: Int, num_partitions: Int, scale: SIMD[float32, 1], batch_idx: Int, start_pos: Int, mask: mask_t)

Was this page helpful?