class DeepEPHTPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
"""
Prepare/Finalize using DeepEP High-Throughput kernels.
"""
@staticmethod
def maybe_roundup_layer_hidden_size(hidden_size: int, dtype: torch.dtype) -> int:
# Round up hidden size so it is compatible with DeepEP High Throughput
# kernels.
# DeepEP intranode kernels make copies in units of,
# 32(warp-size) int4 elements. Round up hidden size to respect this.
# For example, an input hidden size of 2880 with dtype torch.bfloat16
# will be rounded up to 3072.
hidden_size_bytes = hidden_size * dtype.itemsize
xfer_atom_size = 512 # 32 * 16 (size(int4))
if hidden_size_bytes % xfer_atom_size == 0:
return hidden_size
hidden_size_bytes = round_up(hidden_size_bytes, xfer_atom_size)
return hidden_size_bytes // dtype.itemsize
def __init__(
self,
buffer: deep_ep.Buffer,
num_dispatchers: int,
dp_size: int,
rank_expert_offset: int,
):
super().__init__()
self.buffer = buffer
self.num_dispatchers_ = num_dispatchers
self.dp_size = dp_size
self.rank_expert_offset = rank_expert_offset
self.async_prepare = True
# The dispatch function returns a handle that the combine function
# requires. Under DBO microbatching we must track one handle per
# micro-batch to avoid races between threads.
self.handles = [None, None]
# From https://github.com/deepseek-ai/DeepEP/blob/9fe9021f29c9083cd1808ab36b740208524d9f63/deep_ep/buffer.py#L164
self.available_rank_configs = [2, 4, 8, 16, 24, 32, 64, 128, 144, 160]
def num_dispatchers(self) -> int:
return self.num_dispatchers_
def output_is_reduced(self) -> bool:
return True
@property
def activation_format(self) -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
def max_num_tokens_per_rank(self) -> int | None:
return None
def topk_indices_dtype(self) -> torch.dtype | None:
return torch.int64
def _get_dispatch_config(self) -> deep_ep.Config | None:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_dispatch_config(self.num_dispatchers_)
def _get_combine_config(self) -> deep_ep.Config | None:
if self.num_dispatchers_ not in self.available_rank_configs:
return None
return deep_ep.Buffer.get_combine_config(self.num_dispatchers_)
def _do_dispatch(
self,
tokens: torch.Tensor,
token_scales: torch.Tensor | None,
rank_topk_ids: torch.Tensor,
rank_topk_weights: torch.Tensor,
num_experts: int,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> Callable:
has_scales = token_scales is not None
# We yield before launching the dispatch kernel since the dispatch
# kernel will block the CPU so we want to queue up all the compute
# for the other ubatch before the dispatch kernel starts.
dbo_yield_and_switch_from_compute_to_comm()
# capture a DeepEP event and pass it as previous_event so
# DeepEP honors the dependency internally.
previous_event = dbo_get_previous_event(self.buffer.capture)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
dispatch_expert_num_tokens,
is_token_in_rank,
event,
) = self.buffer.get_dispatch_layout(
topk_idx=rank_topk_ids,
num_experts=num_experts,
previous_event=previous_event,
async_finish=False,
allocate_on_comm_stream=False,
)
token_data = tokens
if has_scales:
token_data = (tokens, token_scales)
(
token_data,
expert_topk_ids,
expert_topk_weights,
expert_num_tokens_per_expert_list,
handle,
event,
) = self.buffer.dispatch(
x=token_data,
handle=None,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=dispatch_expert_num_tokens,
topk_idx=rank_topk_ids,
topk_weights=rank_topk_weights,
# expert_alignment rounds the number of tokens per expert
# to this value.
expert_alignment=1,
config=self._get_dispatch_config(),
previous_event=previous_event,
async_finish=self.async_prepare and not dbo_enabled(),
allocate_on_comm_stream=False,
)
# record the handle for this ubatch
a2a_idx = dbo_current_ubatch_id()
self.handles[a2a_idx] = handle
dbo_switch_to_compute_sync()
return lambda: self._receiver(
event,
has_scales,
token_data,
expert_topk_ids,
num_experts,
expert_num_tokens_per_expert_list,
expert_topk_weights,
a1_scale,
quant_config,
defer_input_quant=defer_input_quant,
)
def _receiver(
self,
event: deep_ep.EventOverlap,
has_scales: bool,
token_data: tuple[torch.Tensor, torch.Tensor] | torch.Tensor,
expert_topk_ids: torch.Tensor | None,
num_experts: int,
expert_num_tokens_per_expert_list: list[int],
expert_topk_weights: torch.Tensor | None,
a1_scale: torch.Tensor | None,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool,
) -> mk.PrepareResultType:
if event.event is not None:
event.current_stream_wait()
if has_scales:
expert_x, expert_x_scale = token_data
else:
expert_x, expert_x_scale = token_data, None
# The existing MOE kernels assume that all entries of topk_ids are
# valid. To that effect, set the -1s in expert_topk_ids to some expert
# outside this rank so the expert_map can remap it to -1 when safe.
# With Expert Parallel, the experts are divided amongst the rank
# sequentially. For rank 0, set it to num_experts - 1 and for all other
# ranks set it to 0 as we know that expert_map will have a -1 in those
# regions for those ranks.
#
# DeepEP's topk_ids output refers to the local experts directly. Offset
# the topk_ids to move it back to the global experts space so it aligns
# with existing vLLM interfaces.
assert expert_topk_ids is not None
expert_topk_ids = torch.where(
expert_topk_ids == -1,
num_experts - 1 if self.rank_expert_offset == 0 else 0,
expert_topk_ids + self.rank_expert_offset,
)
# Makes a GPU-CPU copy.
# TODO (varun): Maybe it is better to re-compute the expert_num_tokens
# on GPU.
expert_tokens_meta = mk.ExpertTokensMetadata.make_from_list(
expert_num_tokens_per_expert_list, device=expert_x.device
)
# * For non-block quant, dispatch in b16 and quantize now as
# DeepEP kernels only support dispatching block scales.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if not quant_config.is_block_quantized and not defer_input_quant:
# Quantize after dispatch.
expert_x_scale = None
if expert_x.numel() != 0:
# TODO: support per_act_token_quant,
expert_x, expert_x_scale = moe_kernel_quantize_input(
expert_x,
a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=False,
block_shape=quant_config.block_shape,
)
return (
expert_x,
expert_x_scale,
expert_tokens_meta,
expert_topk_ids,
expert_topk_weights,
)
def supports_async(self) -> bool:
return True
def prepare_async(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.ReceiverType:
if apply_router_weight_on_input:
topk = topk_ids.size(1)
# TODO: this only works for topK=1, will need to update for topK>1
assert topk == 1, (
"apply_router_weight_on_input is only implemented for topk=1"
)
a1 = a1 * topk_weights.to(a1.dtype)
# * DeepEP only supports fp8 block scales so quantize
# before the dispatch for these models.
# * For all other quantization, dispatch after.
# * For expert kernels that require unquantized inputs,
# defer quantization to FusedMoEExpertsPermuteUnpermute.
if quant_config.is_block_quantized and not defer_input_quant:
a1q, a1q_scale = moe_kernel_quantize_input(
a1,
quant_config.a1_scale,
quant_dtype=quant_config.quant_dtype,
per_act_token_quant=quant_config.per_act_token_quant,
block_shape=quant_config.block_shape,
)
if a1q_scale is not None and a1q_scale.numel() == 1:
a1q_scale = a1q_scale.view(1, 1)
a1_post_scale = None
else:
a1q = a1
a1q_scale = None
a1_post_scale = (
quant_config.a1_gscale
if quant_config.quant_dtype == "nvfp4"
else quant_config.a1_scale
)
return self._do_dispatch(
tokens=a1q,
token_scales=a1q_scale,
rank_topk_ids=topk_ids,
rank_topk_weights=topk_weights,
num_experts=num_experts,
a1_scale=a1_post_scale,
quant_config=quant_config,
defer_input_quant=defer_input_quant,
)
def prepare(
self,
a1: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
num_experts: int,
expert_map: torch.Tensor | None,
apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig,
defer_input_quant: bool = False,
) -> mk.PrepareResultType:
receiver = self.prepare_async(
a1,
topk_weights,
topk_ids,
num_experts,
expert_map,
apply_router_weight_on_input,
quant_config,
defer_input_quant,
)
return receiver()
def _finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
do_async: bool,
) -> Callable | None:
a2a_idx = dbo_current_ubatch_id()
handle = self.handles[a2a_idx]
assert handle is not None
# fused_expert_output can have 0 tokens - This happens when none of the
# tokens from the all2all reach this EP rank.
if fused_expert_output.numel() != 0:
if isinstance(weight_and_reduce_impl, TopKWeightAndReduceDelegate):
weight_and_reduce_impl = TopKWeightAndReduceContiguous()
fused_expert_output = weight_and_reduce_impl.apply(
output=None,
fused_expert_output=fused_expert_output,
topk_weights=topk_weights,
topk_ids=topk_ids,
apply_router_weight_on_input=apply_router_weight_on_input,
)
dbo_yield_and_switch_from_compute_to_comm()
assert fused_expert_output.dtype == torch.bfloat16, (
f"Expected fused_expert_output bfloat16, got {fused_expert_output.dtype}"
)
previous_event = dbo_get_previous_event(self.buffer.capture)
combined_x, _, event = self.buffer.combine(
# HT combine only supports BF16
x=fused_expert_output,
handle=handle,
topk_weights=None,
config=self._get_combine_config(),
previous_event=previous_event,
async_finish=do_async and not dbo_enabled(),
allocate_on_comm_stream=False,
)
dbo_switch_to_compute()
if do_async:
def _receiver():
if event.event is not None:
event.current_stream_wait()
dbo_switch_to_comm()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
# TODO(lucas): refactor the modular kernel so this will be
# handled there
dbo_yield_and_switch_from_comm_to_compute()
return _receiver
else:
# TODO(lucas): support this case with the refactored modular kernel
assert not dbo_enabled()
# Respect inplace outputs.
output.copy_(combined_x, non_blocking=True)
return None
def finalize_async(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> Callable:
receiver = self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
True,
)
assert receiver is not None
return receiver
def finalize(
self,
output: torch.Tensor,
fused_expert_output: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
apply_router_weight_on_input: bool,
weight_and_reduce_impl: mk.TopKWeightAndReduce,
) -> None:
self._finalize(
output,
fused_expert_output,
topk_weights,
topk_ids,
apply_router_weight_on_input,
weight_and_reduce_impl,
False,
)