Skip to content

vllm.model_executor.layers.quantization.fp_quant

FPQuantConfig

Bases: QuantizationConfig

Config class for FPQuant.

Source code in vllm/model_executor/layers/quantization/fp_quant.py
class FPQuantConfig(QuantizationConfig):
    """Config class for FPQuant."""

    def __init__(
        self,
        hadamard_group_size: int = 32,
        forward_dtype: str = "mxfp4",
        forward_method: str = "abs_max",
        pseudoquantization: bool = False,
        modules_to_not_convert: list[str] | None = None,
    ) -> None:
        super().__init__()
        self.hadamard_group_size = hadamard_group_size
        self.forward_dtype = forward_dtype
        self.forward_method = forward_method
        self.pseudoquantization = pseudoquantization
        self.modules_to_not_convert = modules_to_not_convert

        if pseudoquantization:
            raise ValueError("Pseudoquantization is not supported for vLLM")

    def __repr__(self) -> str:
        return (
            f"FPQuantConfig(hadamard_group_size={self.hadamard_group_size}, "
            f"forward_dtype={self.forward_dtype}, "
            f"forward_method={self.forward_method}, "
            f"pseudoquantization={self.pseudoquantization}, "
            f"modules_to_not_convert={self.modules_to_not_convert})"
        )

    @classmethod
    def get_name(cls) -> QuantizationMethods:
        return "fp_quant"

    @classmethod
    def get_supported_act_dtypes(cls) -> list[torch.dtype]:
        return [torch.bfloat16]

    @classmethod
    def get_min_capability(cls) -> int:
        return 100

    @classmethod
    def get_config_filenames(cls) -> list[str]:
        return []  # no extra configs.

    @classmethod
    def from_config(cls, config: dict[str, Any]) -> "FPQuantConfig":
        hadamard_group_size = cls.get_from_keys(config, ["hadamard_group_size"])
        forward_dtype = cls.get_from_keys(config, ["forward_dtype"])
        forward_method = cls.get_from_keys(config, ["forward_method"])
        pseudoquantization = cls.get_from_keys(config, ["pseudoquantization"])
        modules_to_not_convert = cls.get_from_keys(config, ["modules_to_not_convert"])
        return cls(
            hadamard_group_size,
            forward_dtype,
            forward_method,
            pseudoquantization,
            modules_to_not_convert,
        )

    def get_quant_method(
        self, layer: torch.nn.Module, prefix: str
    ) -> LinearMethodBase | None:
        if self.modules_to_not_convert is not None and any(
            prefix.endswith(module) for module in self.modules_to_not_convert
        ):
            return UnquantizedLinearMethod()

        if isinstance(layer, LinearBase):
            return FPQuantLinearMethod(self)
        return None

FPQuantLinearMethod

Bases: LinearMethodBase

Linear method for FPQuant.

Parameters:

Name Type Description Default
quant_config FPQuantConfig

The FPQuant quantization config.

required
Source code in vllm/model_executor/layers/quantization/fp_quant.py
class FPQuantLinearMethod(LinearMethodBase):
    """Linear method for FPQuant.

    Args:
        quant_config: The FPQuant quantization config.
    """

    def __init__(self, quant_config: FPQuantConfig):
        self.quant_config = quant_config

    def create_weights(
        self,
        layer: torch.nn.Module,
        input_size_per_partition: int,
        output_partition_sizes: list[int],
        input_size: int,
        output_size: int,
        params_dtype: torch.dtype,
        **extra_weight_attrs,
    ):
        del output_size  # Unused.
        del input_size  # Unused.

        if params_dtype != torch.bfloat16:
            raise ValueError("Only bfloat16 is currently supported by FPQuant")
        if input_size_per_partition % self.quant_config.hadamard_group_size != 0:  # noqa: E501
            raise ValueError(
                "The input size is not aligned with the quantized "
                "weight shape. This can be caused by too large "
                "tensor parallel size. Or other skill issues."
            )

        assert self.quant_config.forward_dtype in ["mxfp4", "nvfp4"], (
            "Only mxfp4 and nvfp4 are supported for now"
        )
        if self.quant_config.forward_dtype == "mxfp4":
            group_size = 32
        elif self.quant_config.forward_dtype == "nvfp4":
            group_size = 16
        else:
            raise ValueError(
                f"Unsupported forward_dtype: {self.quant_config.forward_dtype}"
            )

        qweight = Parameter(
            torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition // 2,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            qweight,
            {
                "input_dim": 1,
                "output_dim": 0,
                "packed_dim": 1,
                "pack_factor": 2,
            }
            | extra_weight_attrs,
        )
        layer.register_parameter("qweight", qweight)

        scales = Parameter(
            torch.empty(
                sum(output_partition_sizes),
                input_size_per_partition // group_size,
                dtype=torch.uint8,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            scales,
            {
                "input_dim": 1,
                "output_dim": 0,
                "packed_dim": 1,
                "pack_factor": group_size,
            }
            | extra_weight_attrs,
        )
        layer.register_parameter("scales", scales)

        weight_global_scale = Parameter(
            torch.empty(1, dtype=torch.float32),
            requires_grad=False,
        )
        set_weight_attrs(
            weight_global_scale, {"ignore_warning": True} | extra_weight_attrs
        )
        layer.register_parameter("weight_global_scale", weight_global_scale)

        act_global_scale = Parameter(
            torch.empty(1, dtype=torch.float32),
            requires_grad=False,
        )
        set_weight_attrs(
            act_global_scale, {"ignore_warning": True} | extra_weight_attrs
        )
        layer.register_parameter("act_global_scale", act_global_scale)

        forward_hadamard_matrix = Parameter(
            torch.empty(
                self.quant_config.hadamard_group_size,
                self.quant_config.hadamard_group_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            forward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
        )
        layer.register_parameter("forward_hadamard_matrix", forward_hadamard_matrix)

        backward_hadamard_matrix = Parameter(
            torch.empty(
                self.quant_config.hadamard_group_size,
                self.quant_config.hadamard_group_size,
                dtype=params_dtype,
            ),
            requires_grad=False,
        )
        set_weight_attrs(
            backward_hadamard_matrix, {"ignore_warning": True} | extra_weight_attrs
        )
        layer.register_parameter("backward_hadamard_matrix", backward_hadamard_matrix)

    def apply(
        self,
        layer: torch.nn.Module,
        x: torch.Tensor,
        bias: torch.Tensor | None = None,
    ) -> torch.Tensor:
        return quantized_forward(
            x,
            layer.qweight,
            layer.scales,
            layer.weight_global_scale,
            layer.act_global_scale,
            bias,
            layer.forward_hadamard_matrix,
            self.quant_config.forward_method,
            self.quant_config.forward_dtype,
        )