Skip to content

vllm.distributed.weight_transfer.base

Base class for weight transfer engines.

WeightTransferEngine

Bases: ABC, Generic[TInitInfo, TUpdateInfo]

Base class for weight transfer engines that handle transport of model weights from a trainer to inference workers.

This abstraction separates weight transfer transport logic from the worker implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be plugged in.

Subclasses should define

init_info_cls: Type of backend-specific initialization info update_info_cls: Type of backend-specific update info

Source code in vllm/distributed/weight_transfer/base.py
class WeightTransferEngine(ABC, Generic[TInitInfo, TUpdateInfo]):
    """
    Base class for weight transfer engines that handle transport of model weights
    from a trainer to inference workers.

    This abstraction separates weight transfer transport logic from the worker
    implementation, allowing different backends (NCCL, CUDA IPC[TODO], RDMA[TODO]) to be
    plugged in.

    Subclasses should define:
        init_info_cls: Type of backend-specific initialization info
        update_info_cls: Type of backend-specific update info
    """

    # Subclasses should override these class attributes
    init_info_cls: type[TInitInfo]
    update_info_cls: type[TUpdateInfo]

    def __init__(
        self, config: WeightTransferConfig, parallel_config: ParallelConfig
    ) -> None:
        """
        Initialize the weight transfer engine.

        Args:
            config: The configuration for the weight transfer engine
            parallel_config: The configuration for the parallel setup
        """
        self.config = config
        self.parallel_config = parallel_config

    def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
        """
        Construct typed init info from dict with validation.

        Args:
            init_dict: Dictionary containing backend-specific initialization parameters

        Returns:
            Typed backend-specific init info dataclass

        Raises:
            ValueError: If init_dict is invalid for this backend
        """
        try:
            return self.init_info_cls(**init_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid init_info for {self.__class__.__name__}: {e}"
            ) from e

    def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
        """
        Construct typed update info from dict with validation.

        Args:
            update_dict: Dictionary containing backend-specific update parameters

        Returns:
            Typed backend-specific update info dataclass

        Raises:
            ValueError: If update_dict is invalid for this backend
        """
        try:
            return self.update_info_cls(**update_dict)
        except TypeError as e:
            raise ValueError(
                f"Invalid update_info for {self.__class__.__name__}: {e}"
            ) from e

    @abstractmethod
    def init_transfer_engine(self, init_info: TInitInfo) -> None:
        """
        Initialize the weight transfer mechanism.
        This is called once at the beginning of training.

        Args:
            init_info: Backend-specific initialization info
        """
        raise NotImplementedError

    @abstractmethod
    def receive_weights(
        self,
        update_info: TUpdateInfo,
        load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
    ) -> None:
        """
        Receive weights from the trainer and load them incrementally.

        Args:
            update_info: Backend-specific update info containing parameter metadata
                        and any backend-specific data
            load_weights: Callable that loads weights into the model. Called
                         incrementally for each weight to avoid OOM.
        """
        raise NotImplementedError

    @abstractmethod
    def shutdown(self) -> None:
        """
        Shutdown the weight transfer engine.
        This should be called when the worker is shutting down.
        """
        raise NotImplementedError

__init__

__init__(
    config: WeightTransferConfig,
    parallel_config: ParallelConfig,
) -> None

Initialize the weight transfer engine.

Parameters:

Name Type Description Default
config WeightTransferConfig

The configuration for the weight transfer engine

required
parallel_config ParallelConfig

The configuration for the parallel setup

required
Source code in vllm/distributed/weight_transfer/base.py
def __init__(
    self, config: WeightTransferConfig, parallel_config: ParallelConfig
) -> None:
    """
    Initialize the weight transfer engine.

    Args:
        config: The configuration for the weight transfer engine
        parallel_config: The configuration for the parallel setup
    """
    self.config = config
    self.parallel_config = parallel_config

init_transfer_engine abstractmethod

init_transfer_engine(init_info: TInitInfo) -> None

Initialize the weight transfer mechanism. This is called once at the beginning of training.

Parameters:

Name Type Description Default
init_info TInitInfo

Backend-specific initialization info

required
Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def init_transfer_engine(self, init_info: TInitInfo) -> None:
    """
    Initialize the weight transfer mechanism.
    This is called once at the beginning of training.

    Args:
        init_info: Backend-specific initialization info
    """
    raise NotImplementedError

parse_init_info

parse_init_info(init_dict: dict[str, Any]) -> TInitInfo

Construct typed init info from dict with validation.

Parameters:

Name Type Description Default
init_dict dict[str, Any]

Dictionary containing backend-specific initialization parameters

required

Returns:

Type Description
TInitInfo

Typed backend-specific init info dataclass

Raises:

Type Description
ValueError

If init_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_init_info(self, init_dict: dict[str, Any]) -> TInitInfo:
    """
    Construct typed init info from dict with validation.

    Args:
        init_dict: Dictionary containing backend-specific initialization parameters

    Returns:
        Typed backend-specific init info dataclass

    Raises:
        ValueError: If init_dict is invalid for this backend
    """
    try:
        return self.init_info_cls(**init_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid init_info for {self.__class__.__name__}: {e}"
        ) from e

parse_update_info

parse_update_info(
    update_dict: dict[str, Any],
) -> TUpdateInfo

Construct typed update info from dict with validation.

Parameters:

Name Type Description Default
update_dict dict[str, Any]

Dictionary containing backend-specific update parameters

required

Returns:

Type Description
TUpdateInfo

Typed backend-specific update info dataclass

Raises:

Type Description
ValueError

If update_dict is invalid for this backend

Source code in vllm/distributed/weight_transfer/base.py
def parse_update_info(self, update_dict: dict[str, Any]) -> TUpdateInfo:
    """
    Construct typed update info from dict with validation.

    Args:
        update_dict: Dictionary containing backend-specific update parameters

    Returns:
        Typed backend-specific update info dataclass

    Raises:
        ValueError: If update_dict is invalid for this backend
    """
    try:
        return self.update_info_cls(**update_dict)
    except TypeError as e:
        raise ValueError(
            f"Invalid update_info for {self.__class__.__name__}: {e}"
        ) from e

receive_weights abstractmethod

receive_weights(
    update_info: TUpdateInfo,
    load_weights: Callable[
        [list[tuple[str, Tensor]]], None
    ],
) -> None

Receive weights from the trainer and load them incrementally.

Parameters:

Name Type Description Default
update_info TUpdateInfo

Backend-specific update info containing parameter metadata and any backend-specific data

required
load_weights Callable[[list[tuple[str, Tensor]]], None]

Callable that loads weights into the model. Called incrementally for each weight to avoid OOM.

required
Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def receive_weights(
    self,
    update_info: TUpdateInfo,
    load_weights: Callable[[list[tuple[str, torch.Tensor]]], None],
) -> None:
    """
    Receive weights from the trainer and load them incrementally.

    Args:
        update_info: Backend-specific update info containing parameter metadata
                    and any backend-specific data
        load_weights: Callable that loads weights into the model. Called
                     incrementally for each weight to avoid OOM.
    """
    raise NotImplementedError

shutdown abstractmethod

shutdown() -> None

Shutdown the weight transfer engine. This should be called when the worker is shutting down.

Source code in vllm/distributed/weight_transfer/base.py
@abstractmethod
def shutdown(self) -> None:
    """
    Shutdown the weight transfer engine.
    This should be called when the worker is shutting down.
    """
    raise NotImplementedError

WeightTransferInitInfo dataclass

Bases: ABC

Base class for backend-specific initialization info.

Source code in vllm/distributed/weight_transfer/base.py
@dataclass
class WeightTransferInitInfo(ABC):  # noqa: B024
    """Base class for backend-specific initialization info."""

    pass

WeightTransferInitRequest dataclass

API-level weight transfer initialization request.

Source code in vllm/distributed/weight_transfer/base.py
@dataclass
class WeightTransferInitRequest:
    """API-level weight transfer initialization request."""

    init_info: dict[str, Any] = field(default_factory=dict)

WeightTransferUpdateInfo dataclass

Bases: ABC

Base class for backend-specific weight update info.

Source code in vllm/distributed/weight_transfer/base.py
@dataclass
class WeightTransferUpdateInfo(ABC):  # noqa: B024
    """Base class for backend-specific weight update info."""

    _: KW_ONLY
    is_checkpoint_format: bool = True
    """Set to True if weights are in checkpoint/original model format and need
    layerwise processing. Set to False if weights have already been processed
    into kernel format (repacking, renaming, etc.)."""

is_checkpoint_format class-attribute instance-attribute

is_checkpoint_format: bool = True

Set to True if weights are in checkpoint/original model format and need layerwise processing. Set to False if weights have already been processed into kernel format (repacking, renaming, etc.).

WeightTransferUpdateRequest dataclass

API-level weight update request.

Source code in vllm/distributed/weight_transfer/base.py
@dataclass
class WeightTransferUpdateRequest:
    """API-level weight update request."""

    update_info: dict[str, Any] = field(default_factory=dict)