Skip to content

vllm.model_executor.models.grok1

Inference-only Grok (Grok1/Grok2) model.

Grok1ForCausalLM

Bases: GrokBaseForCausalLM

Grok1-specific implementation.

Source code in vllm/model_executor/models/grok1.py
class Grok1ForCausalLM(GrokBaseForCausalLM):
    """Grok1-specific implementation."""

    # Grok1 expert weight naming
    ckpt_gate_proj_name = "linear"
    ckpt_down_proj_name = "linear_1"
    ckpt_up_proj_name = "linear_v"

    def get_weight_name_remapping(self) -> dict[str, str]:
        # Grok1 uses standard naming, no remapping needed
        return {}

Grok1MoE

Bases: Module

A tensor-parallel MoE implementation for Grok1 that shards each expert across all ranks.

Each expert's weights are sharded across all ranks and a fused MoE kernel is used for the forward pass, and finally we reduce the outputs across ranks.

Source code in vllm/model_executor/models/grok1.py
class Grok1MoE(nn.Module):
    """A tensor-parallel MoE implementation for Grok1 that shards each expert
    across all ranks.

    Each expert's weights are sharded across all ranks and a fused MoE
    kernel is used for the forward pass, and finally we reduce the outputs
    across ranks.
    """

    def __init__(
        self,
        num_experts: int,
        top_k: int,
        hidden_size: int,
        intermediate_size: int,
        router_logit_soft_cap: float = 0.0,
        params_dtype: torch.dtype | None = None,
        quant_config: QuantizationConfig | None = None,
        tp_size: int | None = None,
        renormalize: bool = False,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size

        # Gate always runs at half / full precision for now.
        self.gate = ReplicatedLinear(
            hidden_size,
            num_experts,
            bias=False,
            params_dtype=params_dtype,
            quant_config=None,
            prefix=f"{prefix}.gate",
        )

        self.experts = FusedMoE(
            num_experts=num_experts,
            top_k=top_k,
            hidden_size=hidden_size,
            intermediate_size=intermediate_size,
            params_dtype=params_dtype,
            reduce_results=True,
            renormalize=renormalize,
            quant_config=quant_config,
            tp_size=tp_size,
            activation="gelu",
            prefix=f"{prefix}.experts",
        )
        self.router_logit_soft_cap = router_logit_soft_cap

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # NOTE: hidden_states can have either 1D or 2D shape.
        orig_shape = hidden_states.shape
        hidden_states = hidden_states.view(-1, self.hidden_size)
        # router_logits: (num_tokens, n_experts)
        router_logits, _ = self.gate(hidden_states)
        if self.router_logit_soft_cap > 0:
            router_logits = self.router_logit_soft_cap * F.tanh(
                router_logits / self.router_logit_soft_cap
            )
        final_hidden_states = self.experts(hidden_states, router_logits)
        return final_hidden_states.view(orig_shape)

Grok2ForCausalLM

Bases: GrokBaseForCausalLM

Grok2-specific implementation.

Source code in vllm/model_executor/models/grok1.py
class Grok2ForCausalLM(GrokBaseForCausalLM):
    """Grok2-specific implementation."""

    # Grok2 has additional packed modules for MLP
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    # Grok2 expert weight naming
    ckpt_gate_proj_name = "w1"
    ckpt_down_proj_name = "w2"
    ckpt_up_proj_name = "w3"

    def get_weight_name_remapping(self) -> dict[str, str]:
        # Grok2 checkpoint uses different naming conventions
        return {
            ".self_attn.": ".attn.",
            ".block_sparse_moe.": ".moe_block.",
        }

GrokBaseForCausalLM

Bases: Module, SupportsLoRA, SupportsPP

Base class for Grok models with shared logic.

Source code in vllm/model_executor/models/grok1.py
class GrokBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
    """Base class for Grok models with shared logic."""

    fall_back_to_pt_during_load = False

    # Subclasses should override these
    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
    }

    # Expert weight naming - subclasses override these
    ckpt_gate_proj_name: str = "linear"
    ckpt_down_proj_name: str = "linear_1"
    ckpt_up_proj_name: str = "linear_v"

    def get_weight_name_remapping(self) -> dict[str, str]:
        """Return weight name remapping for this version. Override in subclasses."""
        return {}

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config

        self.config = config
        self.quant_config = quant_config

        self.model = Grok1Model(
            vllm_config=vllm_config,
            prefix=maybe_prefix(prefix, "model"),
            ckpt_gate_proj_name=self.ckpt_gate_proj_name,
            ckpt_down_proj_name=self.ckpt_down_proj_name,
            ckpt_up_proj_name=self.ckpt_up_proj_name,
            weight_name_remapping=self.get_weight_name_remapping(),
        )

        self.lm_head = ParallelLMHead(
            config.vocab_size,
            config.hidden_size,
            quant_config=quant_config,
            prefix=maybe_prefix(prefix, "lm_head"),
        )

        if self.config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

        self.output_multiplier_scale = getattr(
            config, "output_multiplier_scale", DEFAULT_OUTPUT_MULTIPLIER_SCALE
        )
        self.logits_processor = LogitsProcessor(
            config.vocab_size,
            scale=self.output_multiplier_scale,
            soft_cap=getattr(config, "final_logit_softcapping", None),
        )

        self.make_empty_intermediate_tensors = (
            self.model.make_empty_intermediate_tensors
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.model.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor | IntermediateTensors:
        hidden_states = self.model(
            input_ids, positions, intermediate_tensors, inputs_embeds
        )
        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        logits = self.logits_processor(self.lm_head, hidden_states)
        return logits

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        # Skip lm_head when tie_word_embeddings is True
        skip_prefixes = ["lm_head"] if self.config.tie_word_embeddings else None

        loader = AutoWeightsLoader(
            self,
            skip_prefixes=skip_prefixes,
        )
        return loader.load_weights(weights)

    def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
        return self.model.get_expert_mapping()

get_weight_name_remapping

get_weight_name_remapping() -> dict[str, str]

Return weight name remapping for this version. Override in subclasses.

Source code in vllm/model_executor/models/grok1.py
def get_weight_name_remapping(self) -> dict[str, str]:
    """Return weight name remapping for this version. Override in subclasses."""
    return {}

GrokForCausalLM

Bases: GrokBaseForCausalLM

Factory class that dispatches to version-specific implementation.

Source code in vllm/model_executor/models/grok1.py
class GrokForCausalLM(GrokBaseForCausalLM):
    """Factory class that dispatches to version-specific implementation."""

    def __new__(cls, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        version = _get_grok_version(config)

        instance_cls = _GROK_VERSIONS.get(version)
        if instance_cls is None:
            raise ValueError(f"Unsupported Grok version: {version}")

        # Merge class attributes for LoRA/quantization compatibility
        cls.packed_modules_mapping = dict(cls.packed_modules_mapping)
        cls.packed_modules_mapping.update(instance_cls.packed_modules_mapping)

        return instance_cls(vllm_config=vllm_config, prefix=prefix)

_get_grok_version

_get_grok_version(config) -> str

Detect Grok version from HF config using multiple heuristics.

Source code in vllm/model_executor/models/grok1.py
def _get_grok_version(config) -> str:
    """Detect Grok version from HF config using multiple heuristics."""
    # Check for Grok2-specific attributes (both for robust detection)
    has_residual_moe = getattr(config, "residual_moe", False)
    has_moe_intermediate_size = hasattr(config, "moe_intermediate_size")

    if has_residual_moe or has_moe_intermediate_size:
        return "grok2"

    return "grok1"  # Default to Grok1