class Plamo2ForCausalLM(
torch.nn.Module, HasInnerState, SupportsLoRA, SupportsPP, IsHybrid
):
packed_modules_mapping = {
"qkv_proj": ["qkv_proj"],
"gate_up_proj": ["gate_up_proj"],
"in_proj": ["in_proj"],
}
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
scheduler_config = vllm_config.scheduler_config
self.config = config
self.vllm_config = vllm_config
self.model_config = vllm_config.model_config
self.scheduler_config = scheduler_config
# ModelConfig.get_head_size assumes head_dim is set or calculated as
# hidden_size // num_attention_heads. However, this is not always
# the case for PLaMo2, as indicated by the FIXME comment.
self.config.head_dim = self.config.hidden_size_per_head
self.model = Plamo2Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.vocab_size = self.config.vocab_size
self.lm_head = ParallelLMHead(
self.vocab_size,
self.config.hidden_size,
prefix=f"{prefix}.lm_head",
)
if self.config.tie_word_embeddings:
self.lm_head = self.lm_head.tie_weights(self.model.embed_tokens)
self.logits_processor = LogitsProcessor(
config.vocab_size, self.config.vocab_size
)
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
):
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.mamba2_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
vllm_config.cache_config.mamba_ssm_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[tuple[int, int], tuple[int, int, int]]:
"""Calculate shapes for Mamba's convolutional and state caches.
Args:
vllm_config: vLLM config
Returns:
Tuple containing:
- conv_state_shape: Shape for convolutional state cache
- temporal_state_shape: Shape for state space model cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
intermediate_size = hf_config.mamba_num_heads * hf_config.hidden_size_per_head
return MambaStateShapeCalculator.mamba2_state_shape(
intermediate_size=intermediate_size,
tp_world_size=parallel_config.tensor_parallel_size,
n_groups=0,
num_heads=hf_config.mamba_num_heads,
head_dim=hf_config.hidden_size_per_head,
state_size=hf_config.mamba_d_state,
conv_kernel=hf_config.mamba_d_conv,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc, MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.mamba2_state_copy_func()
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
for name, loaded_weight in weights:
# Both tie_word_embeddings=True and lm_head.weight in the safetensor
# at the same time causes dict key access error.
if name == "lm_head.weight" and self.config.tie_word_embeddings:
assert "lm_head.weight" not in params_dict
continue
# Same workaround as AutoWeightsLoader for GPTQModel
if any(
substr in name
for substr in AutoWeightsLoader.ROTARY_EMBEDS_UNUSED_WEIGHTS
):
continue
# Update the weight names to be compatible with the vllm version
# of the model.
# Do not change the order of the replacements.
replacements = {
# Rename incompatible weight names.
".A_log": ".A",
".B_norm_weight": ".B_norm.weight",
".C_norm_weight": ".C_norm.weight",
".dt_norm_weight": ".dt_norm.weight",
".q_weight": ".q_norm.weight",
".k_weight": ".k_norm.weight",
}
# Apply replacements based on the defined mappings
for old, new in replacements.items():
if old in name:
name = name.replace(old, new)
# Reshape the in_proj weights to match the shape expected
# by MergedColumnParallelLinear.
# This works both for unquantized weights and
# for quantized weights.
# In the quantized case, the weights are already transposed.
# Also, in addition to the quantized weights,
# the zero points and scales have to be reshaped as well.
# Packing should not be affected by this.
if (
".mixer.in_proj.weight" in name
or "mixer.in_proj.qweight" in name
or "mixer.in_proj.scales" in name
or "mixer.in_proj.qzeros" in name
):
if "mixer.in_proj.weight" in name:
loaded_weight = loaded_weight.transpose(0, 1)
# for weight:
# loaded_weight.shape[0] == self.config.hidden_size
# for qweight:
# loaded_weight.shape[0] == self.config.hidden_size // param.pack_factor # noqa
# for scales and qzeros:
# loaded_weight.shape[0] == self.config.hidden_size // self.vllm_config.quant_config.group_size # noqa
loaded_weight = loaded_weight.reshape(
loaded_weight.shape[0], self.config.mamba_num_heads, -1
)
gate_weight, hidden_states_weight = loaded_weight.chunk(2, dim=-1)
gate_weight = gate_weight.reshape(loaded_weight.shape[0], -1)
hidden_states_weight = hidden_states_weight.reshape(
loaded_weight.shape[0], -1
)
loaded_weight = torch.cat([gate_weight, hidden_states_weight], dim=-1)
if "mixer.in_proj.weight" in name:
loaded_weight = loaded_weight.transpose(0, 1)
# Offset parameter with vllm's RMSNorm haven't been supported yet.
if ".pre_mixer_norm" in name:
loaded_weight += 1.0
elif ".post_mixer_norm" in name:
loaded_weight += 1.0 / 5
elif ".pre_mlp_norm" in name:
loaded_weight += 1.0
elif ".post_mlp_norm" in name:
loaded_weight += 1.0 / (5**1.5)
elif "model.norm.weight" in name:
loaded_weight += 1.0
# Skip layers on other devices.
if is_pp_missing_parameter(name, self):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)