Skip to content

vllm.model_executor.models.siglip

SiglipAttention

Bases: Module

Source code in vllm/model_executor/models/siglip.py
class SiglipAttention(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig | SiglipTextConfig,
        quant_config: QuantizationConfig | None = None,
        *,
        prefix: str = "",
        attn_cls: type[EncoderOnlyAttention] | type[MMEncoderAttention],
    ) -> None:
        super().__init__()

        self.config = config
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads
        if self.head_dim * self.num_heads != self.embed_dim:
            raise ValueError(
                f"embed_dim must be divisible by num_heads "
                f"(got `embed_dim`: {self.embed_dim} and "
                f"`num_heads`: {self.num_heads})."
            )

        self.scale = self.head_dim**-0.5

        use_data_parallel = is_vit_use_data_parallel()
        self.qkv_proj = QKVParallelLinear(
            hidden_size=self.embed_dim,
            head_size=self.head_dim,
            total_num_heads=self.num_heads,
            quant_config=quant_config,
            prefix=f"{prefix}.qkv_proj",
            disable_tp=use_data_parallel,
        )

        self.out_proj = RowParallelLinear(
            input_size=self.embed_dim,
            output_size=self.embed_dim,
            quant_config=quant_config,
            prefix=f"{prefix}.out_proj",
            disable_tp=use_data_parallel,
        )

        self.tp_size = (
            1 if use_data_parallel else get_tensor_model_parallel_world_size()
        )
        self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

        if attn_cls == MMEncoderAttention:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )
        else:
            self.attn = attn_cls(
                self.num_heads_per_partition,
                self.head_dim,
                self.scale,
                prefix=f"{prefix}.attn",
            )

    def forward(
        self,
        hidden_states: torch.Tensor,
    ) -> tuple[torch.Tensor, None]:
        """Input shape: Batch x Time x Channel"""
        qkv_states, _ = self.qkv_proj(hidden_states)
        query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
        out = self.attn(query_states, key_states, value_states)
        attn_output, _ = self.out_proj(out)

        return attn_output, None

forward

forward(hidden_states: Tensor) -> tuple[Tensor, None]

Input shape: Batch x Time x Channel

Source code in vllm/model_executor/models/siglip.py
def forward(
    self,
    hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, None]:
    """Input shape: Batch x Time x Channel"""
    qkv_states, _ = self.qkv_proj(hidden_states)
    query_states, key_states, value_states = qkv_states.chunk(3, dim=-1)
    out = self.attn(query_states, key_states, value_states)
    attn_output, _ = self.out_proj(out)

    return attn_output, None

SiglipEmbeddingModel

Bases: Module, SupportsMultiModal, SupportsQuant

Source code in vllm/model_executor/models/siglip.py
@default_pooling_type(seq_pooling_type="CLS")
@MULTIMODAL_REGISTRY.register_processor(
    SiglipMultiModalProcessor,
    info=SiglipProcessingInfo,
    dummy_inputs=SiglipDummyInputsBuilder,
)
class SiglipEmbeddingModel(nn.Module, SupportsMultiModal, SupportsQuant):
    is_pooling_model = True

    packed_modules_mapping = {"qkv_proj": ["q_proj", "k_proj", "v_proj"]}

    @classmethod
    def get_placeholder_str(cls, modality: str, i: int) -> str | None:
        if modality.startswith("image"):
            return None

        raise ValueError("Only image modality is supported")

    def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
        super().__init__()

        config: SiglipConfig = vllm_config.model_config.hf_config
        quant_config = vllm_config.quant_config
        self.config = config

        if hasattr(config, "num_labels"):
            config.num_labels = 0

        text_config = config.text_config
        vision_config = config.vision_config

        self.text_embed_dim = text_config.hidden_size
        self.vision_embed_dim = vision_config.hidden_size
        self.text_projection_size = text_config.projection_size

        with self._mark_language_model(vllm_config):
            self.text_model = SiglipTextTransformer(
                text_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "text_model"),
            )

        with self._mark_tower_model(vllm_config, "image"):
            self.vision_model = SiglipVisionTransformer(
                vision_config,
                quant_config=quant_config,
                prefix=maybe_prefix(prefix, "vision_model"),
                use_head=None,  # Allows potential pooling head
            )

        pooler_config = vllm_config.model_config.pooler_config
        assert pooler_config is not None
        self.pooler_config = pooler_config

        self.pooler = DispatchPooler.for_embedding(pooler_config)

        self._is_text_input = True

    def get_text_features(
        self,
        input_ids: torch.Tensor | None,
        position_ids: torch.Tensor,
        inputs_embeds: torch.Tensor | None = None,
    ) -> torch.Tensor:
        last_hidden_state = self.text_model(
            input_ids=input_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
        )
        text_features = self.text_model.head(last_hidden_state)

        # SigLIP uses reversed position_ids;
        # flip sequences to move EOS token to first position
        text_features = self._flip_sequences_by_position_ids(
            text_features, position_ids
        )

        return text_features

    def _flip_sequences_by_position_ids(
        self,
        features: torch.Tensor,
        position_ids: torch.Tensor,
    ) -> torch.Tensor:
        """Flip sequences so EOS token moves to first position for CLS pooling.

        SigLIP position_ids are reversed within each sequence. This method detects
        sequence boundaries and flips each sequence individually.
        """
        if len(features) == 1:
            return features

        # Detect sequence boundaries where position_ids decrease
        position_diffs = position_ids[1:] - position_ids[:-1]
        boundary_mask = position_diffs <= 0

        boundary_indices = torch.cat(
            [
                torch.tensor([0], device=features.device),
                torch.where(boundary_mask)[0] + 1,
                torch.tensor([len(features)], device=features.device),
            ]
        )

        # For each sequence [start, end), position i flips to: start + end - 1 - i
        lengths = boundary_indices[1:] - boundary_indices[:-1]
        starts = boundary_indices[:-1]
        ends = boundary_indices[1:]

        # Assign sequence ID to each element
        sequence_ids = torch.arange(
            len(lengths), device=features.device
        ).repeat_interleave(lengths)

        # Calculate flipped indices for all positions at once
        current_positions = torch.arange(len(features), device=features.device)
        flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions

        return features[flip_indices]

    def get_image_features(
        self,
        pixel_values: torch.Tensor,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
    ) -> torch.Tensor:
        if feature_select_strategy is None:
            feature_select_strategy = _get_vision_feature_select_strategy(
                self.pooler_config.seq_pooling_type
            )

        pooled_output = self.vision_model(
            pixel_values=pixel_values,
            select_layers=None,
            feature_select_strategy=feature_select_strategy,
        )

        return pooled_output

    def _parse_and_validate_image_input(
        self, **kwargs: object
    ) -> SiglipImagePixelInputs | None:
        pixel_values = kwargs.pop("pixel_values", None)
        if pixel_values is None:
            return None

        expected_h = expected_w = self.config.vision_config.image_size
        return SiglipImagePixelInputs(
            type="pixel_values",
            data=pixel_values,
            resolve_bindings={"h": expected_h, "w": expected_w},
        )

    def _process_image_inputs(self, inputs: SiglipImagePixelInputs) -> torch.Tensor:
        pixel_values = inputs["data"]

        return self.get_image_features(pixel_values)

    def _embed_text_input_ids(
        self,
        input_ids: torch.Tensor,
        embed_input_ids: Callable[[torch.Tensor], torch.Tensor],
        *,
        is_multimodal: torch.Tensor | None,
        handle_oov_mm_token: bool,
    ) -> torch.Tensor:
        inputs_embeds = super()._embed_text_input_ids(
            input_ids,
            embed_input_ids,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

        # NOTE: inputs_embeds in model runner has size text_config.projection_size
        # (instead of text_config.hidden_size) to accommodate image embeddings
        inputs_embeds_size = self.text_projection_size
        if inputs_embeds.shape[1] < inputs_embeds_size:
            inputs_embeds = torch.cat(
                [
                    inputs_embeds,
                    inputs_embeds.new_empty(
                        inputs_embeds.shape[0],
                        inputs_embeds_size - inputs_embeds.shape[1],
                    ),
                ],
                dim=1,
            )
        elif inputs_embeds.shape[1] > inputs_embeds_size:
            # No need to handle this case for now
            raise NotImplementedError

        return inputs_embeds

    def embed_input_ids(
        self,
        input_ids: torch.Tensor,
        multimodal_embeddings: MultiModalEmbeddings | None = None,
        *,
        is_multimodal: torch.Tensor | None = None,
        handle_oov_mm_token: bool = False,
    ) -> torch.Tensor:
        self._is_text_input = (
            multimodal_embeddings is None or len(multimodal_embeddings) == 0
        )

        if multimodal_embeddings is None or is_multimodal is None:
            return super().embed_input_ids(input_ids)

        return super().embed_input_ids(
            input_ids,
            multimodal_embeddings=multimodal_embeddings,
            is_multimodal=is_multimodal,
            handle_oov_mm_token=handle_oov_mm_token,
        )

    def embed_multimodal(self, **kwargs: object) -> MultiModalEmbeddings:
        image_input = self._parse_and_validate_image_input(**kwargs)
        if image_input is None:
            return []

        vision_embeddings = self._process_image_inputs(image_input)
        return vision_embeddings

    def forward(
        self,
        input_ids: torch.Tensor | None,
        positions: torch.Tensor,
        intermediate_tensors: IntermediateTensors | None = None,
        inputs_embeds: torch.Tensor | None = None,
        **kwargs: object,
    ) -> torch.Tensor:
        if intermediate_tensors is not None:
            raise RuntimeError("PP is not supported for this model")

        # Multimodal inputs (image embeddings)
        if not self._is_text_input:
            return inputs_embeds

        # NOTE: inputs_embeds in model runner has size text_config.projection_size
        # (instead of text_config.hidden_size) to accommodate image embeddings
        hidden_size = self.text_embed_dim
        if inputs_embeds.shape[1] > hidden_size:
            inputs_embeds = inputs_embeds[:, :hidden_size]
        elif inputs_embeds.shape[1] < hidden_size:
            # No need to handle this case for now
            raise NotImplementedError

        return self.get_text_features(input_ids, positions, inputs_embeds)

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
        loader = AutoWeightsLoader(
            self,
            skip_substrs=[".position_ids"],
            ignore_unexpected_prefixes=["logit_scale.", "logit_bias."],
        )

        return loader.load_weights(weights)

_flip_sequences_by_position_ids

_flip_sequences_by_position_ids(
    features: Tensor, position_ids: Tensor
) -> Tensor

Flip sequences so EOS token moves to first position for CLS pooling.

SigLIP position_ids are reversed within each sequence. This method detects sequence boundaries and flips each sequence individually.

Source code in vllm/model_executor/models/siglip.py
def _flip_sequences_by_position_ids(
    self,
    features: torch.Tensor,
    position_ids: torch.Tensor,
) -> torch.Tensor:
    """Flip sequences so EOS token moves to first position for CLS pooling.

    SigLIP position_ids are reversed within each sequence. This method detects
    sequence boundaries and flips each sequence individually.
    """
    if len(features) == 1:
        return features

    # Detect sequence boundaries where position_ids decrease
    position_diffs = position_ids[1:] - position_ids[:-1]
    boundary_mask = position_diffs <= 0

    boundary_indices = torch.cat(
        [
            torch.tensor([0], device=features.device),
            torch.where(boundary_mask)[0] + 1,
            torch.tensor([len(features)], device=features.device),
        ]
    )

    # For each sequence [start, end), position i flips to: start + end - 1 - i
    lengths = boundary_indices[1:] - boundary_indices[:-1]
    starts = boundary_indices[:-1]
    ends = boundary_indices[1:]

    # Assign sequence ID to each element
    sequence_ids = torch.arange(
        len(lengths), device=features.device
    ).repeat_interleave(lengths)

    # Calculate flipped indices for all positions at once
    current_positions = torch.arange(len(features), device=features.device)
    flip_indices = starts[sequence_ids] + ends[sequence_ids] - 1 - current_positions

    return features[flip_indices]

SiglipImagePixelInputs

Bases: TensorSchema

Dimensions
  • bn: Batch size * number of images
  • c: Number of channels (3)
  • h: Height of each image
  • w: Width of each image
Source code in vllm/model_executor/models/siglip.py
class SiglipImagePixelInputs(TensorSchema):
    """
    Dimensions:
        - bn: Batch size * number of images
        - c: Number of channels (3)
        - h: Height of each image
        - w: Width of each image
    """

    type: Literal["pixel_values"]
    data: Annotated[torch.Tensor, TensorShape("bn", 3, "h", "w")]

SiglipMultiheadAttentionPoolingHead

Bases: Module

Multihead Attention Pooling.

Source code in vllm/model_executor/models/siglip.py
class SiglipMultiheadAttentionPoolingHead(nn.Module):
    """Multihead Attention Pooling."""

    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: QuantizationConfig | None = None,
        prefix: str = "",
    ) -> None:
        super().__init__()

        self.probe = nn.Parameter(torch.randn(1, 1, config.hidden_size))
        # TODO(ChristopherCho): Implement vLLM version of MultiheadAttention
        self.attention = torch.nn.MultiheadAttention(
            config.hidden_size, config.num_attention_heads, batch_first=True
        )
        self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mlp = SiglipMLP(
            config=config,
            quant_config=quant_config,
            prefix=f"{prefix}.mlp",
        )

    def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
        batch_size = hidden_state.size(0)

        probe = self.probe.expand(batch_size, -1, -1)

        hidden_state = self.attention(probe, hidden_state, hidden_state)[0]

        residual = hidden_state
        hidden_state = self.layernorm(hidden_state)
        hidden_state = self.mlp(hidden_state)
        hidden_state += residual

        # Handled by resolve_visual_encoder_outputs
        # return hidden_state[:, 0]
        return hidden_state

SiglipVisionTransformer

Bases: Module

Source code in vllm/model_executor/models/siglip.py
class SiglipVisionTransformer(nn.Module):
    def __init__(
        self,
        config: SiglipVisionConfig,
        quant_config: QuantizationConfig | None = None,
        *,
        num_hidden_layers_override: int | None = None,
        require_post_norm: bool | None = None,
        prefix: str = "",
        use_head: bool | None = False,
    ) -> None:
        super().__init__()

        self.config = config
        embed_dim = config.hidden_size

        self.embeddings = SiglipVisionEmbeddings(config)

        self.encoder = SiglipEncoder(
            config,
            quant_config=quant_config,
            num_hidden_layers_override=num_hidden_layers_override,
            prefix=f"{prefix}.encoder",
            attn_cls=MMEncoderAttention,
        )

        num_hidden_layers = config.num_hidden_layers
        if len(self.encoder.layers) > config.num_hidden_layers:
            raise ValueError(
                f"The original encoder only has {num_hidden_layers} "
                f"layers, but you requested {len(self.encoder.layers)} layers."
            )

        # If possible, skip post_layernorm to conserve memory
        if require_post_norm is None:
            require_post_norm = len(self.encoder.layers) == num_hidden_layers

        if require_post_norm:
            self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
        else:
            self.post_layernorm = None

        # Fall back to the config if a bool is not provided explicitly;
        # note that many config types, including SiglipVisionConfig,
        # do not have vision_use_head as a defined attribute.
        if isinstance(use_head, bool):
            self.use_head = use_head
        else:
            self.use_head = (
                True
                if not hasattr(config, "vision_use_head")
                else config.vision_use_head
            )

        # Only create and load the head weights if we actually need them
        self.head = (
            SiglipMultiheadAttentionPoolingHead(
                config=config,
                quant_config=quant_config,
                prefix=f"{prefix}.head",
            )
            if self.use_head
            else None
        )
        self.last_hs_proc = partial(self.maybe_layer_norm_and_apply_head)

    @property
    def dtype(self):
        return next(self.parameters()).dtype

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(
        self,
        pixel_values: torch.Tensor,
        *,
        interpolate_pos_encoding: bool = False,
        select_layers: list[int] | None = None,
        feature_select_strategy: VisionFeatureSelectStrategy | None = None,
    ) -> torch.Tensor:
        hidden_states = self.embeddings(
            pixel_values,
            interpolate_pos_encoding=interpolate_pos_encoding,
        )
        # Produces either the last layer output or all of the hidden states,
        # depending on if we have select_layers or not
        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            return_all_hidden_states=select_layers is not None,
        )

        # In the case that we have multiple feature layers,
        # we stack and concatenate them into a tensor.
        # NOTE: post layer norm and the attention pooling head
        # are handled by last_hs_proc, which runs before applying
        # the vision feature selection strategy.
        encoder_outputs = resolve_visual_encoder_outputs(
            encoder_outputs,
            None,
            select_layers=select_layers,
            max_possible_layers=self.config.num_hidden_layers,
            last_hs_proc=self.last_hs_proc,
            feature_select_strategy=feature_select_strategy,
        )

        return encoder_outputs

    def maybe_layer_norm_and_apply_head(
        self, encoder_outputs: torch.Tensor
    ) -> torch.Tensor:
        """Apply the post layer norm and head if they are enabled,
        given the last hidden states tensor.

        args:
            encoder_outputs: The last hidden states from the visual encoder.
        """
        if self.post_layernorm is not None:
            encoder_outputs = self.post_layernorm(encoder_outputs)
        if self.head is not None:
            encoder_outputs = self.head(encoder_outputs)
        return encoder_outputs

    def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
        stacked_params_mapping = [
            # (param_name, shard_name, shard_id)
            ("qkv_proj", "q_proj", "q"),
            ("qkv_proj", "k_proj", "k"),
            ("qkv_proj", "v_proj", "v"),
        ]
        params_dict = dict(self.named_parameters())
        loaded_params: set[str] = set()
        layer_count = len(self.encoder.layers)

        for name, loaded_weight in weights:
            # post_layernorm is not needed in SiglipVisionTransformer
            if name.startswith("post_layernorm") and self.post_layernorm is None:
                continue

            # if the model configuration is not going to use
            # the pooling head for inference, don't load its weights
            if self.head is None and name.startswith("head"):
                continue

            # omit layers when num_hidden_layers_override is set
            if name.startswith("encoder.layers"):
                layer_idx = int(name.split(".")[2])
                if layer_idx >= layer_count:
                    continue

            for param_name, weight_name, shard_id in stacked_params_mapping:
                if weight_name not in name:
                    continue
                name = name.replace(weight_name, param_name)
                param = params_dict[name]
                weight_loader = param.weight_loader
                weight_loader(param, loaded_weight, shard_id)
                break
            else:
                param = params_dict[name]
                weight_loader = getattr(param, "weight_loader", default_weight_loader)
                weight_loader(param, loaded_weight)
            loaded_params.add(name)
        return loaded_params

maybe_layer_norm_and_apply_head

maybe_layer_norm_and_apply_head(
    encoder_outputs: Tensor,
) -> Tensor

Apply the post layer norm and head if they are enabled, given the last hidden states tensor.

Parameters:

Name Type Description Default
encoder_outputs Tensor

The last hidden states from the visual encoder.

required
Source code in vllm/model_executor/models/siglip.py
def maybe_layer_norm_and_apply_head(
    self, encoder_outputs: torch.Tensor
) -> torch.Tensor:
    """Apply the post layer norm and head if they are enabled,
    given the last hidden states tensor.

    args:
        encoder_outputs: The last hidden states from the visual encoder.
    """
    if self.post_layernorm is not None:
        encoder_outputs = self.post_layernorm(encoder_outputs)
    if self.head is not None:
        encoder_outputs = self.head(encoder_outputs)
    return encoder_outputs