Compilers match structure, not meaning
Two identical computations can look different to a compiler. A vLLM bug showed me why — and how to fix it.
When a GPU model runs slower than it should but produces correct results, something subtle is usually wrong. I hit this while contributing to vLLM, the open-source LLM serving engine. The fix was 40 lines. Grasping why took far longer.
The memory wall
GPUs are fast at arithmetic. The bottleneck is memory. Each kernel — a function that runs on the GPU — must read its inputs from high-bandwidth memory (HBM) and write its outputs back. Registers and shared memory sit on-chip and run 10-100x faster, but they only live for the duration of a single kernel.
So if five operations run as five kernels, each intermediate result takes a round trip through HBM. Fuse them into one kernel and intermediates stay in registers. That is the core idea behind kernel fusion: fewer memory round trips, less wasted time.
What torch.compile does
PyTorch’s torch.compile automates this. It traces your model
into an FX graph — a graph of tensor operations — then looks
for patterns it can replace with fused kernels.
The pipeline runs in six stages. Custom passes from libraries like vLLM hook into the fifth — after graph optimization but before code generation.
One such pass fuses QK normalization and rotary position embeddings (RoPE) in transformer attention. Without fusion, the sequence — split QKV, reshape, normalize Q and K, reshape back, apply RoPE — launches five kernels. Fused, it launches one.
The bug
On NVIDIA’s B200 GPU with FP8 quantization, the fusion stopped firing. The model still worked. It just ran slower, and a log line quietly read: “Fused QK Norm+RoPE on 0 sites”.
The graph’s shape had changed. The fusion pass works by pattern matching: you define the unfused computation as a template, and PyTorch searches the FX graph for an exact structural match.
In the standard graph, a split_with_sizes node appears once
with three consumers. The pattern expects this shape:
split_with_sizes ──┬── getitem[0] → Q ├── getitem[1] → K └── getitem[2] → VOn the B200 with FP8, FP8 quantization lowering created three
separate split_with_sizes nodes — one per output. Each had a
single consumer:
split_with_sizes_0 ── getitem[0] → Qsplit_with_sizes_1 ── getitem[0] → Ksplit_with_sizes_2 ── getitem[0] → VThe key insight
The two graphs compute the same thing. But the pattern matcher is structural, not semantic. It saw three nodes with one consumer each, not one node with three consumers, and found no match.
A compiler pass called common subexpression elimination (CSE) would usually merge these identical splits. But on this hardware and dtype path, the FP8 lowering created the splits after CSE had already run. This is a known PyTorch issue.
The fix
Rather than make the fusion pass handle every graph variant,
I wrote a separate canonicalization pass. The
SplitCoalescingPass runs before fusion and normalizes the
graph:
- Walk every node in topological order.
- Find
split_with_sizesnodes whose consumers are allgetitemnodes. - Group them by input tensor and split sizes.
- If duplicates exist, redirect all consumers to the first instance and delete the rest.
After this pass, the graph always has the shape the fusion pattern expects. The fusion pass needs no changes. And any future pass that expects canonical splits will find them.
Pass ordering matters: canonicalization must precede matching. Reverse them and the fusion pass sees the unnormalized graph, fails, and the coalescing runs for nothing.
Lessons
Silent performance bugs are hard to catch. No error, no crash. The model produces correct outputs. The only signal is a counter that reads zero instead of one — per layer, per token.
Upstream compiler passes are not reliable. You cannot assume CSE or any other optimization will always normalize your graph the same way across every hardware and dtype combination. Different lowering paths produce different graph shapes.
Separate canonicalization from matching. A dedicated normalization pass is simpler, reusable, and keeps the fusion pass focused on its one job.
The fix was small. Reaching it was not. Forty lines of
code, but getting there required tracing the full
torch.compile pipeline — from Python bytecode through FX
graphs to Triton codegen — and understanding exactly where and
why the graph shape changed.
If you work on GPU compilers or vLLM, my full notes and the annotated fix are on GitHub.