Skip to content

vllm.inputs.preprocess

InputPreprocessor

Source code in vllm/inputs/preprocess.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
class InputPreprocessor:
    def __init__(
        self,
        model_config: ModelConfig,
        observability_config: ObservabilityConfig | None = None,
        renderer: BaseRenderer | None = None,
        mm_registry: MultiModalRegistry = MULTIMODAL_REGISTRY,
        mm_processor_cache: BaseMultiModalProcessorCache | None = None,
    ) -> None:
        super().__init__()

        self.model_config = model_config
        self.observability_config = observability_config
        self.renderer = renderer or renderer_from_config(model_config)
        self.mm_registry = mm_registry
        self.mm_processor_cache = mm_processor_cache

        self.mm_cache_stats = MultiModalCacheStats() if mm_processor_cache else None

    @property
    def tokenizer(self) -> TokenizerLike | None:
        return self.renderer.tokenizer

    def get_tokenizer(self) -> TokenizerLike:
        return self.renderer.get_tokenizer()

    def get_bos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for BOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.bos_token_id

    def get_eos_token_id(self) -> int | None:
        if self.tokenizer is None:
            logger.warning_once(
                "Using None for EOS token id because tokenizer is not initialized"
            )
            return None

        return self.tokenizer.eos_token_id

    def get_decoder_start_token_id(self) -> int:
        """
        Obtain the decoder start token id employed by an encoder/decoder
        model. Raises an error if it is not available.
        """
        dec_start_token_id = getattr(
            self.model_config.hf_config, "decoder_start_token_id", None
        )

        if dec_start_token_id is None:
            logger.warning_once(
                "Falling back on <BOS> for decoder start token "
                "id because decoder start token id is not "
                "available."
            )
            dec_start_token_id = self.get_bos_token_id()

        if dec_start_token_id is None:
            raise RuntimeError("Cannot find decoder start token id or <BOS>")

        return dec_start_token_id

    def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
        """
        Prepares `decoder_input_ids` for generation with encoder-decoder models.

        Based on:
        https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
        specifically,
        `GenerationMixin._prepare_decoder_input_ids_for_generation()`.

        Arguments:

        * decoder_input_ids: input token ids to preprocess

        Returns:

        * Processed token list
        """
        decoder_start_token_id = self.get_decoder_start_token_id()

        if (
            len(decoder_input_ids) == 0
            or decoder_input_ids[0] != decoder_start_token_id
        ):
            decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

        return decoder_input_ids

    def _get_tokenization_kw(
        self,
        overrides: dict[str, Any] | None = None,
    ) -> dict[str, Any]:
        kwargs = dict[str, Any]()

        if self.model_config.is_encoder_decoder:
            # For Whisper, special tokens should be provided by the user based
            # on the task and language of their request. Also needed to avoid
            # appending an EOS token to the prompt which disrupts generation.
            kwargs["add_special_tokens"] = False

        if overrides:
            kwargs.update(overrides)

        return kwargs

    def _tokenize_prompt(
        self,
        prompt: str,
        tokenization_kwargs: dict[str, Any] | None = None,
    ) -> list[int]:
        """
        Apply the model's tokenizer to a text prompt, returning the
        corresponding token IDs.
        """
        tokenizer = self.get_tokenizer()
        tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)

        encoder_config = self.model_config.encoder_config

        if encoder_config and encoder_config.get("do_lower_case", False):
            prompt = prompt.lower()

        return tokenizer.encode(prompt, **tokenization_kwargs)

    def _get_mm_processor(self) -> BaseMultiModalProcessor:
        if not hasattr(self, "_mm_processor"):
            self._mm_processor = self.mm_registry.create_processor(
                self.model_config,
                self.observability_config,
                tokenizer=self.tokenizer,
                cache=self.mm_processor_cache,
            )

        return self._mm_processor

    def _process_multimodal(
        self,
        prompt: str | list[int],
        mm_data: MultiModalDataDict,
        mm_processor_kwargs: Mapping[str, object] | None,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> MultiModalInputs:
        """
        Apply the model's multi-modal processor to a multi-modal prompt,
        returning the corresponding token IDs and metadata.
        """
        mm_processor = self._get_mm_processor()

        if mm_processor_kwargs is None:
            mm_processor_kwargs = {}

        mm_items = mm_processor.info.parse_mm_data(mm_data)
        mm_input = mm_processor.apply(
            prompt,
            mm_items,
            hf_processor_mm_kwargs=mm_processor_kwargs,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )
        mm_hashes = mm_input["mm_hashes"]

        # Validate that all mm items have a string as their hash
        contains_only_strings = all(
            isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
        )
        if not contains_only_strings:
            raise ValueError(
                f"mm_hashes must contain only strings, got: {mm_hashes}. "
                "This is likely due to an incorrect custom implementation of "
                "MultiModalProcessor.apply method."
            )

        return mm_input

    def _process_embeds(
        self,
        parsed_content: EmbedsPrompt,
    ) -> EmbedsInputs:
        if not self.model_config.enable_prompt_embeds:
            raise ValueError(
                "You must set `--enable-prompt-embeds` to input `prompt_embeds`."
            )

        prompt_embeds = parsed_content["prompt_embeds"]

        # prompt_embeds must be (seq_len, hidden_size), but if the user
        # passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
        # we can unambiguously process the intent by squeezing the batch
        # dimension.
        if prompt_embeds.ndim == 3:
            prompt_embeds = prompt_embeds.squeeze(dim=0)

        if prompt_embeds.ndim != 2:
            raise ValueError("prompt_embeds must be of shape (seq_len, hidden_size).")

        # Tensors must be on CPU for serialization between processes
        # in the MsgpackEncoder. Casting to CPU here ensures that there is no
        # hidden device transfer in the critical path of generation.
        prompt_embeds = prompt_embeds.cpu()

        return embeds_inputs(
            prompt_embeds=prompt_embeds, cache_salt=parsed_content.get("cache_salt")
        )

    def _truncate_inputs(
        self, inputs: list[int], tokenization_kwargs: dict[str, Any] | None = None
    ) -> list[int]:
        if (
            not tokenization_kwargs
            or "truncation" not in tokenization_kwargs
            or self.tokenizer is None
        ):
            return inputs

        max_length = tokenization_kwargs["max_length"]

        if self.tokenizer.truncation_side == "left":
            return inputs[-max_length:]
        else:
            return inputs[:max_length]

    def _process_tokens(
        self,
        parsed_content: TokensPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
        prompt_token_ids = self._truncate_inputs(
            parsed_content["prompt_token_ids"], tokenization_kwargs
        )

        inputs: TokenInputs | MultiModalInputs
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_token_ids,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs") or {},
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            )
        else:
            inputs = token_inputs(prompt_token_ids)

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    def _process_text(
        self,
        parsed_content: TextPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> TokenInputs | MultiModalInputs:
        prompt_text = parsed_content["prompt"]

        inputs: TokenInputs | MultiModalInputs
        if multi_modal_data := parsed_content.get("multi_modal_data"):
            inputs = self._process_multimodal(
                prompt_text,
                multi_modal_data,
                parsed_content.get("mm_processor_kwargs") or {},
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            )
        else:
            prompt_token_ids = self._tokenize_prompt(
                prompt_text,
                tokenization_kwargs=tokenization_kwargs,
            )
            inputs = token_inputs(prompt_token_ids)

        if cache_salt := parsed_content.get("cache_salt"):
            inputs["cache_salt"] = cache_salt

        return inputs

    @overload
    def _prompt_to_llm_inputs(
        self,
        prompt: EncoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> EncoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderInputs: ...

    @overload
    def _prompt_to_llm_inputs(  # type: ignore[misc]
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderOnlyInputs: ...

    def _prompt_to_llm_inputs(
        self,
        prompt: SingletonDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> SingletonInputs:
        """
        Extract the singleton inputs from a prompt.

        Arguments:

        * prompt: single encoder or decoder input prompt

        Returns:

        * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
        """
        if "prompt_embeds" in prompt:
            return self._process_embeds(prompt)  # type: ignore[arg-type]

        if "prompt_token_ids" in prompt:
            return self._process_tokens(
                prompt,  # type: ignore[arg-type]
                mm_uuids=mm_uuids,
            )

        if "prompt" in prompt:
            return self._process_text(
                prompt,  # type: ignore[arg-type]
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            )

        assert_never(prompt)  # type: ignore[arg-type]

    def _validate_enc_inputs(self, inputs: SingletonInputs) -> EncoderInputs:
        if inputs["type"] == "embeds":
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )

        if inputs["type"] == "multimodal" and "encoder_prompt_token_ids" not in inputs:
            raise RuntimeError(
                "You should register an encoder-decoder "
                "multi-modal processor for encoder-decoder models."
            )

        return inputs  # type: ignore[return-value]

    def _validate_dec_inputs(self, inputs: SingletonInputs) -> DecoderInputs:
        if inputs["type"] == "embeds":
            raise ValueError(
                "Embedding inputs are not supported for encoder-decoder models"
            )

        return inputs

    def _build_enc_dec_inputs(
        self,
        encoder_inputs: SingletonInputs,
        decoder_inputs: SingletonInputs | None = None,
    ) -> EncoderDecoderInputs:
        enc_inputs = self._validate_enc_inputs(encoder_inputs)

        if decoder_inputs is None:
            dec_inputs: DecoderInputs = enc_inputs  # type: ignore[assignment]
        else:
            dec_inputs = self._validate_dec_inputs(decoder_inputs)

        enc_inputs_new: EncoderInputs
        dec_inputs_new: DecoderInputs

        if enc_inputs["type"] == "multimodal":
            enc_inputs_new = token_inputs(enc_inputs["encoder_prompt_token_ids"])
            dec_inputs_new = MultiModalInputs(
                type="multimodal",
                prompt_token_ids=dec_inputs["prompt_token_ids"],
                mm_kwargs=enc_inputs["mm_kwargs"],
                mm_hashes=enc_inputs["mm_hashes"],
                mm_placeholders=enc_inputs["mm_placeholders"],
            )
        elif enc_inputs["type"] == "token":
            enc_inputs_new = token_inputs(prompt_token_ids=[])
            dec_inputs_new = dec_inputs
        else:
            assert_never(enc_inputs)

        dec_inputs_new["prompt_token_ids"] = self._prepare_decoder_input_ids(
            dec_inputs_new["prompt_token_ids"]
        )
        if cache_salt := enc_inputs.get("cache_salt"):
            dec_inputs_new["cache_salt"] = cache_salt

        return EncoderDecoderInputs(encoder=enc_inputs_new, decoder=dec_inputs_new)

    def _process_encoder_decoder_prompt(
        self,
        prompt: EncoderDecoderDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> EncoderDecoderInputs:
        """
        For encoder/decoder models only:
        Process an input prompt into an
        [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
        instance.

        Arguments:

        * prompt: an input prompt

        Returns:

        * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
          instance
        """
        encoder_prompt = prompt["encoder_prompt"]
        decoder_prompt = prompt["decoder_prompt"]

        return self._build_enc_dec_inputs(
            encoder_inputs=self._prompt_to_llm_inputs(
                encoder_prompt,
                tokenization_kwargs=tokenization_kwargs,
                mm_uuids=mm_uuids,
            ),
            decoder_inputs=(
                None
                if decoder_prompt is None
                else self._prompt_to_llm_inputs(
                    decoder_prompt,
                    tokenization_kwargs=tokenization_kwargs,
                )
            ),
        )

    def _process_decoder_only_prompt(
        self,
        prompt: DecoderOnlyDictPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> DecoderOnlyInputs:
        """
        For decoder-only models:
        Process an input prompt into a
        [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.

        Arguments:

        * prompt: input prompt

        Returns:

        * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
        """
        return self._prompt_to_llm_inputs(
            prompt,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def _preprocess(
        self,
        prompt: PromptType | DictPrompt | TokPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> ProcessorInputs:
        if self.model_config.is_encoder_decoder:
            # Encoder-decoder model requires special mapping of
            # input prompts to encoder & decoder.
            return self._process_encoder_decoder_prompt(
                parse_enc_dec_prompt(prompt),
                tokenization_kwargs,
                mm_uuids=mm_uuids,
            )

        return self._process_decoder_only_prompt(
            parse_dec_only_prompt(prompt),
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    def preprocess(
        self,
        prompt: PromptType | DictPrompt | TokPrompt,
        tokenization_kwargs: dict[str, Any] | None = None,
        *,
        mm_uuids: MultiModalUUIDDict | None = None,
    ) -> ProcessorInputs:
        """Preprocess the input prompt."""
        res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)

        if self.mm_processor_cache and self.mm_cache_stats is not None:
            delta = self.mm_processor_cache.make_stats(delta=True)
            self.mm_cache_stats.requests += 1
            self.mm_cache_stats.queries += delta.total
            self.mm_cache_stats.hits += delta.hits

        return res

    def stat_mm_cache(self) -> MultiModalCacheStats | None:
        mm_cache_stats = self.mm_cache_stats
        if mm_cache_stats is None:
            return None

        self.mm_cache_stats = MultiModalCacheStats()

        return mm_cache_stats

    def clear_mm_cache(self) -> None:
        if self.mm_processor_cache is not None:
            self.mm_processor_cache.clear_cache()

        if self.mm_cache_stats is not None:
            self.mm_cache_stats.reset = True

_prepare_decoder_input_ids

_prepare_decoder_input_ids(
    decoder_input_ids: list[int],
) -> list[int]

Prepares decoder_input_ids for generation with encoder-decoder models.

Based on: https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py specifically, GenerationMixin._prepare_decoder_input_ids_for_generation().

Arguments:

  • decoder_input_ids: input token ids to preprocess

Returns:

  • Processed token list
Source code in vllm/inputs/preprocess.py
def _prepare_decoder_input_ids(self, decoder_input_ids: list[int]) -> list[int]:
    """
    Prepares `decoder_input_ids` for generation with encoder-decoder models.

    Based on:
    https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
    specifically,
    `GenerationMixin._prepare_decoder_input_ids_for_generation()`.

    Arguments:

    * decoder_input_ids: input token ids to preprocess

    Returns:

    * Processed token list
    """
    decoder_start_token_id = self.get_decoder_start_token_id()

    if (
        len(decoder_input_ids) == 0
        or decoder_input_ids[0] != decoder_start_token_id
    ):
        decoder_input_ids = [decoder_start_token_id] + decoder_input_ids

    return decoder_input_ids

_process_decoder_only_prompt

_process_decoder_only_prompt(
    prompt: DecoderOnlyDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs

For decoder-only models: Process an input prompt into a DecoderOnlyInputs instance.

Arguments:

  • prompt: input prompt

Returns:

Source code in vllm/inputs/preprocess.py
def _process_decoder_only_prompt(
    self,
    prompt: DecoderOnlyDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs:
    """
    For decoder-only models:
    Process an input prompt into a
    [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance.

    Arguments:

    * prompt: input prompt

    Returns:

    * [`DecoderOnlyInputs`][vllm.inputs.data.DecoderOnlyInputs] instance
    """
    return self._prompt_to_llm_inputs(
        prompt,
        tokenization_kwargs=tokenization_kwargs,
        mm_uuids=mm_uuids,
    )

_process_encoder_decoder_prompt

_process_encoder_decoder_prompt(
    prompt: EncoderDecoderDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderDecoderInputs

For encoder/decoder models only: Process an input prompt into an EncoderDecoderInputs instance.

Arguments:

  • prompt: an input prompt

Returns:

Source code in vllm/inputs/preprocess.py
def _process_encoder_decoder_prompt(
    self,
    prompt: EncoderDecoderDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderDecoderInputs:
    """
    For encoder/decoder models only:
    Process an input prompt into an
    [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
    instance.

    Arguments:

    * prompt: an input prompt

    Returns:

    * [`EncoderDecoderInputs`][vllm.inputs.data.EncoderDecoderInputs]
      instance
    """
    encoder_prompt = prompt["encoder_prompt"]
    decoder_prompt = prompt["decoder_prompt"]

    return self._build_enc_dec_inputs(
        encoder_inputs=self._prompt_to_llm_inputs(
            encoder_prompt,
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        ),
        decoder_inputs=(
            None
            if decoder_prompt is None
            else self._prompt_to_llm_inputs(
                decoder_prompt,
                tokenization_kwargs=tokenization_kwargs,
            )
        ),
    )

_process_multimodal

_process_multimodal(
    prompt: str | list[int],
    mm_data: MultiModalDataDict,
    mm_processor_kwargs: Mapping[str, object] | None,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs

Apply the model's multi-modal processor to a multi-modal prompt, returning the corresponding token IDs and metadata.

Source code in vllm/inputs/preprocess.py
def _process_multimodal(
    self,
    prompt: str | list[int],
    mm_data: MultiModalDataDict,
    mm_processor_kwargs: Mapping[str, object] | None,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> MultiModalInputs:
    """
    Apply the model's multi-modal processor to a multi-modal prompt,
    returning the corresponding token IDs and metadata.
    """
    mm_processor = self._get_mm_processor()

    if mm_processor_kwargs is None:
        mm_processor_kwargs = {}

    mm_items = mm_processor.info.parse_mm_data(mm_data)
    mm_input = mm_processor.apply(
        prompt,
        mm_items,
        hf_processor_mm_kwargs=mm_processor_kwargs,
        tokenization_kwargs=tokenization_kwargs,
        mm_uuids=mm_uuids,
    )
    mm_hashes = mm_input["mm_hashes"]

    # Validate that all mm items have a string as their hash
    contains_only_strings = all(
        isinstance(leaf, str) for leaf in json_iter_leaves(mm_hashes)
    )
    if not contains_only_strings:
        raise ValueError(
            f"mm_hashes must contain only strings, got: {mm_hashes}. "
            "This is likely due to an incorrect custom implementation of "
            "MultiModalProcessor.apply method."
        )

    return mm_input

_prompt_to_llm_inputs

_prompt_to_llm_inputs(
    prompt: EncoderDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> EncoderInputs
_prompt_to_llm_inputs(
    prompt: DecoderDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderInputs
_prompt_to_llm_inputs(
    prompt: DecoderOnlyDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> DecoderOnlyInputs
_prompt_to_llm_inputs(
    prompt: SingletonDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> SingletonInputs

Extract the singleton inputs from a prompt.

Arguments:

  • prompt: single encoder or decoder input prompt

Returns:

Source code in vllm/inputs/preprocess.py
def _prompt_to_llm_inputs(
    self,
    prompt: SingletonDictPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> SingletonInputs:
    """
    Extract the singleton inputs from a prompt.

    Arguments:

    * prompt: single encoder or decoder input prompt

    Returns:

    * [`SingletonInputs`][vllm.inputs.data.SingletonInputs] instance
    """
    if "prompt_embeds" in prompt:
        return self._process_embeds(prompt)  # type: ignore[arg-type]

    if "prompt_token_ids" in prompt:
        return self._process_tokens(
            prompt,  # type: ignore[arg-type]
            mm_uuids=mm_uuids,
        )

    if "prompt" in prompt:
        return self._process_text(
            prompt,  # type: ignore[arg-type]
            tokenization_kwargs=tokenization_kwargs,
            mm_uuids=mm_uuids,
        )

    assert_never(prompt)  # type: ignore[arg-type]

_tokenize_prompt

_tokenize_prompt(
    prompt: str,
    tokenization_kwargs: dict[str, Any] | None = None,
) -> list[int]

Apply the model's tokenizer to a text prompt, returning the corresponding token IDs.

Source code in vllm/inputs/preprocess.py
def _tokenize_prompt(
    self,
    prompt: str,
    tokenization_kwargs: dict[str, Any] | None = None,
) -> list[int]:
    """
    Apply the model's tokenizer to a text prompt, returning the
    corresponding token IDs.
    """
    tokenizer = self.get_tokenizer()
    tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)

    encoder_config = self.model_config.encoder_config

    if encoder_config and encoder_config.get("do_lower_case", False):
        prompt = prompt.lower()

    return tokenizer.encode(prompt, **tokenization_kwargs)

get_decoder_start_token_id

get_decoder_start_token_id() -> int

Obtain the decoder start token id employed by an encoder/decoder model. Raises an error if it is not available.

Source code in vllm/inputs/preprocess.py
def get_decoder_start_token_id(self) -> int:
    """
    Obtain the decoder start token id employed by an encoder/decoder
    model. Raises an error if it is not available.
    """
    dec_start_token_id = getattr(
        self.model_config.hf_config, "decoder_start_token_id", None
    )

    if dec_start_token_id is None:
        logger.warning_once(
            "Falling back on <BOS> for decoder start token "
            "id because decoder start token id is not "
            "available."
        )
        dec_start_token_id = self.get_bos_token_id()

    if dec_start_token_id is None:
        raise RuntimeError("Cannot find decoder start token id or <BOS>")

    return dec_start_token_id

preprocess

preprocess(
    prompt: PromptType | DictPrompt | TokPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs

Preprocess the input prompt.

Source code in vllm/inputs/preprocess.py
def preprocess(
    self,
    prompt: PromptType | DictPrompt | TokPrompt,
    tokenization_kwargs: dict[str, Any] | None = None,
    *,
    mm_uuids: MultiModalUUIDDict | None = None,
) -> ProcessorInputs:
    """Preprocess the input prompt."""
    res = self._preprocess(prompt, tokenization_kwargs, mm_uuids=mm_uuids)

    if self.mm_processor_cache and self.mm_cache_stats is not None:
        delta = self.mm_processor_cache.make_stats(delta=True)
        self.mm_cache_stats.requests += 1
        self.mm_cache_stats.queries += delta.total
        self.mm_cache_stats.hits += delta.hits

    return res