Python class
MHAMaskVariant
MHAMaskVariant
class max.nn.attention.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)
Defines the integer mask variant codes used by multihead attention kernels.
CAUSAL_MASK
CAUSAL_MASK = '0'
CHUNKED_CAUSAL_MASK
CHUNKED_CAUSAL_MASK = '3'
NULL_MASK
NULL_MASK = '2'
SLIDING_WINDOW_CAUSAL_MASK
SLIDING_WINDOW_CAUSAL_MASK = '4'
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!