class LoRAModelRunnerMixin:
def load_lora_model(
self,
model: nn.Module,
vllm_config: VllmConfig,
device: torch.device,
) -> nn.Module:
if not supports_lora(model):
raise ValueError(f"{model.__class__.__name__} does not support LoRA yet.")
# Add LoRA Manager to the Model Runner
self.lora_manager = LRUCacheWorkerLoRAManager(
vllm_config,
device,
model.embedding_modules,
)
return self.lora_manager.create_lora_manager(model, vllm_config)
def _set_active_loras(
self,
prompt_lora_mapping: tuple[int, ...],
token_lora_mapping: tuple[int, ...],
lora_requests: set[LoRARequest],
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
self._ensure_lora_enabled()
# Set is_prefill to True, so we always use the SGMV kernels on
# non-cuda platforms.
# On cuda platforms we use the same kernels for prefill and
# decode and this flag is generally ignored.
lora_mapping = LoRAMapping(
token_lora_mapping,
prompt_lora_mapping,
is_prefill=True,
type=mapping_type,
)
self.lora_manager.set_active_adapters(lora_requests, lora_mapping)
def _ensure_lora_enabled(self) -> None:
if not hasattr(self, "lora_manager"):
raise RuntimeError("LoRA is not enabled. Use --enable-lora to enable LoRA.")
def set_active_loras(
self,
input_batch: InputBatch,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray | None = None,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
) -> None:
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
prompt_lora_mapping: tuple[int, ...] # of size np.sum(num_sampled_tokens)
token_lora_mapping: tuple[int, ...] # of size np.sum(num_scheduled_tokens)
lora_requests: set[LoRARequest]
prompt_lora_mapping, token_lora_mapping, lora_requests = (
input_batch.make_lora_inputs(num_scheduled_tokens, num_sampled_tokens)
)
return self._set_active_loras(
prompt_lora_mapping, token_lora_mapping, lora_requests, mapping_type
)
@contextmanager
def maybe_setup_dummy_loras(
self, lora_config: LoRAConfig | None, remove_lora: bool = True
):
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_loras = lora_config.max_loras
lora_warmup_rank = (
lora_config.max_lora_rank if lora_config.max_lora_rank < 8 else 8
)
# Make dummy lora requests
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, num_loras + 1)
}
with self.lora_manager.dummy_lora_cache():
# Add the dummy LoRAs here so _set_active_loras doesn't try to
# load from disk.
for lr in lora_requests:
self.lora_manager.add_dummy_lora(lr, rank=lora_warmup_rank)
yield
# __exit__ code
if remove_lora:
self.lora_manager.remove_all_adapters()
@contextmanager
def maybe_select_dummy_loras(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
num_sampled_tokens: np.ndarray | None = None,
num_active_loras: int = 0,
):
"""
Context manager to select dummy LoRAs for capture/warmup.
Args:
lora_config: LoRA configuration, or None if LoRA is disabled.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
num_active_loras: Number of distinct active LoRAs to use.
- 0: No LoRA active (set up zero mappings).
- >0: Use exactly this many distinct LoRAs.
"""
if num_sampled_tokens is None:
num_sampled_tokens = np.ones_like(num_scheduled_tokens, dtype=np.int32)
# Skip LoRA setup entirely only if no LoRA config
if lora_config is None:
yield
else:
# __enter__ code
assert self.lora_manager is not None, "LoRA is not enabled"
num_reqs = len(num_scheduled_tokens)
max_loras = lora_config.max_loras
# Determine how many distinct LoRAs to use and whether to include
# no-LoRA tokens (-1 entries).
# When num_active_loras > max_loras (e.g., max_loras + 1), we need
# to include -1 entries to simulate batches with both LoRA and
# no-LoRA tokens. This ensures prepare_tensors computes the correct
# num_active_loras that matches the cudagraph capture key.
if num_active_loras == 0:
# No LoRA active - use 0 mappings like the original code
effective_num_loras = 0
include_no_lora = False
elif num_active_loras > max_loras:
# num_active_loras > max_loras means we want max_loras adapters
# PLUS no-LoRA tokens (-1). This is the max_loras + 1 case.
effective_num_loras = max_loras
include_no_lora = True
else:
# Specific number of active LoRAs requested
effective_num_loras = min(num_active_loras, max_loras)
include_no_lora = False
# Make prompt lora mapping
# Assign LoRA IDs cyclically to simulate a worst-case scenario.
# LoRA IDs are 1-indexed (1 to max_loras) as required by LoRARequest.
# convert_mapping() will convert these to 0-indexed slot indices.
if effective_num_loras > 0:
if include_no_lora:
# Include -1 (no-LoRA) entries by cycling through
# -1, 1, 2, ..., effective_num_loras
# This ensures prepare_tensors sees both LoRA and no-LoRA
# tokens, computing num_active_loras = effective_num_loras+1
cycle_values = np.array(
list(range(1, effective_num_loras + 1)),
dtype=np.int32,
)
prompt_lora_mapping = cycle_values[
np.arange(num_reqs, dtype=np.int32) % len(cycle_values)
]
else:
# Use 1 to effective_num_loras (1-indexed lora IDs)
prompt_lora_mapping = (
np.arange(num_reqs, dtype=np.int32) % effective_num_loras
) + 1
else:
# No LoRA active - use 0 for all tokens (original behavior)
prompt_lora_mapping = np.zeros(num_reqs, dtype=np.int32)
# Make sample lora mapping
sample_lora_mapping = np.repeat(prompt_lora_mapping, num_sampled_tokens)
# Make token lora mapping
token_lora_mapping = np.repeat(prompt_lora_mapping, num_scheduled_tokens)
# Make dummy lora requests (only for the active LoRAs)
lora_requests: set[LoRARequest] = {
LoRARequest(
lora_name=f"warmup_{lora_id}",
lora_int_id=lora_id,
lora_path="/not/a/real/path",
)
for lora_id in range(1, effective_num_loras + 1)
}
self._set_active_loras(
tuple(sample_lora_mapping),
tuple(token_lora_mapping),
lora_requests,
mapping_type,
)
yield
@contextmanager
def maybe_dummy_run_with_lora(
self,
lora_config: LoRAConfig | None,
num_scheduled_tokens: np.ndarray,
num_sampled_tokens: np.ndarray,
remove_lora: bool = True,
num_active_loras: int = 0,
mapping_type: LoRAMappingType = LoRAMappingType.LANGUAGE,
):
"""
Context manager for dummy runs with LoRA.
Args:
lora_config: LoRA configuration.
num_scheduled_tokens: Array of scheduled token counts per request.
num_sampled_tokens: Array of sampled token counts per request.
remove_lora: Whether to remove LoRAs after the context exits.
num_active_loras: Number of distinct active LoRAs to use.
LoRA is activated when num_active_loras > 0.
"""
with (
self.maybe_setup_dummy_loras(lora_config, remove_lora),
self.maybe_select_dummy_loras(
lora_config,
num_scheduled_tokens,
mapping_type,
num_sampled_tokens,
num_active_loras,
),
):
yield
def maybe_remove_all_loras(self, lora_config: LoRAConfig | None):
if lora_config is None:
return
self.lora_manager.remove_all_adapters()
def add_lora(self, lora_request: LoRARequest) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.add_adapter(lora_request)
def remove_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.remove_adapter(lora_id)
def pin_lora(self, lora_id: int) -> bool:
self._ensure_lora_enabled()
return self.lora_manager.pin_adapter(lora_id)
def list_loras(self) -> set[int]:
self._ensure_lora_enabled()
return self.lora_manager.list_adapters()