Skip to main content

Python class

MHAMaskVariant

MHAMaskVariant

class max.nn.attention.MHAMaskVariant(value, names=<not given>, *values, module=None, qualname=None, type=None, start=1, boundary=None)

source

Bases: str, Enum

Defines the integer mask variant codes used by multihead attention kernels.

CAUSAL_MASK

CAUSAL_MASK = '0'

source

CHUNKED_CAUSAL_MASK

CHUNKED_CAUSAL_MASK = '3'

source

NULL_MASK

NULL_MASK = '2'

source

SLIDING_WINDOW_CAUSAL_MASK

SLIDING_WINDOW_CAUSAL_MASK = '4'

source