Skip to content

vllm.model_executor.models.baichuan

Inference-only BaiChuan model compatible with HuggingFace weights.

BaiChuanAttention

Bases: Module

Multi-headed attention from 'Attention Is All You Need' paper

Source code in vllm/model_executor/models/baichuan.py
class BaiChuanAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
        position_embedding: str,
        rope_parameters: dict,
        max_position_embeddings: int = 8192,
        cache_config: CacheConfig | None = None,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ):
        super().__init__()
        self.hidden_size = hidden_size
        tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
        self.total_num_heads = num_heads
        assert self.total_num_heads % tensor_model_parallel_world_size == 0
        self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
        self.head_dim = hidden_size // self.total_num_heads
        self.position_embedding = position_embedding
        self.max_position_embeddings = max_position_embeddings

        # pylint: disable=invalid-name
        self.W_pack = QKVParallelLinear(
            hidden_size,
            self.head_dim,
            self.total_num_heads,
            self.total_num_heads,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.W_pack",
        )
        self.o_proj = RowParallelLinear(
            self.total_num_heads * self.head_dim,
            hidden_size,
            bias=False,
            quant_config=quant_config,
            prefix=f"{prefix}.o_proj",
        )
        # Create the alibi slopes and slice them.
        if self.position_embedding == "ALIBI":
            tp_rank = get_tensor_model_parallel_rank()
            head_start = tp_rank * self.num_heads
            head_end = (tp_rank + 1) * self.num_heads
            alibi_slopes = _get_alibi_slopes(self.total_num_heads)
            alibi_slopes = alibi_slopes[head_start:head_end].tolist()

            scaling = self.head_dim**-0.5
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                scaling,
                alibi_slopes=alibi_slopes,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )
        else:
            self.rotary_emb = get_rope(
                self.head_dim,
                max_position=self.max_position_embeddings,
                rope_parameters=rope_parameters,
            )
            self.scaling = self.head_dim**-0.5
            self.attn = Attention(
                self.num_heads,
                self.head_dim,
                self.scaling,
                cache_config=cache_config,
                quant_config=quant_config,
                prefix=f"{prefix}.attn",
            )

    def forward(
        self,
        positions: torch.Tensor,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor:
        qkv, _ = self.W_pack(hidden_states)
        q, k, v = qkv.chunk(chunks=3, dim=-1)
        if self.position_embedding != "ALIBI":
            q, k = self.rotary_emb(positions, q, k)
        attn_output = self.attn(q, k, v)
        output, _ = self.o_proj(attn_output)
        return output

BaiChuanForCausalLM

Bases: BaiChuanBaseForCausalLM

Baichuan 7B. NOTE: the class name has an upper case 'C'.

Source code in vllm/model_executor/models/baichuan.py
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 7B.
    NOTE: the class name has an upper case 'C'.
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__(
            vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
        )

BaichuanForCausalLM

Bases: BaiChuanBaseForCausalLM

Baichuan 13B and Baichuan2 7B/13B. NOTE: the class name has a lower case 'c'.

Source code in vllm/model_executor/models/baichuan.py
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
    """Baichuan 13B and Baichuan2 7B/13B.
    NOTE: the class name has a lower case 'c'.
    """

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        config = vllm_config.model_config.hf_config
        if config.hidden_size == 4096:  # baichuan2 7b
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ROPE"
            )
        else:  # baichuan 13b, baichuan2 13b
            super().__init__(
                vllm_config=vllm_config, prefix=prefix, position_embedding="ALIBI"
            )