torch.compile 简介及其在 vLLM 中的工作原理
说明
这篇博客源自我们每两周举办一次的 vLLM office hours,这是一个由 Red Hat 与 vLLM 项目提交者及加州大学伯克利分校团队共同主持的社区论坛。每次会议都会涵盖最新动态、特邀嘉宾的深度分享以及开放问答环节。欢迎每隔一个周四美国东部时间下午 2:00 / 太平洋时间上午 11:00 在 Google Meet 上加入我们,会后可以在我们的 YouTube 播放列表上获取录像和幻灯片。
引言
如今,要实现大型语言模型(LLM)的快速推理,需要在多样化的硬件、工作负载和规模下尽可能高效地执行模型。高效执行需要高度优化的算子(kernel),而这些算子通常需要针对不同的模型和平台进行手动调优。torch.compile 是 PyTorch 的即时(JIT)编译器,它可以自动生成优化的算子,从而显著加快 PyTorch 代码的运行速度,而无需开发者为所有支持的硬件平台手动优化算子。
对于 vLLM 这个用于可移植和高效 LLM 推理的事实标准开源推理引擎来说,torch.compile 不仅仅是一个性能增强器。它是一个核心组件,将优化的责任从模型开发者转移到了编译器。优化是在编译期间应用的,而不是要求修改模型定义,从而实现了更清晰的关注点分离,并获得了最大的性能。在这篇文章中,我们将详细介绍 torch.compile 的工作原理、它如何集成到 vLLM 中,以及 vLLM 如何使用自定义编译器通道(pass)来最大化性能。我们还将讨论 vLLM 中 torch.compile 集成的正在进行和未来的工作,以进一步提高其可用性和性能。
什么是 torch.compile?
torch.compile 让您能以最小的努力优化 PyTorch 代码:使用 torch.compile 非常简单,就像给一个函数或 torch.nn.Module 添加一个装饰器一样。torch.compile 会自动将张量操作捕获到一个计算图中,然后为该图生成优化的代码。
在下面的例子中,torch.compile 为函数 fn 中所有的逐点(pointwise)操作生成了一个单一的融合算子。它会即时捕获并编译该函数,如果任何捕获条件(例如输入形状)发生变化,可能会重新编译。
图 1:torch.compile 是 PyTorch 代码的 JIT 编译器。你可以用 torch.compile 包装函数、nn.Module 和其他可调用对象。
有多种使用 torch.compile 的方式。你可以将它用作算子生成器(如图 1 所示),我们编译一个函数。但你也可以将 torch.compile 应用于你的整个 nn.Module 模型或其子模块。根据模型的结构和你的需求(例如编译时间),我们建议在不同的地方应用 torch.compile。
为什么要使用 torch.compile?
优化模型的一种方法是编写自定义的 CPU/CUDA 操作,这些操作执行与模型中相同的运算但速度更快。为每个模型编写自定义算子非常耗时,并且需要对性能和硬件有深入的理解。torch.compile 几乎不需要额外的工程努力就能让你达到接近峰值性能。例如,PyTorch 的开源 TorchBench 基准测试套件显示,在 80 多个模型上,几何平均速度提升了 1.8-2 倍。
图 2:torch.compile 为您提供了快速的基线性能,从而节省了您调优模型性能的开发时间。
torch.compile 的工作原理
torch.compile 流水线包括两个主要阶段:前端(TorchDynamo)和后端(TorchInductor)。我们将做一个简要概述,更多详情请参阅官方 PyTorch 2 论文。
1. 前端(TorchDynamo):图捕获
torch.compile 的前端是一个自定义的字节码解释器。它追踪任意的 Python 函数,并提取出仅包含张量操作的线性 torch.fx 图。torch.compile 的一个关键特性是图断点(graph breaks),这使其能够很好地覆盖所有 Python 代码。每当 torch.compile 遇到它不支持的操作时,它不会报错。相反,它会结束当前正在追踪的图,运行该操作,然后开始追踪一个新的图。torch.compile 将每个追踪到的图发送到后端进行优化。
在下面的代码示例中,torch.save 是一个不支持的操作:torch.compile 不知道如何执行磁盘 I/O。将 torch.compile 应用于函数 f,相当于将 torch.compile 分别应用于调用 torch.save 之前的计算区域和调用 torch.save 之后的区域。
图 3:torch.compile 捕获张量操作的线性图,并绕过像 torch.save 这样的不支持的操作。
2. 后端(TorchInductor):优化与算子生成
torch.compile 的后端接收来自前端的图,并通过图优化通道以及降级(lowering)到优化的 C++、Triton 或其他算子来进行优化。它能够:
- 融合逐点和归约(reduction)操作
- 自动调优算子配置,如块大小
- 为矩阵乘法(matmul)选择不同的后端(cuBLAS, Triton, CUTLASS),并执行前序(prologue)和后序(epilogue)融合
- 使用 CUDA Graphs 高效地缓存和重放算子启动
CUDA Graphs 是一个例子,说明了拥有一个编译器是多么有帮助。CUDA Graphs 减少了启动开销,但要求你的代码满足某些假设(例如,它必须只使用 CUDA 操作,输入张量必须有静态内存地址)。torch.compile 能够自动在不支持的操作处分割图,创建更小的、可以安全使用 CUDA Graph 的图,并自动管理静态输入缓冲区。
vLLM 集成
vLLM V1 默认集成了 torch.compile,用于在线和离线推理。你可以使用 -O0 或 --enforce-eager 来禁用它,但在大多数用例中,保持开启状态会带来性能优势。更多详情请参见文档。
编译缓存
vLLM 在冷启动期间编译模型,并将产物(FX 图、Triton 算子)保存在一个缓存目录中(默认为 ~/.cache/vllm/torch_compile_cache)。在热启动时,会从缓存中检索这些产物。你可以通过 VLLM_DISABLE_COMPILE_CACHE=1 或删除缓存目录来禁用缓存。
编译的产物和缓存可以在具有相同环境的机器之间重用。如果你有自动扩展的用例,请确保只生成一次缓存目录并在实例之间共享它。
图 4:编译产物在冷启动后被缓存,并且可以在机器之间重用,以确保在正确设置下实现快速、一致的启动。
动态批处理大小和特化
默认情况下,vLLM 编译一个具有动态批处理大小的图,该图支持所有可能的批处理大小。这意味着一个产物可以服务于可变的输入大小。然而,针对已知的批处理大小(如 1、2 或 4)进行特化可以带来性能提升。
在你的配置中使用 compile_sizes: [1, 2, 4] 来触发这种特化。在底层,这会告诉 torch.compile 针对这些静态大小进行编译,并可能执行更多的自动调优来选择最佳的算子。
图 5:如何指定针对特定批处理大小进行特化编译。
分段 CUDA Graphs
并非所有操作都与 CUDA Graphs 兼容;例如,级联注意力(cascade attention)就不兼容。vLLM 通过将捕获的图分解为 CUDA Graph 安全和不安全的部分,并分别执行它们来解决这个问题。这使我们既能获得 CUDA Graphs 的性能优势,又不会损失正确性。
图 6:vLLM 中的分段 CUDA Graphs 捕获并重放支持的 GPU 算子序列以实现低开销执行,同时跳过不支持的操作,如级联注意力。
vLLM 中的自定义编译器通道
虽然 torch.compile 包含许多内置优化,但 vLLM 添加了自定义编译器通道,应用额外的优化以进一步提高性能。
为何需要自定义通道?
模型作者编写声明式的、模块化的代码,侧重于正确性并使用清晰的抽象,将更高级别的操作分离到不同的子模块中,并按层进行分组。然而,要达到峰值性能,通常需要打破这些抽象,比如跨子模块和层融合操作。vLLM 的自定义通道重写 torch.fx 图,而不是重写模型本身。
这些通道:
- 融合内存密集型的自定义操作,如激活函数和量化
- 添加 Inductor 中没有的优化(例如移除多余的无操作)
示例:SiLU + 量化融合
在量化的 MLP 中,一个常见的模式是 SiLU 激活函数后接一个量化的下投影线性层。量化的线性层包括对输入进行量化操作,然后是量化的矩阵乘法。单独来看,SiLU 和量化操作速度慢且受内存限制。利用 Inductor 的模式匹配器工具,vLLM 中的 ActivationFusionPass 自定义通道将它们替换为单个融合算子,吞吐量提升高达 8%。
图 7:在 8x AMD MI300s 上对 Llama 3.1 405B 模型进行 FP8 量化测试,融合算子(`fusion`,黄色)的性能优于 `default`(使用 torch ops 实现 RMSNorm 和 SiLU,以及自定义 FP8 量化算子)和 `custom`(未融合的自定义算子)。
图 8:详细的吞吐量加速对比,比较了上述的 `fusion` 和 `default` 两种模式。如果通过融合完全消除了所有量化开销(8%),理论上吞吐量的最大提升将是 8%,我们可以看到在某些情况下确实达到了这个提升。
说明
自从那次 office hours 之后,我们增加了一个使用 torch 操作实现量化的方法,该方法(经 Inductor 编译后)比自定义的 CUDA/ROCm 算子更快。因为 Inductor 可以自动将这些 torch 操作与 SiLU 的 torch 操作融合,所以在某些情况下,SiLU+量化和 RMSNorm+量化通道现在已经过时了。然而,任何涉及自定义操作(注意力、集合通信、亚字节量化)的融合仍然需要自定义通道。我们在这里展示 SiLU+量化的例子是为了与 office hours 的幻灯片和录像保持一致,但其他融合通道的工作方式非常相似。
示例:序列并行 + 异步张量并行
当使用张量并行(TP)时,线性层会对权重进行分片并计算不完整的矩阵乘法结果,这些结果需要在 GPU 之间同步。如果对计算和通信部分使用独立的算子,我们会因为 GPU 在等待通信结果的网络延迟时处于空闲状态而产生通信开销。
相反,我们可以通过使用融合了 GEMM 和集合通信的算子来重叠计算和通信。这类算子的一个例子是 GEMM+reduce_scatter 和 all_gather+GEMM 算子。为了利用这些算子,我们需要将 all_reduce 集合操作分解为 reduce_scatter 和 all_gather,同时将 all_gather 推迟到 layernorm 之后,以便它能与接下来的 GEMM 融合。
如果我们要将这种优化实现在模型定义中,我们就必须修改 vLLM 支持的每一个模型(有数百个!)。这将是侵入性的,会破坏抽象,增加开发者摩擦,并且很可能一开始就不会被 vLLM 接受。相反,通过在 torch.compile 中实现该优化,它被限制在仅仅 2 个自定义通道中,并且可以通过命令行标志开启,为 vLLM 支持的所有模型提供更好的性能。
说明
这项优化由社区成员 @cascade812 完全实现,我们感谢他做出的卓越贡献。关于异步 TP 的更多信息可以在 PyTorch 博客上找到。
当前和即将推出的通道
今日可用
- 融合通道
- RMSNorm + 量化 (FP8) 融合
- SiLU-Mul + 量化 (FP8) 融合
- Attention + 量化 (FP8) 融合(最高提升 7%)
- AllReduce + RMSNorm 融合(最高提升 15%)
- AllReduce + RMSNorm + 量化 (FP8) 融合(最高提升 8%)
- AllReduce + RMSNorm + 量化 (FP4) 融合(最高提升 10%)
- 序列并行 & 异步 TP(最高提升 10%)
- 其他通道
- 无操作消除:消除或简化冗余的 reshape 操作
- 修复函数化:手动替换 auto_functionalized 操作,以避免冗余拷贝和内存使用
即将推出
通道可以通过 PostGradPassManager、命令行(--compilation-config)或在离线模式下指定一个配置对象来添加。这允许 vLLM 用户执行其用例所需的自定义图转换(算子替换或其他),而无需修改 vLLM 源代码。
未来工作
我们在 vLLM-torch.compile 集成方面已经取得了很大进展。以下是我们未来六个月将重点关注的一些领域。
提高稳定性
vLLM-torch.compile 集成使用了许多私有的(以下划线开头)torch.compile API,并依赖于不稳定的实现细节。我们这样做是因为使用公共的 torch.compile API 不足以满足我们的需求——vLLM 需要快速的服务性能,并且在模型服务期间不能有重新编译。这导致了一些问题,比如奇怪的缓存问题,或者需要为某些模型禁用 vLLM 的 torch.compile 缓存。PyTorch 编译器团队正在努力将 vLLM(以及通用推理)相关的功能从 vLLM 上游贡献到 torch.compile,并将 vLLM 迁移到使用更稳定的 API。其中许多功能已经存在于 torch 2.8 中,该版本将很快登陆 vLLM!
改善启动时间
我们了解到,对于 vLLM torch.compile 和 CUDAGraphs 来说,启动时间是一个巨大的痛点,尤其是在自动扩展的场景中,需要根据需求动态启动新机器。我们计划显著减少 vLLM 的冷启动(首次)和热启动(第二次及以后)时间,特别是与 Dynamo 和 Inductor 编译相关的时间。请关注 GitHub 上的 startup-ux 标签或加入 vLLM Slack 上的 #feat-startup-ux 频道以获取最新进展!
一个重要的用户体验改进是计划中的对 -O 命令行标志的改造。通过在 vLLM 命令行中指定 -O<n>(其中 n 是 0-3 之间的整数),用户将能更轻松地直接控制在启动时间和性能之间进行权衡。其中 -O0 几乎不执行任何优化,以最快速度启动,而 -O3 则会花费更长的时间,但能提供最佳性能。
自定义通道改进
我们计划对自定义通道机制进行一些广泛的改进,以增加其灵活性并使其更易于编写,同时提高应用优化后的最终性能。
- 编译多个动态形状的
torch.fx图。这将使我们能够根据批次的大小来特化前向传递图,而无需为每个静态大小单独编译。更多信息请参见 RFC。 - 启用对自定义操作的 torch 实现的匹配。目前,需要启用自定义操作(rms_norm、quant 等)才能进行模式匹配和融合,但可能有些自定义操作最终没有被融合(特别是对于每层发生 4 次的量化)。这些操作比它们的 torch 等价物慢,从而降低了融合带来的好处。我们有一个工作原型,可以对自定义操作的 torch 实现进行模式匹配,有望带来进一步的性能提升。
实验性的 torch.compile 后端集成
我们还在探索一个实验性的 MPK/Mirage 编译器集成。MPK 是一个精度调度的大算子(megakernel)编译器,这意味着它为整个模型的前向传递生成一个单一的算子,与 CUDA Graphs 相比,这可以进一步减少 CPU 开销并消除算子启动开销。关于提议的集成的更多信息请参见 RFC。
其他性能改进
vLLM 的 torch.compile 集成的目标是提供良好的基线性能,以避免需要编写和维护大量的自定义算子。我们将继续维护和提高性能。正在进行的工作的一些亮点包括:
- 改进的 FlexAttention 支持。FlexAttention 是一个 API,它允许使用不同的注意力变体,而无需为每种变体编写自定义的注意力算子。在底层,它使用 torch.compile 来生成一个自定义的 Triton 模板。
- 对 Flash Attention v2 和 FlashInfer 的完整 CUDA Graphs 支持。完整的 CUDAGraphs 比分段 CUDA Graphs 的开销更小,应该能在那些高开销的场景中提高性能。
结论
torch.compile 提供了一种强大且易于使用的方式来加速 PyTorch 模型。在 vLLM 中,它是推理流水线的核心部分。结合缓存、动态形状支持、CUDA Graphs 和自定义通道,它实现了在任何环境下的高效、可扩展的 LLM 服务。
随着编译器堆栈的成熟和对新硬件支持的扩展,torch.compile 和 vLLM 将继续推动推理性能的边界——同时保持模型开发的整洁和模块化。阅读更多关于 torch.compile 的信息,请参阅 PyTorch 文档和 vLLM 文档,并加入 vLLM Slack 上的 #sig-torch-compile 频道来提问、分享反馈,并贡献您自己的自定义通道!