Skip to content

vllm.utils.deep_gemm

Compatibility wrapper for DeepGEMM API changes.

Users of vLLM should always import only these wrappers.

DeepGemmQuantScaleFMT

Bases: Enum

Source code in vllm/utils/deep_gemm.py
class DeepGemmQuantScaleFMT(Enum):
    # Float32 scales in Float32 tensor
    FLOAT32 = 0
    # Compute float32 scales and ceil the scales to UE8M0.
    # Keep the scales in Float32 tensor.
    FLOAT32_CEIL_UE8M0 = 1
    # Compute float32 scales and ceil the scales to UE8M0.
    # Pack the scales into a int32 tensor where each int32
    # element contains 4 scale values.
    UE8M0 = 2

    @classmethod
    def init_oracle_cache(cls) -> None:
        """Initialize the oracle decision and store it in the class cache"""
        cached = getattr(cls, "_oracle_cache", None)
        if cached is not None:
            return

        use_e8m0 = (
            envs.VLLM_USE_DEEP_GEMM_E8M0
            and is_deep_gemm_supported()
            and (_fp8_gemm_nt_impl is not None)
        )
        if not use_e8m0:
            cls._oracle_cache = cls.FLOAT32  # type: ignore
            return

        cls._oracle_cache = (  # type: ignore
            cls.UE8M0
            if current_platform.is_device_capability_family(100)
            else cls.FLOAT32_CEIL_UE8M0
        )

    @classmethod
    def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
        """Return the pre-initialized oracle decision"""
        cached = getattr(cls, "_oracle_cache", None)
        assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
        return cached

from_oracle classmethod

from_oracle() -> DeepGemmQuantScaleFMT

Return the pre-initialized oracle decision

Source code in vllm/utils/deep_gemm.py
@classmethod
def from_oracle(cls) -> "DeepGemmQuantScaleFMT":
    """Return the pre-initialized oracle decision"""
    cached = getattr(cls, "_oracle_cache", None)
    assert cached is not None, "DeepGemmQuantScaleFMT oracle cache not initialized"
    return cached

init_oracle_cache classmethod

init_oracle_cache() -> None

Initialize the oracle decision and store it in the class cache

Source code in vllm/utils/deep_gemm.py
@classmethod
def init_oracle_cache(cls) -> None:
    """Initialize the oracle decision and store it in the class cache"""
    cached = getattr(cls, "_oracle_cache", None)
    if cached is not None:
        return

    use_e8m0 = (
        envs.VLLM_USE_DEEP_GEMM_E8M0
        and is_deep_gemm_supported()
        and (_fp8_gemm_nt_impl is not None)
    )
    if not use_e8m0:
        cls._oracle_cache = cls.FLOAT32  # type: ignore
        return

    cls._oracle_cache = (  # type: ignore
        cls.UE8M0
        if current_platform.is_device_capability_family(100)
        else cls.FLOAT32_CEIL_UE8M0
    )

_lazy_init

_lazy_init() -> None

Import deep_gemm and resolve symbols on first use.

Source code in vllm/utils/deep_gemm.py
def _lazy_init() -> None:
    """Import deep_gemm and resolve symbols on first use."""
    global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl
    global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl
    global _get_paged_mqa_logits_metadata_impl
    global _get_mn_major_tma_aligned_tensor_impl
    global _get_mk_alignment_for_contiguous_layout_impl
    global _transform_sf_into_required_layout_impl
    # fast path
    if (
        _fp8_gemm_nt_impl is not None
        or _grouped_impl is not None
        or _grouped_masked_impl is not None
        or _fp8_mqa_logits_impl is not None
        or _fp8_paged_mqa_logits_impl is not None
        or _get_paged_mqa_logits_metadata_impl is not None
        or _get_mk_alignment_for_contiguous_layout_impl is not None
        or _transform_sf_into_required_layout_impl is not None
    ):
        return

    if not has_deep_gemm():
        return

    # Set up deep_gemm cache path
    DEEP_GEMM_JIT_CACHE_ENV_NAME = "DG_JIT_CACHE_DIR"
    if not os.environ.get(DEEP_GEMM_JIT_CACHE_ENV_NAME, None):
        os.environ[DEEP_GEMM_JIT_CACHE_ENV_NAME] = os.path.join(
            envs.VLLM_CACHE_ROOT, "deep_gemm"
        )

    _dg = importlib.import_module("deep_gemm")

    _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None)
    _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None)
    _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None)
    _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None)
    _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None)
    _get_paged_mqa_logits_metadata_impl = getattr(
        _dg, "get_paged_mqa_logits_metadata", None
    )
    _get_mn_major_tma_aligned_tensor_impl = getattr(
        _dg, "get_mn_major_tma_aligned_tensor", None
    )
    _get_mk_alignment_for_contiguous_layout_impl = getattr(
        _dg, "get_mk_alignment_for_contiguous_layout", None
    )
    _transform_sf_into_required_layout_impl = getattr(
        _dg, "transform_sf_into_required_layout", None
    )
    DeepGemmQuantScaleFMT.init_oracle_cache()

_missing

_missing(*_: Any, **__: Any) -> NoReturn

Placeholder for unavailable DeepGEMM backend.

Source code in vllm/utils/deep_gemm.py
def _missing(*_: Any, **__: Any) -> NoReturn:
    """Placeholder for unavailable DeepGEMM backend."""
    raise RuntimeError(
        "DeepGEMM backend is not available or outdated. Please install or "
        "update the `deep_gemm` to a newer version to enable FP8 kernels."
    )

calc_diff

calc_diff(x: Tensor, y: Tensor)

Return a global difference metric for unit tests.

DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element error, causing torch.testing.assert_close to fail. Instead of checking every element, we compute a cosine-style similarity over the whole tensor and report 1 - sim. Once kernel accuracy improves this helper can be removed.

Source code in vllm/utils/deep_gemm.py
def calc_diff(x: torch.Tensor, y: torch.Tensor):
    """Return a global difference metric for unit tests.

    DeepGEMM kernels on Blackwell/B200 currently exhibit noticeable per-element
    error, causing `torch.testing.assert_close` to fail.  Instead of checking
    every element, we compute a cosine-style similarity over the whole tensor
    and report `1 - sim`.  Once kernel accuracy improves this helper can be
    removed.
    """

    x, y = x.double(), y.double()
    denominator = (x * x + y * y).sum()
    sim = 2 * (x * y).sum() / denominator
    return 1 - sim

fp8_mqa_logits

fp8_mqa_logits(
    q: Tensor,
    kv: tuple[Tensor, Tensor],
    weights: Tensor,
    cu_seqlen_ks: Tensor,
    cu_seqlen_ke: Tensor,
    clean_logits: bool,
) -> Tensor

Compute FP8 MQA logits for a single sequence without KV paging.

Parameters:

Name Type Description Default
q Tensor

Query tensor of shape [M, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv tuple[Tensor, Tensor]

Tuple (k_fp8, k_scales) where k_fp8 has shape [N, D] with dtype torch.float8_e4m3fn and k_scales has shape [N]) with dtype torch.float32.

required
weights Tensor

weights of shape [M, H], dtype torch.float32.

required
cu_seqlen_ks Tensor

Start indices (inclusive) for valid K per query position, shape [M], dtype int32.

required
cu_seqlen_ke Tensor

End indices (exclusive) for valid K per query position, shape [M], dtype int32.

required
clean_logits bool

Whether to clean the unfilled logits into -inf.

required

Returns:

Type Description
Tensor

Logits tensor of shape [M, N], dtype torch.float32.

Source code in vllm/utils/deep_gemm.py
def fp8_mqa_logits(
    q: torch.Tensor,
    kv: tuple[torch.Tensor, torch.Tensor],
    weights: torch.Tensor,
    cu_seqlen_ks: torch.Tensor,
    cu_seqlen_ke: torch.Tensor,
    clean_logits: bool,
) -> torch.Tensor:
    """Compute FP8 MQA logits for a single sequence without KV paging.

    Args:
        q: Query tensor of shape [M, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with
            dtype `torch.float8_e4m3fn` and `k_scales` has shape [N])
            with dtype `torch.float32`.
        weights: weights of shape [M, H], dtype `torch.float32`.
        cu_seqlen_ks: Start indices (inclusive) for valid K per query position,
            shape [M], dtype int32.
        cu_seqlen_ke: End indices (exclusive) for valid K per query position,
            shape [M], dtype int32.
        clean_logits: Whether to clean the unfilled logits into `-inf`.

    Returns:
        Logits tensor of shape [M, N], dtype `torch.float32`.
    """
    _lazy_init()
    if _fp8_mqa_logits_impl is None:
        return _missing()
    return _fp8_mqa_logits_impl(
        q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=clean_logits
    )

fp8_paged_mqa_logits

fp8_paged_mqa_logits(
    q_fp8: Tensor,
    kv_cache_fp8: Tensor,
    weights: Tensor,
    context_lens: Tensor,
    block_tables: Tensor,
    schedule_metadata: Tensor,
    max_model_len: int,
    clean_logits: bool,
) -> Tensor

Compute FP8 MQA logits using paged KV-cache.

Parameters:

Name Type Description Default
q_fp8 Tensor

Query tensor of shape [B, next_n, H, D]. Casted to torch.float8_e4m3fn by caller.

required
kv_cache_fp8 Tensor

Paged KV-cache in packed FP8+scale layout with shape [num_blocks, block_size, 1, D+4], dtype torch.uint8. The last 4 bytes per (block,pos) store the float dequant scale.

required
weights Tensor

Tensor of shape [B * next_n, H], dtype torch.float32.

required
context_lens Tensor

Tensor of shape [B], dtype int32; effective context length for each batch element.

required
block_tables Tensor

Tensor of shape [B, max_blocks], dtype int32; maps logical block indices to physical blocks in the paged cache.

required
schedule_metadata Tensor

Returned by get_paged_mqa_logits_metadata; used to distribute work across SMs.

required
max_model_len int

Maximum sequence length used to size the logits output.

required
clean_logits bool

Whether to clean the unfilled logits into -inf.

required

Returns:

Type Description
Tensor

Logits tensor of shape [B * next_n, max_model_len], dtype

Tensor

torch.float32.

Source code in vllm/utils/deep_gemm.py
def fp8_paged_mqa_logits(
    q_fp8: torch.Tensor,
    kv_cache_fp8: torch.Tensor,
    weights: torch.Tensor,
    context_lens: torch.Tensor,
    block_tables: torch.Tensor,
    schedule_metadata: torch.Tensor,
    max_model_len: int,
    clean_logits: bool,
) -> torch.Tensor:
    """Compute FP8 MQA logits using paged KV-cache.

    Args:
        q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to
            `torch.float8_e4m3fn` by caller.
        kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape
            [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last
            4 bytes per (block,pos) store the `float` dequant scale.
        weights: Tensor of shape [B * next_n, H], dtype `torch.float32`.
        context_lens: Tensor of shape [B], dtype int32; effective context length
            for each batch element.
        block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical
            block indices to physical blocks in the paged cache.
        schedule_metadata: Returned by `get_paged_mqa_logits_metadata`;
            used to distribute work across SMs.
        max_model_len: Maximum sequence length used to size the logits output.
        clean_logits: Whether to clean the unfilled logits into `-inf`.

    Returns:
        Logits tensor of shape [B * next_n, max_model_len], dtype
        `torch.float32`.
    """
    _lazy_init()
    if _fp8_paged_mqa_logits_impl is None:
        return _missing()
    return _fp8_paged_mqa_logits_impl(
        q_fp8,
        kv_cache_fp8,
        weights,
        context_lens,
        block_tables,
        schedule_metadata,
        max_model_len,
        clean_logits=clean_logits,
    )

get_col_major_tma_aligned_tensor

get_col_major_tma_aligned_tensor(x: Tensor) -> Tensor

Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor

Source code in vllm/utils/deep_gemm.py
def get_col_major_tma_aligned_tensor(x: torch.Tensor) -> torch.Tensor:
    """Wrapper for DeepGEMM's get_mn_major_tma_aligned_tensor"""
    _lazy_init()
    if _get_mn_major_tma_aligned_tensor_impl is None:
        return _missing()
    return _get_mn_major_tma_aligned_tensor_impl(x)

get_paged_mqa_logits_metadata

get_paged_mqa_logits_metadata(
    context_lens: Tensor, block_size: int, num_sms: int
) -> Tensor

Build scheduling metadata for paged MQA logits.

Parameters:

Name Type Description Default
context_lens Tensor

Tensor of shape [B], dtype int32; effective context length per batch element.

required
block_size int

KV-cache block size in tokens (e.g., 64).

required
num_sms int

Number of SMs available. 132 for Hopper

required

Returns:

Type Description
Tensor

Backend-specific tensor consumed by fp8_paged_mqa_logits to

Tensor

schedule work across SMs.

Source code in vllm/utils/deep_gemm.py
def get_paged_mqa_logits_metadata(
    context_lens: torch.Tensor, block_size: int, num_sms: int
) -> torch.Tensor:
    """Build scheduling metadata for paged MQA logits.

    Args:
        context_lens: Tensor of shape [B], dtype int32; effective context length
            per batch element.
        block_size: KV-cache block size in tokens (e.g., 64).
        num_sms: Number of SMs available. 132 for Hopper

    Returns:
        Backend-specific tensor consumed by `fp8_paged_mqa_logits` to
        schedule work across SMs.
    """
    _lazy_init()
    if _get_paged_mqa_logits_metadata_impl is None:
        return _missing()
    return _get_paged_mqa_logits_metadata_impl(context_lens, block_size, num_sms)

is_deep_gemm_e8m0_used cached

is_deep_gemm_e8m0_used() -> bool

Return True if vLLM is configured to use DeepGEMM " "E8M0 scale on a Hopper or Blackwell-class GPU.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_deep_gemm_e8m0_used() -> bool:
    """Return `True` if vLLM is configured to use DeepGEMM "
    "E8M0 scale on a Hopper or Blackwell-class GPU.
    """
    if not is_deep_gemm_supported():
        logger.debug_once(
            "DeepGEMM E8M0 disabled: DeepGEMM not supported on this system."
        )
        return False

    _lazy_init()

    if _fp8_gemm_nt_impl is None:
        logger.info_once(
            "DeepGEMM E8M0 disabled: _fp8_gemm_nt_impl not found", scope="local"
        )
        return False

    if envs.VLLM_USE_DEEP_GEMM_E8M0:
        logger.info_once("DeepGEMM E8M0 enabled on current platform.", scope="local")
        return True

    logger.info_once("DeepGEMM E8M0 disabled on current configuration.", scope="local")
    return False

is_deep_gemm_supported cached

is_deep_gemm_supported() -> bool

Return True if DeepGEMM is supported on the current platform. Currently, only Hopper and Blackwell GPUs are supported.

Source code in vllm/utils/deep_gemm.py
@functools.cache
def is_deep_gemm_supported() -> bool:
    """Return `True` if DeepGEMM is supported on the current platform.
    Currently, only Hopper and Blackwell GPUs are supported.
    """
    is_supported_arch = current_platform.is_cuda() and (
        current_platform.is_device_capability(90)
        or current_platform.is_device_capability_family(100)
    )
    return envs.VLLM_USE_DEEP_GEMM and has_deep_gemm() and is_supported_arch