Coalesce duplicate split_with_sizes nodes that operate on the same input tensor with the same split sizes.
On certain hardware/dtype combinations (e.g. B200 + FP8) the Inductor graph may contain multiple split_with_sizes calls on the same tensor that CSE fails to merge. This pass detects and replaces the duplicates so that downstream pattern-matching passes (e.g. QK-Norm+RoPE fusion) see a single split node with all users attached.
See also
- vLLM #33295 (original issue)
- PyTorch #174472 (upstream CSE gap)
SplitCoalescingPass
Bases: VllmInductorPass
Replace duplicate split_with_sizes nodes with a single canonical node when they share the same input tensor and split sizes.
Source code in vllm/compilation/passes/utility/split_coalescing.py
| class SplitCoalescingPass(VllmInductorPass):
"""Replace duplicate ``split_with_sizes`` nodes with a single canonical
node when they share the same input tensor and split sizes."""
@VllmInductorPass.time_and_log
def __call__(self, graph: fx.Graph) -> None:
count = 0
# Map from input tensor node -> list of split nodes seen so far.
split_nodes: dict[fx.Node, list[fx.Node]] = {}
for node in graph.nodes:
if not is_func(node, torch.ops.aten.split_with_sizes.default):
continue
if not all(is_func(user, operator.getitem) for user in node.users):
continue
arg_node, split_sizes = node.args[:2]
if arg_node not in split_nodes:
split_nodes[arg_node] = [node]
continue
# Find existing node with same split_sizes
canonical = next(
(
n
for n in split_nodes[arg_node]
if list(n.args[1]) == list(split_sizes)
),
None,
)
if canonical is not None:
node.replace_all_uses_with(canonical)
graph.erase_node(node)
count += 1
else:
split_nodes[arg_node].append(node)
logger.debug("Coalesced %d duplicate split_with_sizes nodes", count)
|