class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
"""DeepGemm-based fused MoE expert implementation."""
def __init__(self, moe_config: FusedMoEConfig, quant_config: FusedMoEQuantConfig):
super().__init__(moe_config=moe_config, quant_config=quant_config)
assert quant_config.block_shape == get_mk_alignment_for_contiguous_layout()
assert quant_config.quant_dtype == torch.float8_e4m3fn
assert not quant_config.per_act_token_quant
assert not quant_config.per_out_ch_quant
@staticmethod
def activation_format() -> mk.FusedMoEActivationFormat:
return mk.FusedMoEActivationFormat.Standard
@staticmethod
def _supports_current_device() -> bool:
return is_deep_gemm_supported()
@staticmethod
def _supports_no_act_and_mul() -> bool:
return False
@staticmethod
def _supports_quant_scheme(
weight_key: QuantKey | None,
activation_key: QuantKey | None,
) -> bool:
SUPPORTED_W_A = [
(kFp8Static128BlockSym, kFp8Dynamic128Sym),
]
return (weight_key, activation_key) in SUPPORTED_W_A
@staticmethod
def _supports_activation(activation: MoEActivation) -> bool:
return activation in [MoEActivation.SILU, MoEActivation.SWIGLUSTEP]
@staticmethod
def _supports_parallel_config(moe_parallel_config: FusedMoEParallelConfig) -> bool:
# NOTE(rob): discovered an IMA with this combination. Needs investigation.
return not moe_parallel_config.use_fi_all2allv_kernels
def supports_chunking(self) -> bool:
return True
def supports_expert_map(self) -> bool:
return True
def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce:
return TopKWeightAndReduceNoOP()
def workspace_shapes(
self,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
activation: MoEActivation,
) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]:
assert self.block_shape is not None
block_m = self.block_shape[0]
M_sum = compute_aligned_M(
M, topk, local_num_experts, block_m, expert_tokens_meta
)
assert M_sum % block_m == 0
activation_out_dim = self.adjust_N_for_activation(N, activation)
workspace1 = (M_sum, max(activation_out_dim, K))
workspace2 = (M_sum, max(N, K))
output = (M, K)
return (workspace1, workspace2, output)
def _act_mul_quant(
self, input: torch.Tensor, output: torch.Tensor, activation: MoEActivation
) -> tuple[torch.Tensor, torch.Tensor]:
assert self.block_shape is not None
block_k = self.block_shape[1]
scale_fmt = DeepGemmQuantScaleFMT.from_oracle()
M_sum, N = input.size()
activation_out_dim = self.adjust_N_for_activation(N, activation)
# 1. DeepGemm UE8M0: use packed per-token-group quant
if scale_fmt == DeepGemmQuantScaleFMT.UE8M0:
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
a2q, a2q_scale = per_token_group_quant_fp8_packed_for_deepgemm(
act_out,
block_k,
out_q=output,
)
return a2q, a2q_scale
# 2. Hopper / non‑E8M0: prefer the fused SiLU+mul+quant kernel
if activation == MoEActivation.SILU:
use_ue8m0 = scale_fmt == DeepGemmQuantScaleFMT.FLOAT32_CEIL_UE8M0
return silu_mul_per_token_group_quant_fp8_colmajor(
input=input,
output=output,
use_ue8m0=use_ue8m0,
)
# 3. fallback path for non-SiLU activations in non‑UE8M0 cases.
act_out = torch.empty(
(M_sum, activation_out_dim), dtype=input.dtype, device=input.device
)
self.activation(activation, act_out, input)
return per_token_group_quant_fp8(
act_out, block_k, column_major_scales=True, out_q=output
)
def apply(
self,
output: torch.Tensor,
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: MoEActivation,
global_num_experts: int,
expert_map: torch.Tensor | None,
a1q_scale: torch.Tensor | None,
a2_scale: torch.Tensor | None,
workspace13: torch.Tensor,
workspace2: torch.Tensor,
expert_tokens_meta: mk.ExpertTokensMetadata | None,
apply_router_weight_on_input: bool,
):
assert a1q_scale is not None
assert a2_scale is None
assert self.block_shape is not None
assert self.w1_scale is not None
assert self.w2_scale is not None
a1q = hidden_states
_, N, K = w1.size()
local_num_experts = w1.size(0)
if global_num_experts == -1:
global_num_experts = local_num_experts
assert w2.size(1) == K
M_sum = compute_aligned_M(
M=topk_ids.size(0),
num_topk=topk_ids.size(1),
local_num_experts=local_num_experts,
alignment=get_mk_alignment_for_contiguous_layout()[0],
expert_tokens_meta=expert_tokens_meta,
)
a1q_perm = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, K)
)
a1q, a1q_scale, expert_ids, inv_perm = deepgemm_moe_permute(
aq=a1q,
aq_scale=a1q_scale,
topk_ids=topk_ids,
local_num_experts=local_num_experts,
expert_map=expert_map,
expert_tokens_meta=expert_tokens_meta,
aq_out=a1q_perm,
)
assert a1q.size(0) == M_sum
mm1_out = _resize_cache(workspace2, (M_sum, N))
m_grouped_fp8_gemm_nt_contiguous(
(a1q, a1q_scale), (w1, self.w1_scale), mm1_out, expert_ids
)
activation_out_dim = self.adjust_N_for_activation(N, activation)
quant_out = _resize_cache(
workspace13.view(dtype=torch.float8_e4m3fn), (M_sum, activation_out_dim)
)
a2q, a2q_scale = self._act_mul_quant(
input=mm1_out.view(-1, N), output=quant_out, activation=activation
)
mm2_out = _resize_cache(workspace2, (M_sum, K))
m_grouped_fp8_gemm_nt_contiguous(
(a2q, a2q_scale), (w2, self.w2_scale), mm2_out, expert_ids
)
if apply_router_weight_on_input:
topk_weights = torch.ones_like(topk_weights)
deepgemm_unpermute_and_reduce(
a=mm2_out,
topk_ids=topk_ids,
topk_weights=topk_weights,
inv_perm=inv_perm,
expert_map=expert_map,
output=output,
)