Skip to content

vllm.compilation.passes.fusion.matcher_utils

MatcherCustomOp

Bases: ABC

Source code in vllm/compilation/passes/fusion/matcher_utils.py
class MatcherCustomOp(ABC):
    def __init__(self, enabled: bool) -> None:
        config = get_current_vllm_config()
        self.model_dtype = config.model_config.dtype if config.model_config else None
        self.device = config.device_config.device if config.device_config else None

        self.enabled = enabled
        self.forward = self.forward_custom if enabled else self.forward_native

    @abstractmethod
    def forward_custom(self, *args: Any, **kwargs: Any) -> Any:
        pass

    @abstractmethod
    def forward_native(self, *args: Any, **kwargs: Any) -> Any:
        pass

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        return self.forward(*args, **kwargs)

    def empty(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=self.model_dtype, device=self.device, **kwargs)

    def empty_int64(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=torch.int64, device=self.device, **kwargs)

    def empty_f32(self, *args: Any, **kwargs: Any) -> torch.Tensor:
        return torch.empty(*args, dtype=torch.float32, device=self.device, **kwargs)

    def inputs(self) -> list[torch.Tensor]:
        """Utility for inputs to the pattern"""
        raise NotImplementedError

inputs

inputs() -> list[Tensor]

Utility for inputs to the pattern

Source code in vllm/compilation/passes/fusion/matcher_utils.py
def inputs(self) -> list[torch.Tensor]:
    """Utility for inputs to the pattern"""
    raise NotImplementedError