Skip to content

vllm.model_executor.layers.quantization.utils.quant_utils

This file is used for /tests and /benchmarks

GroupShape

Bases: _GroupShape

This class describes the quantization group shape. It includes static members for common shapes (per-tensor, per-token).

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
class GroupShape(_GroupShape):
    """
    This class describes the quantization group shape.
    It includes static members for common shapes (per-tensor, per-token).
    """

    # Aliases for common quantization group shapes
    PER_TENSOR: ClassVar["GroupShape"]
    PER_TOKEN: ClassVar["GroupShape"]
    PER_CHANNEL: ClassVar["GroupShape"]

    def is_per_tensor(self) -> bool:
        return self.row == -1 and self.col == -1

    def is_per_token(self) -> bool:
        return self.row == 1 and self.col == -1

    def is_per_channel(self) -> bool:
        return self.row == -1 and self.col == 1

    def is_per_group(self) -> bool:
        return self.row == 1 and self.col >= 1

QuantKey dataclass

Class for identifying the type of quantization. dtype: quantized data type scale: scale descriptor scale2: second-level scale descriptor symmetric: symmetric if True, asymmetric if False

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
@dataclass(frozen=True)
class QuantKey:
    """
    Class for identifying the type of quantization.
    dtype: quantized data type
    scale: scale descriptor
    scale2: second-level scale descriptor
    symmetric: symmetric if True, asymmetric if False
    """

    dtype: torch.dtype
    scale: ScaleDesc
    scale2: ScaleDesc | None = None
    symmetric: bool = True

    def __str__(self):
        scale2_str = f"scale2({self.scale2})," if self.scale2 else ""
        return (
            f"QuantKey({fx.graph.dtype_abbrs[self.dtype]},"
            f"scale({self.scale}),{scale2_str}"
            f"{'a' if not self.symmetric else ''}symmetric)"
        )

ScaleDesc dataclass

Class for describing a single quantization scaling factor. dtype: data type of the scale static: static scale if True, dynamic if False group_shape: group shape of the scale

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
@dataclass(frozen=True)
class ScaleDesc:
    """
    Class for describing a single quantization scaling factor.
    dtype: data type of the scale
    static: static scale if True, dynamic if False
    group_shape: group shape of the scale
    """

    dtype: torch.dtype
    static: bool
    group_shape: GroupShape

    def __str__(self):
        d = {
            GroupShape.PER_TENSOR: "per_tensor",
            GroupShape.PER_TOKEN: "per_token",
            GroupShape.PER_CHANNEL: "per_channel",
        }
        group_shape = d.get(self.group_shape, str(self.group_shape))
        return (
            f"{fx.graph.dtype_abbrs[self.dtype]},"
            f"{'static' if self.static else 'dynamic'},{group_shape}"
        )

convert_bf16_scales_to_fp8

convert_bf16_scales_to_fp8(
    quant_fp8: Callable, scales: Tensor
) -> tuple[Tensor, Tensor]

Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales) expected by W4A8 GEMM kernels.

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def convert_bf16_scales_to_fp8(
    quant_fp8: Callable, scales: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Convert a BF16 scale tensor into the pair of (fp8_scales, channel_scales)
    expected by W4A8 GEMM kernels.
    """
    assert scales.is_contiguous(), (
        f"scale tensor must be contiguous, got {scales.stride()=}"
    )
    assert scales.is_cuda, "scales must be on gpu"

    orig_shape = scales.shape
    k_groups = orig_shape[-1]
    flat_scales = scales.view(-1, k_groups)

    fp8_scales, chan_scales = quant_fp8(flat_scales)
    fp8_scales = (fp8_scales.float() / 8.0).to(torch.float8_e4m3fn)
    chan_scales *= 8.0

    # restore original shape
    fp8_scales = fp8_scales.view(orig_shape)
    chan_scales = chan_scales.view(orig_shape[:-1], -1)

    return fp8_scales, chan_scales

convert_packed_uint4b8_to_signed_int4_inplace

convert_packed_uint4b8_to_signed_int4_inplace(
    t: Tensor,
) -> Tensor

Convert int4b8 (packed to int32) to signed int4

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def convert_packed_uint4b8_to_signed_int4_inplace(t: torch.Tensor) -> torch.Tensor:
    """
    Convert int4b8 (packed to int32) to signed int4
    """
    assert t.is_cuda, "tensor must be on gpu"
    assert t.dtype == torch.int32, f"expected int32 packed weights but got {t.dtype}"

    # loop through the 8 4-bit nibbles in each int32 entry
    for i in range(8):
        shift = 4 * i
        # extract the i-th 4-bit nibble
        nib = (t >> shift) & 0xF
        # clear the original nibble by masking out
        t &= ~(0xF << shift)
        # convert int4b8 [0..15] to signed int4 [-8..7] by subtracting 8
        # and update in-place
        t |= ((nib - 8) & 0xF) << shift

    return t

get_and_maybe_dequant_weights

get_and_maybe_dequant_weights(
    layer: LinearBase, out_dtype: dtype = float32
)

Return layer's unquantized weights in [out, in] layout

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def get_and_maybe_dequant_weights(
    layer: "LinearBase", out_dtype: torch.dtype = torch.float32
):
    """Return layer's unquantized weights in [out, in] layout"""
    from vllm.model_executor.layers.linear import UnquantizedLinearMethod
    from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod

    weight = get_attribute_fallback(layer, ["weight", "qweight", "weight_packed"])

    # Unquantized layer: just return base weights
    if layer.quant_method is None or isinstance(
        layer.quant_method, UnquantizedLinearMethod
    ):
        return weight.to(out_dtype)

    # Simple Fp8 case: rescale with tensor or block weight scales
    if (
        isinstance(layer.quant_method, Fp8LinearMethod)
        and not layer.quant_method.use_marlin
        # DeepGEMM transforms the scales using `transform_sf_into_required_layout` into
        # a layout that is not compatible with `scaled_dequantize`.
        and not layer.quant_method.use_deep_gemm
    ):
        weight_scales = get_attribute_fallback(
            layer, ["weight_scale", "weight_scale_inv"]
        )
        dequant_weights = scaled_dequantize(
            weight,
            weight_scales,
            group_shape=layer.weight_block_size,
            out_dtype=out_dtype,
        )
        # per-tensor scaling stores weights in [in, out] layout
        if not layer.quant_method.block_quant:
            dequant_weights = dequant_weights.T
        return dequant_weights

    # NOTE: Most generic base case
    # - Call the layer with identity matrix which returns unquantized weights.
    # - Must be used with extra care when dealing with static activation quantization:
    #   quantizing 1.0 may lead to over/underflows
    # - Should only be used offline, since it's O(N^3)
    assert hasattr(layer, "input_size_per_partition")
    eye = torch.eye(
        layer.input_size_per_partition,
        dtype=out_dtype,
        device=weight.device,
    )
    dequant_weights = layer.quant_method.apply(layer, eye, bias=None).to(out_dtype)
    return dequant_weights.T

get_fp8_min_max

get_fp8_min_max() -> tuple[float, float]

Get the min and max values for FP8 quantization.

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def get_fp8_min_max() -> tuple[float, float]:
    """Get the min and max values for FP8 quantization."""
    # Using the default value (240.0) from pytorch will cause accuracy
    # issue on dynamic quantization models on ROCm. Here, use 224.0 for fnuz
    # on ROCm platforms that use the torch.float8_e4m3fnuz dtype.
    if current_platform.is_fp8_fnuz():
        return -224.0, 224.0
    finfo = torch.finfo(current_platform.fp8_dtype())
    return finfo.min, finfo.max

prep_scale_for_group_broadcast

prep_scale_for_group_broadcast(
    scale: Tensor, x: Tensor, group_shape: GroupShape | None
) -> Tensor

Prepare the input quantization scale for group broadcasting.

Parameters:

Name Type Description Default
scale Tensor

The scale tensor (scalar or 1D).

required
x Tensor

Target tensor whose shape determines broadcast dimensions.

required
group_shape GroupShape | None

GroupShape to broadcast over.

required

Returns:

Type Description
Tensor

scale reshaped for correct broadcasting.

Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def prep_scale_for_group_broadcast(
    scale: torch.Tensor,
    x: torch.Tensor,
    group_shape: GroupShape | None,
) -> torch.Tensor:
    """
    Prepare the input quantization scale for group broadcasting.

    Args:
        scale: The scale tensor (scalar or 1D).
        x: Target tensor whose shape determines broadcast dimensions.
        group_shape: GroupShape to broadcast over.

    Returns:
        scale reshaped for correct broadcasting.
    """
    if scale.numel() == 1:
        # For per-tensor quant, keep the scale as a scalar (not reshaped to (1, 1)).
        # This avoids misclassifying it as channelwise quant in Fp8LinearOp.apply,
        # where the "per_tensor_activations" check relies on "x_scale.dim() < 2":
        #   per_tensor_activations = (x_scale.numel() == 1) and x_scale.dim() < 2
        # For all other cases, reshape scalar scales to (1, 1) for broadcasting.
        return (
            scale
            if group_shape is not None and group_shape.is_per_tensor()
            else scale.reshape(1, 1)
        )
    if scale.ndim == 1:
        assert group_shape is not None, (
            "group_shape must be provided to correctly broadcast 1D scale"
        )
        rows, cols = _normalize_quant_group_shape(x, group_shape)
        # Determine broadcasting dimension: either rows or columns match group size
        if rows == x.shape[-2]:
            scale = scale.unsqueeze(-2)
        elif cols == x.shape[-1]:
            scale = scale.unsqueeze(-1)
        else:
            raise ValueError(
                f"1D scale with shape {scale.shape} cannot be broadcast to x with shape"
                f" {x.shape}, group_shape={(rows, cols)}"
            )
    return scale

scaled_quantize

scaled_quantize(
    x: Tensor,
    group_shape: GroupShape,
    quant_dtype: dtype,
    compute_dtype: dtype | None = None,
) -> tuple[Tensor, Tensor]

Parameters:

Name Type Description Default
x Tensor

Input tensor to quantize

required
group_shape GroupShape

Shape of quantization groups

required
quant_dtype dtype

Target quantized dtype (e.g., torch.float8_e4m3fn)

required
compute_dtype dtype | None

Optional dtype for intermediate computations. If None, uses input dtype. Use torch.float32 for higher precision.

None
Source code in vllm/model_executor/layers/quantization/utils/quant_utils.py
def scaled_quantize(
    x: torch.Tensor,
    group_shape: GroupShape,
    quant_dtype: torch.dtype,
    compute_dtype: torch.dtype | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Args:
        x: Input tensor to quantize
        group_shape: Shape of quantization groups
        quant_dtype: Target quantized dtype (e.g., torch.float8_e4m3fn)
        compute_dtype: Optional dtype for intermediate computations.
            If None, uses input dtype. Use torch.float32 for higher precision.
    """
    group_shape = _normalize_quant_group_shape(x, group_shape)
    assert quant_dtype.is_floating_point, (
        "currently `scaled_quantize` only supports floating point dtypes "
        "but could be extended to support other dtypes"
    )

    finfo = torch.finfo(quant_dtype)

    # Convert to compute dtype if specified
    x_compute = x if compute_dtype is None else x.to(compute_dtype)

    # Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
    assert x.ndim == 2
    assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
    blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
    x_blkd = x_compute.reshape(blk_m, group_shape[0], blk_n, group_shape[1])

    # Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
    x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
    # Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
    x_blkd_permd = x_blkd_permd.flatten(start_dim=2)

    # Compute scales
    min_val, max_val = x_blkd_permd.aminmax(dim=-1)
    amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
    _, fp8_max = get_fp8_min_max()
    scale = fp8_max / amax

    # Apply scale and convert from:
    # (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
    x_scl_sat = (
        (x_blkd_permd * scale.unsqueeze(-1))
        .clamp(min=finfo.min, max=finfo.max)
        .reshape(blk_m, blk_n, group_shape[0], group_shape[1])
        .permute(0, 2, 1, 3)
        .reshape(x.shape)
    )

    return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()