For the complete documentation index, see llms.txt. Markdown versions of all pages are available by appending .md to any URL (e.g. /max/get-started.md).
Mojo function
apply_packed_bitmask
def apply_packed_bitmask[dtype: DType, //, target: StringSlice[StaticConstantOrigin]](output: TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size], logits: TileTensor[dtype, Storage=logits.Storage, address_space=logits.address_space, linear_idx_type=logits.linear_idx_type, element_size=logits.element_size], packed: TileTensor[DType.int32, Storage=packed.Storage, address_space=packed.address_space, linear_idx_type=packed.linear_idx_type, element_size=packed.element_size], fill_value: Scalar[dtype], ctx: DeviceContext)
Apply a packed-int32 grammar bitmask to logits in a single fused pass.
Unpacks a packed bitmask (1 bit per token, 32 tokens per int32 word) and
masks logits with it without ever materializing a bool tensor: for each
(b, v), the token is kept when bit v % 32 of word packed[b, v // 32]
is set, otherwise output[b, v] is set to fill_value (the masked-out
sentinel, e.g. a large negative number). This replaces a CPU unpack +
ops.where in constrained decoding.
Args:
- βoutput (
TileTensor[dtype, Storage=output.Storage, address_space=output.address_space, linear_idx_type=output.linear_idx_type, element_size=output.element_size]): Masked logits, shape[batch, vocab]. - βlogits (
TileTensor[dtype, Storage=logits.Storage, address_space=logits.address_space, linear_idx_type=logits.linear_idx_type, element_size=logits.element_size]): Input logits, shape[batch, vocab]. - βpacked (
TileTensor[DType.int32, Storage=packed.Storage, address_space=packed.address_space, linear_idx_type=packed.linear_idx_type, element_size=packed.element_size]): Packedint32bitmask, shape[batch, ceil(vocab / 32)]. A set bit means the token is grammar-valid. Extra trailing bits beyondvocab(32-bit alignment padding from llguidance) are never read. - βfill_value (
Scalar[dtype]): Value written for masked-out (grammar-invalid) tokens. - βctx (
DeviceContext): The device context.
Was this page helpful?
Thank you! We'll create more content like this.
Thank you for helping us improve!