class MiniMaxText01ForCausalLM(nn.Module, HasInnerState, IsHybrid):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None:
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
if not hasattr(config, "sliding_window"):
config.sliding_window = None
self.CONCAT_FFN = True
if hasattr(vllm_config.model_config, "max_model_len"):
self.config.max_model_len = vllm_config.model_config.max_model_len
self.model = MiniMaxText01Model(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
if get_pp_group().is_last_rank:
self.lm_head = ParallelLMHead(
config.vocab_size,
self.config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(
config.vocab_size, self.config.vocab_size
)
else:
self.lm_head = PPMissingLayer()
self.lm_head.float()
flash_layer_count = sum(
1 for attn_type in self.model.decoder_attention_types if attn_type == 1
)
self.kv_cache = [torch.tensor([]) for _ in range(flash_layer_count)]
return
def copy_inputs_before_cuda_graphs(self, input_buffers, **kwargs):
return self.model.minimax_cache.copy_inputs_before_cuda_graphs(
input_buffers, **kwargs
)
def get_seqlen_agnostic_capture_inputs(self, batch_size: int):
return self.model.minimax_cache.get_seqlen_agnostic_capture_inputs(batch_size)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
**kwargs,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds, **kwargs
)
return hidden_states
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
logits = self.logits_processor(self.lm_head, hidden_states.float())
return logits
def make_empty_intermediate_tensors(
self, batch_size: int, dtype: torch.dtype, device: torch.device
) -> IntermediateTensors:
return IntermediateTensors(
{
"hidden_states": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
"residual": torch.zeros(
(batch_size, self.config.hidden_size), dtype=dtype, device=device
),
}
)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
def which_layer(name: str) -> int:
if "layers" in name:
after_layer = name.split("layers")[-1]
return int(after_layer.split(".")[1])
return None
def is_linear_attn_layer(layer_idx: int) -> bool:
if layer_idx is None or layer_idx >= len(
self.model.decoder_attention_types
):
return False
return self.model.decoder_attention_types[layer_idx] == 0
def is_moe_weight(name: str) -> bool:
return "block_sparse_moe" in name and not name.endswith(".bias")
def get_expert_id(param_name):
pattern = r"model\.layers\.\d+\.block_sparse_moe\.experts\.(\d+)\."
match = re.search(pattern, param_name)
if match:
return match.group(1)
return None
def load_sparse_moe_weight(
name: str, loaded_weight: torch.Tensor, self
) -> None:
if isinstance(self.config.num_local_experts, list):
expert_params_mapping = [
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
)
for expert_id in range(max(self.config.num_local_experts))
for weight_name in ["w1", "w2", "w3"]
]
else:
expert_params_mapping = [
(
"w13_scale" if weight_name in ["w1", "w3"] else "w2_scale",
f"{expert_id}.{weight_name}.weight_scale",
expert_id,
weight_name,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
] + [
(
"w13_weight" if weight_name in ["w1", "w3"] else "w2_weight",
f"{expert_id}.{weight_name}.weight",
expert_id,
weight_name,
)
for expert_id in range(self.config.num_local_experts)
for weight_name in ["w1", "w2", "w3"]
]
for param_name, weight_name, expert_id, shard_id in expert_params_mapping:
name_expert_id = get_expert_id(name)
if name_expert_id is not None and int(name_expert_id) != int(expert_id):
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(
param,
loaded_weight,
weight_name,
expert_id=expert_id,
shard_id=shard_id,
)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return
def is_shared_mlp_weight(name: str) -> bool:
return "shared_mlp" in name and not name.endswith(".bias")
def load_shared_mlp_weight(
name: str, loaded_weight: torch.Tensor, self
) -> None:
if not self.CONCAT_FFN:
if "gate_proj" in name:
name = name.replace("gate_proj", "w1", 1)
elif "up_proj" in name:
name = name.replace("up_proj", "w3", 1)
elif "down_proj" in name:
name = name.replace("down_proj", "w2", 1)
else:
if "gate_proj" in name:
name = name.replace("gate_proj", "gate_up_proj", 1)
loaded_shard_id = 0
elif "up_proj" in name:
name = name.replace("up_proj", "gate_up_proj", 1)
loaded_shard_id = 1
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
if not self.CONCAT_FFN:
weight_loader(param, loaded_weight)
else:
if "gate_up_proj" in name:
weight_loader(param, loaded_weight, loaded_shard_id)
elif "down_proj" in name:
weight_loader(param, loaded_weight)
else:
raise AssertionError("MLP weight not in [gate_up_proj, down_proj]")
loaded_params.add(name)
return
def is_mha_weight(name: str) -> bool:
return "self_attn" in name and not name.endswith(".bias")
def load_linear_attn_weight(
name: str, loaded_weight: torch.Tensor, self
) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", MiniMaxText01LinearAttention.weight_direct_load
)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return
def load_flash_attn_weight(
name: str, loaded_weight: torch.Tensor, self
) -> None:
flash_mha_params_mapping = [
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
for param_name, weight_name, shard_id in flash_mha_params_mapping:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return
def is_layer_norm_weight(name: str) -> bool:
return "norm" in name and not name.endswith(".bias") and name in params_dict
def load_layer_norm_weight(
name: str, loaded_weight: torch.Tensor, self
) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return
def load_basic_weight(name: str, loaded_weight: torch.Tensor, self) -> None:
if is_pp_missing_parameter(name, self):
return
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader = weight_loader_with_alias(name)(weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
return
for name, loaded_weight in weights:
weight_at_layer = which_layer(name)
if weight_at_layer and weight_at_layer >= len(
self.model.decoder_attention_types
):
continue
if is_layer_norm_weight(name):
load_layer_norm_weight(name, loaded_weight, self)
continue
if is_mha_weight(name):
if is_linear_attn_layer(weight_at_layer):
load_linear_attn_weight(name, loaded_weight, self)
else:
load_flash_attn_weight(name, loaded_weight, self)
continue
if is_moe_weight(name):
load_sparse_moe_weight(name, loaded_weight, self)
continue
if is_shared_mlp_weight(name):
load_shared_mlp_weight(name, loaded_weight, self)
continue
if "rotary_emb.inv_freq" in name:
continue
load_basic_weight(name, loaded_weight, self)
return loaded_params
@classmethod
def get_mamba_state_dtype_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[torch.dtype, torch.dtype]:
return MambaStateDtypeCalculator.linear_attention_state_dtype(
vllm_config.model_config.dtype,
vllm_config.cache_config.mamba_cache_dtype,
)
@classmethod
def get_mamba_state_shape_from_config(
cls,
vllm_config: "VllmConfig",
) -> tuple[tuple[int, ...], ...]:
"""Calculate shape for MiniMaxText01LinearAttention cache.
Args:
vllm_config: vLLM config
Returns:
Tuple containing:
- state_shape: Shape of the cache
"""
parallel_config = vllm_config.parallel_config
hf_config = vllm_config.model_config.hf_config
return MambaStateShapeCalculator.linear_attention_state_shape(
num_heads=hf_config.num_attention_heads,
tp_size=parallel_config.tensor_parallel_size,
head_dim=hf_config.head_dim,
)
@classmethod
def get_mamba_state_copy_func(cls) -> tuple[MambaStateCopyFunc]:
return MambaStateCopyFuncCalculator.linear_attention_state_copy_func()