Skip to content

vllm.model_executor.models.gemma3n_mm

Gemma3nAudioInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of audios
  • s: seq_length
  • f: num_features
Source code in vllm/model_executor/models/gemma3n_mm.py
class Gemma3nAudioInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of audios
        - s: seq_length
        - f: num_features
    """

    type: Literal["audio"] = "audio"
    input_features_padded: Annotated[torch.Tensor, TensorShape("bn", "s", "f")]
    input_features_mask: Annotated[torch.Tensor, TensorShape("bn", "s")]

Gemma3nForConditionalGeneration

Bases: Module, SupportsMultiModal, SupportsTranscription

Source code in vllm/model_executor/models/gemma3n_mm.py
@MULTIMODAL_REGISTRY.register_processor(
    Gemma3nMultiModalProcessor,
    info=Gemma3nProcessingInfo,
    dummy_inputs=Gemma3nDummyInputsBuilder,
)
class Gemma3nForConditionalGeneration(
    nn.Module, SupportsMultiModal, SupportsTranscription
):
    supported_languages = ISO639_1_SUPPORTED_LANGS

    packed_modules_mapping = {
        "qkv_proj": [
            "q_proj",
            "k_proj",
            "v_proj",
        ],
        "gate_up_proj": [
            "gate_proj",
            "up_proj",
        ],
    }

    hf_to_vllm_mapper = WeightsMapper(
        orig_to_new_prefix={
            # mapping for new names in checkpoint saved after transformers v4.52
            "model.embed_audio.": "embed_audio.",
            "model.embed_vision.": "embed_vision.",
            "model.language_model.": "language_model.model.",
            "model.vision_tower.": "vision_tower.",
            "model.audio_tower.": "audio_tower.",
            "model.multi_modal_projector.": "multi_modal_projector.",
            "lm_head.": "language_model.lm_head.",
            "model": "language_model.model",
        }
    )

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()
        config = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        multimodal_config = vllm_config.model_config.multimodal_config
        self.config = config
        self.quant_config = quant_config
        self.multimodal_config = multimodal_config
        self.vocab_size = config.text_config.vocab_size

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_tower = AutoModel.from_config(config=config.vision_config)
            self.embed_vision = Gemma3nMultimodalEmbedder(
                config.vision_config, config.text_config
            )

        with self._mark_tower_model(vllm_config, "audio"):
            self.audio_tower = AutoModel.from_config(config=config.audio_config)
            self.embed_audio = Gemma3nMultimodalEmbedder(
                config.audio_config, config.text_config
            )

        with self._mark_language_model(vllm_config):
            self.language_model: Gemma3nForCausalLM = init_vllm_registered_model(
                vllm_config=vllm_config,
                hf_config=config.text_config,
                prefix=maybe_prefix(prefix, "language_model"),
                architectures=["Gemma3nForCausalLM"],
            )

            # NOTE (NickLucche) In order to be compatible with cudagraph, the
            # buffer needs to be consistent, so we pre-allocate here.
            self.per_layer_embeddings = torch.zeros(
                vllm_config.scheduler_config.max_num_batched_tokens,
                self.config.text_config.num_hidden_layers,
                self.config.text_config.hidden_size_per_layer_input,
                device=self.language_model.model.embed_tokens.weight.device,
                dtype=self.language_model.model.embed_tokens.weight.dtype,
            )

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> Gemma3nImageInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        image_embeds = kwargs.pop("image_embeds", None)
        # TODO is this the case?
        assert image_embeds is None, "Gemma3n does not support image_embeds."
        if pixel_values is None:
            return None

        return Gemma3nImagePixelInputs(pixel_values=pixel_values)

    def _parse_and_validate_audio_input(
        self, **kwargs: object
    ) -> Gemma3nAudioInputs | None:
        input_features_padded = kwargs.pop("input_features_padded", None)
        if input_features_padded is None:
            return None

        input_features_mask = kwargs.pop("input_features_mask", None)
        if input_features_mask is None:
            return None

        return Gemma3nAudioInputs(
            input_features_padded=input_features_padded,
            input_features_mask=input_features_mask,
        )

    def _parse_and_validate_multimodal_inputs(self, **kwargs: object) -> dict:
        mm_input_by_modality = {}

        # Preserve the order of modalities if there are multiple of them
        # from the order of kwargs.
        for input_key in kwargs:
            if (
                input_key in ("pixel_values", "image_embeds")
                and "image" not in mm_input_by_modality
            ):
                mm_input_by_modality["image"] = self._parse_and_validate_image_input(
                    **kwargs
                )
            if (
                input_key == "input_features_padded"
                and "audio" not in mm_input_by_modality
            ):
                mm_input_by_modality["audio"] = self._parse_and_validate_audio_input(
                    **kwargs
                )
        return mm_input_by_modality

    def _process_image_input(
        self,
        image_input: Gemma3nImageInputs,
    ) -> list[torch.Tensor]:
        pixel_values = image_input["pixel_values"]
        vision_outputs = self.vision_tower(
            pixel_values=pixel_values, do_pooling=False, return_dict=True
        ).last_hidden_state
        # TODO try to avoid copy here
        # (batch, channels, height, width) to (batch, height * width, channels)
        vision_outputs = (
            vision_outputs.reshape(
                vision_outputs.shape[0],
                self.config.vision_config.hidden_size,
                self.config.vision_soft_tokens_per_image,
            )
            .permute(0, 2, 1)
            .contiguous()
        )
        # Normalize and embed the soft tokens into language model space.
        vision_outputs *= self.config.vision_config.hidden_size**0.5
        # Return a list of embeddings instead of a batched tensor
        return self.embed_vision(inputs_embeds=vision_outputs).unbind(0)

    def _process_audio_input(
        self,
        audio_input: Gemma3nAudioInputs,
    ) -> list[torch.Tensor]:
        # Run on padded features to enable batching
        input_features = audio_input["input_features_padded"].squeeze(1)
        input_features_mask = audio_input["input_features_mask"].squeeze(1)
        audio_outputs = self.audio_tower(input_features, ~input_features_mask)
        if isinstance(audio_outputs, tuple):
            # Transformers v4
            audio_encodings, audio_mask = audio_outputs
        else:
            # Transformers v5
            audio_encodings = audio_outputs.last_hidden_state
            audio_mask = audio_outputs.audio_mel_mask
        audio_features = self.embed_audio(inputs_embeds=audio_encodings)

        # The Gemma3nProcessor expects all audio will be 30s in length and
        # inserts 188 audio soft tokens into the text to account for this.
        # However, the audio preprocessing and encoder do not guarantee they
        # will produce exactly 188 soft tokens; they may produce fewer tokens
        # (for shorter audio) or more tokens (for longer audio or due to
        # BOA/EOA special tokens in the placeholder sequence).
        # We handle both cases:
        # - If fewer tokens: pad with the embedding of the last vocab token
        # - If more tokens: truncate to the expected count
        # TODO precompute and cache padding
        audio_padding_toks = torch.tensor(
            [[self.vocab_size - 1]], dtype=torch.long, device=audio_features.device
        )
        audio_padding_embs = self.embed_audio(input_ids=audio_padding_toks)
        audio_features = torch.where(
            audio_mask.unsqueeze(-1), audio_padding_embs, audio_features
        )

        expected_tokens = self.config.audio_soft_tokens_per_image
        audio_features, tokens_truncated = adjust_audio_features_to_expected_length(
            audio_features, expected_tokens, audio_padding_embs
        )
        if tokens_truncated > 0:
            logger.warning(
                "Gemma3n audio encoder produced %d extra tokens. "
                "Truncating to match placeholder count of %d.",
                tokens_truncated,
                expected_tokens,
            )

        # Return a list of embeddings instead of a batched tensor
        return audio_features.unbind(0)

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        mm_input_by_modality = self._parse_and_validate_multimodal_inputs(**kwargs)
        if mm_input_by_modality is None:
            return []

        multimodal_embeddings: list[torch.Tensor] = []

        # NOTE: It is important to iterate over the keys in this dictionary
        # to preserve the order of the modalities.
        for modality in mm_input_by_modality:
            multimodal_input = mm_input_by_modality[modality]
            if modality == "image":
                vision_embeddings = self._process_image_input(multimodal_input)
                multimodal_embeddings.extend(vision_embeddings)
            if modality == "audio":
                audio_embeddings = self._process_audio_input(multimodal_input)
                multimodal_embeddings.extend(audio_embeddings)
        return multimodal_embeddings

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        # NOTE (NickLucche) Each pass needs tokens to compute PLE so we cache
        # them here, as the model  forward has only access to the input_embeds.
        if input_ids is not None:
            per_layer_inputs = self.language_model.model.get_per_layer_input_embeddings(
                input_ids
            )
            per_layer_inputs = per_layer_inputs.reshape(
                -1,
                self.config.text_config.num_hidden_layers,
                self.config.text_config.hidden_size_per_layer_input,
            )
            self.per_layer_embeddings[: per_layer_inputs.shape[0]].copy_(
                per_layer_inputs
            )

        # This is to satisfy the type checker for each overload
        if multimodal_embeddings is None or is_multimodal is None:
            return super().embed_input_ids(input_ids)

        return super().embed_input_ids(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> IntermediateTensors:
        if intermediate_tensors is not None:
            inputs_embeds = None

        # NOTE (NickLucche) During profiling, `embed_input_ids` is not
        # called, hence we don't have input_ids to compute PLEs. We simply
        # select a chunk of pre-allocated PLEs. During normal execution,
        # `embed_input_ids` is called before forward, hence this slice
        # will contain PLEs computed from the actual input_ids.
        per_layer_inputs = self.per_layer_embeddings[: inputs_embeds.shape[0]]

        hidden_states = self.language_model.model(
            input_ids,
            positions,
            per_layer_inputs=per_layer_inputs,
            intermediate_tensors=intermediate_tensors,
            inputs_embeds=inputs_embeds,
            **kwargs,
        )

        return hidden_states

    def compute_logits(
        self,
        hidden_states: torch.Tensor,
    ) -> torch.Tensor | None:
        return self.language_model.compute_logits(hidden_states)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        loader = AutoWeightsLoader(self)
        return loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)

    def get_mm_mapping(self) -> MultiModelKeys:
        """
        Get the module prefix in multimodal models
        """
        return MultiModelKeys.from_string_field(
            language_model="language_model",
            connector="multi_modal_projector",
            tower_model="vision_tower",
        )

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality == "image":
            return "<image_soft_token>"
        elif modality == "audio":
            return "<audio_soft_token>"
        else:
            raise ValueError(f"Unsupported modality: {modality}")

    @classmethod
    def get_generation_prompt(
        cls,
        audio: np.ndarray,
        stt_config: SpeechToTextConfig,
        model_config: ModelConfig,
        language: str | None,
        task_type: Literal["transcribe", "translate"],
        request_prompt: str,
        to_language: str | None,
    ) -> PromptType:
        """
        Gemma3n supports "free-form" transcription.
        We fix its prompt here to standardize transcriptions/translations
        requests.
        """
        # Transcribe this audio [into <>] | for transcription
        # Translate this audio [from <> into <>] | for translation
        prompt = "<start_of_turn>user\n"
        prompt += "Transcribe" if task_type == "transcribe" else "Translate"
        prompt += " this audio"

        # We assume the language is a valid ISO 639-1 code.
        full_lang_name = cls.supported_languages.get(language, "")
        # Translation only for now
        full_lang_name_to = cls.supported_languages.get(to_language, "")

        if task_type == "transcribe" and full_lang_name:
            prompt += f" into {full_lang_name}"
        elif task_type == "translate":
            if full_lang_name:
                prompt += f" from {full_lang_name}"
            if full_lang_name_to:
                prompt += f" into {full_lang_name_to}"

        prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"

        return TextPrompt(
            prompt=prompt,
            multi_modal_data={"audio": (audio, stt_config.sample_rate)},
        )

    @classmethod
    def get_speech_to_text_config(
        cls, model_config: ModelConfig, task_type: str
    ) -> SpeechToTextConfig:
        return SpeechToTextConfig(
            # Let's set this to 30 as suggested in the docs for now, although
            # the model is only limited by its context length.
            max_audio_clip_s=30,
            sample_rate=16000,
            # TODO enable chunking after more thorough testing.
            min_energy_split_window_size=None,
        )

get_generation_prompt classmethod

get_generation_prompt(
    audio: ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType

Gemma3n supports "free-form" transcription. We fix its prompt here to standardize transcriptions/translations requests.

Source code in vllm/model_executor/models/gemma3n_mm.py
@classmethod
def get_generation_prompt(
    cls,
    audio: np.ndarray,
    stt_config: SpeechToTextConfig,
    model_config: ModelConfig,
    language: str | None,
    task_type: Literal["transcribe", "translate"],
    request_prompt: str,
    to_language: str | None,
) -> PromptType:
    """
    Gemma3n supports "free-form" transcription.
    We fix its prompt here to standardize transcriptions/translations
    requests.
    """
    # Transcribe this audio [into <>] | for transcription
    # Translate this audio [from <> into <>] | for translation
    prompt = "<start_of_turn>user\n"
    prompt += "Transcribe" if task_type == "transcribe" else "Translate"
    prompt += " this audio"

    # We assume the language is a valid ISO 639-1 code.
    full_lang_name = cls.supported_languages.get(language, "")
    # Translation only for now
    full_lang_name_to = cls.supported_languages.get(to_language, "")

    if task_type == "transcribe" and full_lang_name:
        prompt += f" into {full_lang_name}"
    elif task_type == "translate":
        if full_lang_name:
            prompt += f" from {full_lang_name}"
        if full_lang_name_to:
            prompt += f" into {full_lang_name_to}"

    prompt += ": <audio_soft_token><end_of_turn>\n<start_of_turn>model\n"

    return TextPrompt(
        prompt=prompt,
        multi_modal_data={"audio": (audio, stt_config.sample_rate)},
    )

get_mm_mapping

get_mm_mapping() -> MultiModelKeys

Get the module prefix in multimodal models

Source code in vllm/model_executor/models/gemma3n_mm.py
def get_mm_mapping(self) -> MultiModelKeys:
    """
    Get the module prefix in multimodal models
    """
    return MultiModelKeys.from_string_field(
        language_model="language_model",
        connector="multi_modal_projector",
        tower_model="vision_tower",
    )

Gemma3nImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height of each patch
  • w: Width of each patch
Source code in vllm/model_executor/models/gemma3n_mm.py
class Gemma3nImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each patch
        - w: Width of each patch
    """

    type: Literal["pixel_values"] = "pixel_values"
    pixel_values: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]

Gemma3nMultimodalEmbedder

Bases: Module

Embeds token ids or soft tokens for multimodal content into language model space.

Source code in vllm/model_executor/models/gemma3n_mm.py
class Gemma3nMultimodalEmbedder(nn.Module):
    """Embeds token ids or soft tokens for multimodal content into language
    model space."""

    def __init__(
        self,
        multimodal_config: Gemma3nAudioConfig | Gemma3nVisionConfig,
        text_config: Gemma3nTextConfig,
    ):
        super().__init__()

        self.multimodal_hidden_size = multimodal_config.hidden_size
        self.eps = multimodal_config.rms_norm_eps
        self.vocab_offset = multimodal_config.vocab_offset
        self.vocab_size = multimodal_config.vocab_size
        self.text_hidden_size = text_config.hidden_size

        self.embedding = VocabParallelEmbedding(
            self.vocab_size,
            self.multimodal_hidden_size,
        )

        self.hard_embedding_norm = RMSNorm(
            self.multimodal_hidden_size,
            eps=self.eps,
        )

        self.soft_embedding_norm = RMSNorm(
            self.multimodal_hidden_size,
            eps=self.eps,
        )

        self.embedding_projection = RowParallelLinear(
            self.multimodal_hidden_size,
            self.text_hidden_size,
            bias=False,
        )

        self.embedding_post_projection_norm = RMSNorm(
            self.text_hidden_size,
            eps=self.eps,
            has_weight=False,
        )

    def forward(
        self,
        input_ids: torch.LongTensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        """Embeds token ids or soft tokens for multimodal content into language model space.

        Args:
            input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
                `[vocab_offset, vocab_offset + vocab_size)`.
            inputs_embeds: A torch.Tensor containing the soft tokens to embed.

        Returns:
            A torch.Tensor of embeddings with  shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
        """  # noqa: E501
        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You must specify exactly one of input_ids or inputs_embeds"
            )

        if inputs_embeds is not None:
            emb_norm = self.soft_embedding_norm(inputs_embeds)
        else:
            hard_emb = self.embedding(input_ids - self.vocab_offset)
            emb_norm = self.hard_embedding_norm(hard_emb)

        emb_norm_proj, _ = self.embedding_projection(emb_norm)
        return self.embedding_post_projection_norm(emb_norm_proj)

forward

forward(
    input_ids: LongTensor | None = None,
    inputs_embeds: Tensor | None = None,
) -> Tensor

Embeds token ids or soft tokens for multimodal content into language model space.

Parameters:

Name Type Description Default
input_ids LongTensor | None

A torch.LongTensor containing the token ids to embed. Values should be in the range [vocab_offset, vocab_offset + vocab_size).

None
inputs_embeds Tensor | None

A torch.Tensor containing the soft tokens to embed.

None

Returns:

Type Description
Tensor

A torch.Tensor of embeddings with shape [batch_size, seq_len, self.config.text_config.hidden_size].

Source code in vllm/model_executor/models/gemma3n_mm.py
def forward(
    self,
    input_ids: torch.LongTensor | None = None,
    inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor:
    """Embeds token ids or soft tokens for multimodal content into language model space.

    Args:
        input_ids: A torch.LongTensor containing the token ids to embed. Values should be in the range
            `[vocab_offset, vocab_offset + vocab_size)`.
        inputs_embeds: A torch.Tensor containing the soft tokens to embed.

    Returns:
        A torch.Tensor of embeddings with  shape `[batch_size, seq_len, self.config.text_config.hidden_size]`.
    """  # noqa: E501
    if (input_ids is None) ^ (inputs_embeds is not None):
        raise ValueError(
            "You must specify exactly one of input_ids or inputs_embeds"
        )

    if inputs_embeds is not None:
        emb_norm = self.soft_embedding_norm(inputs_embeds)
    else:
        hard_emb = self.embedding(input_ids - self.vocab_offset)
        emb_norm = self.hard_embedding_norm(hard_emb)

    emb_norm_proj, _ = self.embedding_projection(emb_norm)
    return self.embedding_post_projection_norm(emb_norm_proj)

Gemma3nProcessingInfo

Bases: BaseProcessingInfo

Source code in vllm/model_executor/models/gemma3n_mm.py
class Gemma3nProcessingInfo(BaseProcessingInfo):
    def get_hf_config(self):
        return self.ctx.get_hf_config(Gemma3nConfig)

    def get_hf_processor(self, **kwargs: object):
        return self.ctx.get_hf_processor(Gemma3nProcessor, **kwargs)

    def get_feature_extractor(self, **kwargs: object) -> Gemma3nAudioFeatureExtractor:
        return self.get_hf_processor(**kwargs).feature_extractor

    def get_data_parser(self):
        feature_extractor = self.get_feature_extractor()

        return MultiModalDataParser(
            target_sr=feature_extractor.sampling_rate,
            expected_hidden_size=self._get_expected_hidden_size(),
        )

    def get_supported_mm_limits(self) -> Mapping[str, int | None]:
        return {"image": None, "audio": None}

    def get_max_tokens_per_item(
        self, seq_len: int, mm_counts: Mapping[str, int]
    ) -> Mapping[str, int] | None:
        return {"image": TOKENS_PER_IMAGE, "audio": TOKENS_PER_AUDIO}

    def get_image_repl(
        self,
        *,
        image_width: int,
        image_height: int,
        processor: Gemma3nProcessor | None,
    ) -> str:
        """
        Get the replacement text for image tokens.

        For Gemma3n, this should return the full_image_sequence which includes
        BOI token, repeated image tokens, and EOI token.
        """
        if processor is None:
            processor = self.get_hf_processor()

        return PromptUpdateDetails.select_token_id(
            processor.full_image_sequence, processor.image_token_id
        )

    def get_audio_repl(
        self,
        *,
        processor: Gemma3nProcessor | None,
    ) -> str:
        """
        Get the replacement text for audio tokens.

        For Gemma3n, this should return the full_audio_sequence which includes
        BOA token, repeated audio tokens, and EOA token.
        """
        if processor is None:
            processor = self.get_hf_processor()

        # Return the full audio sequence as defined by the processor
        return PromptUpdateDetails.select_token_id(
            processor.full_audio_sequence, processor.audio_token_id
        )

get_audio_repl

get_audio_repl(
    *, processor: Gemma3nProcessor | None
) -> str

Get the replacement text for audio tokens.

For Gemma3n, this should return the full_audio_sequence which includes BOA token, repeated audio tokens, and EOA token.

Source code in vllm/model_executor/models/gemma3n_mm.py
def get_audio_repl(
    self,
    *,
    processor: Gemma3nProcessor | None,
) -> str:
    """
    Get the replacement text for audio tokens.

    For Gemma3n, this should return the full_audio_sequence which includes
    BOA token, repeated audio tokens, and EOA token.
    """
    if processor is None:
        processor = self.get_hf_processor()

    # Return the full audio sequence as defined by the processor
    return PromptUpdateDetails.select_token_id(
        processor.full_audio_sequence, processor.audio_token_id
    )

get_image_repl

get_image_repl(
    *,
    image_width: int,
    image_height: int,
    processor: Gemma3nProcessor | None,
) -> str

Get the replacement text for image tokens.

For Gemma3n, this should return the full_image_sequence which includes BOI token, repeated image tokens, and EOI token.

Source code in vllm/model_executor/models/gemma3n_mm.py
def get_image_repl(
    self,
    *,
    image_width: int,
    image_height: int,
    processor: Gemma3nProcessor | None,
) -> str:
    """
    Get the replacement text for image tokens.

    For Gemma3n, this should return the full_image_sequence which includes
    BOI token, repeated image tokens, and EOI token.
    """
    if processor is None:
        processor = self.get_hf_processor()

    return PromptUpdateDetails.select_token_id(
        processor.full_image_sequence, processor.image_token_id
    )