Skip to content

vllm.model_executor.models.nemotron_h

Inference-only NemotronH model.

NemotronHForCausalLM

Bases: Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid, SupportsQuant, MixtureOfExperts, SupportsMambaPrefixCaching

Source code in vllm/model_executor/models/nemotron_h.py
class NemotronHForCausalLM(
    nn.Module,
    HasInnerState,
    SupportsLoRA,
    SupportsPP,
    IsHybrid,
    SupportsQuant,
    MixtureOfExperts,
    SupportsMambaPrefixCaching,
):
    # Relevant only if self.has_moe is True
    is_non_gated_moe: bool = True

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={"backbone": "model"},
        orig_to_new_substr={"A_log": "A", "embeddings": "embed_tokens"},
    )

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # LoRA specific attributes
    embedding_modules = {
        "embed_tokens": "input_embeddings",
        "lm_head": "output_embeddings",
    }

    # Skip MTP (Multi-Token Prediction) layers during LoRA loading
    lora_skip_prefixes = ["mtp."]

    @classmethod
    def get_mamba_state_dtype_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[torch.dtype, torch.dtype]:
        return MambaStateDtypeCalculator.mamba2_state_dtype(
            vllm_config.model_config.dtype,
            vllm_config.cache_config.mamba_cache_dtype,
            vllm_config.cache_config.mamba_ssm_cache_dtype,
        )

    @classmethod
    def get_mamba_state_shape_from_config(
        cls,
        vllm_config: "VllmConfig",
    ) -> tuple[tuple[int, int], tuple[int, int, int]]:
        """Calculate shapes for Mamba's convolutional and state caches.

        Args:
            vllm_config: vLLM config

        Returns:
            Tuple containing:
            - conv_state_shape: Shape for convolutional state cache
            - temporal_state_shape: Shape for state space model cache
        """
        parallel_config = vllm_config.parallel_config
        hf_config = vllm_config.model_config.hf_config
        intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim

        return MambaStateShapeCalculator.mamba2_state_shape(
            intermediate_size=intermediate_size,
            tp_world_size=parallel_config.tensor_parallel_size,
            n_groups=hf_config.n_groups,
            num_heads=hf_config.mamba_num_heads,
            head_dim=hf_config.mamba_head_dim,
            state_size=hf_config.ssm_state_size,
            conv_kernel=hf_config.conv_kernel,
        )

    @classmethod
    def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
        return MambaStateCopyFuncCalculator.mamba2_state_copy_func()

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        self.vllm_config = vllm_config
        self.model_config = vllm_config.model_config

        scheduler_config = vllm_config.scheduler_config

        self.quant_config = vllm_config.quant_config

        super().__init__()
        self.config = config
        self.scheduler_config = scheduler_config
        self.model = NemotronHModel(
            vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
        )

        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            prefix=maybe_prefix(prefix, "lm_head"),
        )

        self.logits_processor = LogitsProcessor(config.vocab_size)

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

        # Set MoE hyperparameters
        if self.model.has_moe:
            self.expert_weights = []
            self.num_expert_groups = config.n_group

            self.moe_layers = []
            example_moe = None
            for layer in self.model.layers:
                if isinstance(layer, NemotronHMoEDecoderLayer):
                    # Pick last one layer since the first ones
                    # may be dense layers.
                    example_moe = layer.mixer
                    self.moe_layers.append(layer.mixer.experts)

            self.num_moe_layers = len(self.moe_layers)
            self.num_logical_experts = example_moe.n_logical_experts
            self.num_physical_experts = example_moe.n_physical_experts
            self.num_local_physical_experts = example_moe.n_local_physical_experts  # noqa: E501
            self.num_routed_experts = example_moe.n_routed_experts
            self.num_shared_experts = example_moe.n_shared_experts
            self.num_redundant_experts = example_moe.n_redundant_experts

    def update_physical_experts_metadata(
        self,
        num_physical_experts: int,
        num_local_physical_experts: int,
    ) -> None:
        assert self.num_local_physical_experts == num_local_physical_experts
        self.num_physical_experts = num_physical_experts
        self.num_local_physical_experts = num_local_physical_experts
        self.num_redundant_experts = num_physical_experts - self.num_logical_experts
        for layer in self.model.layers:
            if isinstance(layer, NemotronHMoEDecoderLayer):
                moe = layer.mixer
                moe.n_local_physical_experts = num_local_physical_experts
                moe.n_physical_experts = num_physical_experts
                moe.n_redundant_experts = self.num_redundant_experts
                moe.experts.update_expert_map()

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs,
    ):
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self, skip_prefixes=["mtp"])
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

get_mamba_state_shape_from_config classmethod

get_mamba_state_shape_from_config(
    vllm_config: VllmConfig,
) -> tuple[tuple[int, int], tuple[int, int, int]]

Calculate shapes for Mamba's convolutional and state caches.

Parameters:

Name Type Description Default
vllm_config VllmConfig

vLLM config

required

Returns:

Type Description
tuple[int, int]

Tuple containing:

tuple[int, int, int]
  • conv_state_shape: Shape for convolutional state cache
tuple[tuple[int, int], tuple[int, int, int]]
  • temporal_state_shape: Shape for state space model cache
Source code in vllm/model_executor/models/nemotron_h.py
@classmethod
def get_mamba_state_shape_from_config(
    cls,
    vllm_config: "VllmConfig",
) -> tuple[tuple[int, int], tuple[int, int, int]]:
    """Calculate shapes for Mamba's convolutional and state caches.

    Args:
        vllm_config: vLLM config

    Returns:
        Tuple containing:
        - conv_state_shape: Shape for convolutional state cache
        - temporal_state_shape: Shape for state space model cache
    """
    parallel_config = vllm_config.parallel_config
    hf_config = vllm_config.model_config.hf_config
    intermediate_size = hf_config.mamba_num_heads * hf_config.mamba_head_dim

    return MambaStateShapeCalculator.mamba2_state_shape(
        intermediate_size=intermediate_size,
        tp_world_size=parallel_config.tensor_parallel_size,
        n_groups=hf_config.n_groups,
        num_heads=hf_config.mamba_num_heads,
        head_dim=hf_config.mamba_head_dim,
        state_size=hf_config.ssm_state_size,
        conv_kernel=hf_config.conv_kernel,
    )

NemotronHModel

Bases: Module

Source code in vllm/model_executor/models/nemotron_h.py
@support_torch_compile
class NemotronHModel(nn.Module):
    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: NemotronHConfig = vllm_config.model_config.hf_config
        model_config = vllm_config.model_config
        cache_config = vllm_config.cache_config
        quant_config = vllm_config.quant_config
        parallel_config = vllm_config.parallel_config

        self.config = config

        self.vocab_size = config.vocab_size

        self.embed_tokens = VocabParallelEmbedding(
            self.vocab_size,
            config.hidden_size,
        )

        self.has_moe = "E" in config.hybrid_override_pattern

        def get_layer(prefix: str):
            layer_idx = int(prefix.rsplit(".", 1)[1])
            layer_class = ALL_DECODER_LAYER_TYPES[
                config.hybrid_override_pattern[layer_idx]
            ]
            return layer_class(
                config=config,
                layer_idx=layer_idx,
                model_config=model_config,
                cache_config=cache_config,
                quant_config=quant_config,
                parallel_config=parallel_config,
                prefix=prefix,
            )

        self.start_layer, self.end_layer, self.layers = make_layers(
            len(config.hybrid_override_pattern), get_layer, prefix=f"{prefix}.layers"
        )
        self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory(
            ["hidden_states", "residual"], config.hidden_size
        )

        self.norm_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon)

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.embed_tokens(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        if get_pp_group().is_first_rank:
            if inputs_embeds is not None:
                hidden_states = inputs_embeds
            else:
                hidden_states = self.embed_input_ids(input_ids)
            residual = None
        else:
            assert intermediate_tensors is not None
            hidden_states = intermediate_tensors["hidden_states"]
            residual = intermediate_tensors["residual"]

        for layer in islice(self.layers, self.start_layer, self.end_layer):
            hidden_states, residual = layer(
                positions=positions,
                hidden_states=hidden_states,
                residual=residual,
            )

        if not get_pp_group().is_last_rank:
            return IntermediateTensors(
                {"hidden_states": hidden_states, "residual": residual}
            )
        hidden_states, _ = self.norm_f(hidden_states, residual)
        return hidden_states

    def _get_max_n_routed_experts(self) -> int:
        """Get max n_routed_experts from config or block_configs for puzzle models.

        For heterogeneous models with varying expert counts per layer,
        returns the MAX to ensure all expert weights can be loaded.
        """
        # First try top-level attribute
        n_routed_experts = getattr(self.config, "n_routed_experts", None)
        if n_routed_experts is not None:
            return n_routed_experts

        # For puzzle models, get MAX from all MoE blocks in block_configs
        # (different layers may have different expert counts)
        max_experts = 0
        block_configs = getattr(self.config, "block_configs", None)
        if block_configs:
            for block in block_configs:
                if isinstance(block, dict):
                    if block.get("block_type") == "moe":
                        max_experts = max(max_experts, block.get("n_routed_experts", 0))
                else:
                    # HF converts dicts to objects with attributes
                    if getattr(block, "block_type", "") == "moe":
                        max_experts = max(
                            max_experts, getattr(block, "n_routed_experts", 0)
                        )
        return max_experts

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        if self.has_moe:
            # (param_name, weight_name, expert_id, shard_id)
            expert_params_mapping = FusedMoE.make_expert_params_mapping(
                # - FusedMoe.w1 (aka gate_proj) should be up_proj since that's
                #   what the activation is applied to
                # - FusedMoe.w3 (aka up_proj) should be ignored since we're
                #   using non-gated MoE
                self,
                ckpt_gate_proj_name="up_proj",
                ckpt_down_proj_name="down_proj",
                ckpt_up_proj_name="",
                num_experts=self._get_max_n_routed_experts(),
                num_redundant_experts=getattr(self, "num_redundant_experts", 0),
            )
            return expert_params_mapping

        return []

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]

        expert_params_mapping = self.get_expert_mapping()

        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        for name, loaded_weight in weights:
            if "scale" in name or "zero_point" in name:
                # Remapping the name of FP8 kv-scale.
                name = maybe_remap_kv_scale_name(name, params_dict)
                if name is None:
                    continue

            # load stacked params
            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                # Skip loading extra bias for GPTQ models.
                if name.endswith(".bias") and name not in params_dict:
                    continue

                if is_pp_missing_parameter(name, self):
                    continue

                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break

            # load other params
            else:
                is_expert_weight = False
                for mapping in expert_params_mapping:
                    param_name, weight_name, expert_id, shard_id = mapping
                    if weight_name not in name:
                        continue

                    # Anyway, this is an expert weight and should not be
                    # attempted to load as other weights later
                    is_expert_weight = True

                    # Do not modify `name` since the loop may continue here
                    # Instead, create a new variable
                    name_mapped = name.replace(weight_name, param_name)

                    if is_pp_missing_parameter(name_mapped, self):
                        continue
                    param = params_dict[name_mapped]
                    # We should ask the weight loader to return success or not
                    # here since otherwise we may skip experts with other
                    # available replicas.
                    weight_loader = typing.cast(
                        Callable[..., bool], param.weight_loader
                    )
                    success = weight_loader(
                        param,
                        loaded_weight,
                        name_mapped,
                        shard_id=shard_id,
                        expert_id=expert_id,
                        return_success=True,
                    )
                    if success:
                        name = name_mapped
                        break
                else:
                    if is_expert_weight:
                        continue

                    if is_pp_missing_parameter(name, self):
                        continue

                    param = params_dict[name]
                    weight_loader = getattr(
                        param, "weight_loader", default_weight_loader
                    )
                    weight_loader(param, loaded_weight)

            loaded_params.add(name)
        return loaded_params

_get_max_n_routed_experts

_get_max_n_routed_experts() -> int

Get max n_routed_experts from config or block_configs for puzzle models.

For heterogeneous models with varying expert counts per layer, returns the MAX to ensure all expert weights can be loaded.

Source code in vllm/model_executor/models/nemotron_h.py
def _get_max_n_routed_experts(self) -> int:
    """Get max n_routed_experts from config or block_configs for puzzle models.

    For heterogeneous models with varying expert counts per layer,
    returns the MAX to ensure all expert weights can be loaded.
    """
    # First try top-level attribute
    n_routed_experts = getattr(self.config, "n_routed_experts", None)
    if n_routed_experts is not None:
        return n_routed_experts

    # For puzzle models, get MAX from all MoE blocks in block_configs
    # (different layers may have different expert counts)
    max_experts = 0
    block_configs = getattr(self.config, "block_configs", None)
    if block_configs:
        for block in block_configs:
            if isinstance(block, dict):
                if block.get("block_type") == "moe":
                    max_experts = max(max_experts, block.get("n_routed_experts", 0))
            else:
                # HF converts dicts to objects with attributes
                if getattr(block, "block_type", "") == "moe":
                    max_experts = max(
                        max_experts, getattr(block, "n_routed_experts", 0)
                    )
    return max_experts