追踪挂起和复杂的 GPU 内核到源代码
几个月前,我们发表了一篇题为《CUDA 核心转储:调试内存访问问题及其他问题的有效工具》的博客文章,介绍了一种调试 CUDA 内核中非法内存访问问题的强大技术。这代表了 GPU 内核调试的一个重要里程碑,因为它使开发人员能够查明导致故障的确切内核。以前,由于 GPU 执行的异步特性,识别问题内核几乎是不可能的,并且错误消息常常具有误导性。
随着 CUDA 核心转储技术的普及,开发人员表达了对更细粒度信息的需求——特别是触发问题的源代码的确切行。在这篇博客文章中,我们首先介绍如何识别挂起的内核,然后演示如何将问题内核追溯到其源代码,从而解决这一空白。
如何查找挂起的内核
GPU 计算能力呈指数级增长,但内存带宽未能跟上。这种不平衡导致了日益复杂的内存访问模式。近年来,旗舰数据中心 GPU 引入了异步内存访问模式,这需要在实现高性能内核时进行复杂的同步。这些同步机制容易出现竞态条件和死锁,尤其是在复杂的代码库中。
当 GPU 内核挂起时,程序通常会冻结或变得无响应——即使按下 Ctrl-C 也无法停止它。最直接的解决方案是终止进程,但这种方法无法提供有关根本原因的信息。开发人员只能盲目猜测,通过二分法代码更改并迭代运行测试,直到识别出问题。
说明
当 CUDA 内核挂起时,为什么按 Ctrl-C 不能停止进程?按 Ctrl-C 会向进程发送 SIGINT 信号。如果进程正在运行 Python 代码,SIGINT 信号会被 Python 解释器捕获,它将其转换为 KeyboardInterrupt 异常并将异常排队,以便在进程返回运行 Python 代码后进行处理。但是,如果进程正在运行 CUDA 内核并等待 GPU 完成,它正在等待低级 CUDA API 返回,而没有 Python 代码正在运行,因此无法引发 KeyboardInterrupt 异常。在下面的 conditional_hang.py 示例中,如果您想通过 Ctrl-C 终止进程,您需要在脚本开头添加 import signal; signal.signal(signal.SIGINT, signal.SIG_DFL),以便 Python 解释器不捕获 SIGINT 信号,这样 Ctrl-C 才能成功终止进程。缺点是 Python 解释器在被 Ctrl-C 停止时将无法显示错误堆栈。
幸运的是,有更好的方法。CUDA 驱动程序包含一个名为 用户诱导 GPU 核心转储生成 的功能:驱动程序在操作系统中打开管道,允许用户通过写入它们来触发核心转储。触发时,CUDA 驱动程序将 GPU 状态转储到核心转储文件中,从而能够检查 GPU 内部发生的情况,最重要的是,识别哪个 GPU 内核正在挂起。
考虑一个条件挂起内核的简单示例
# save as conditional_hang.py
import triton
import triton.language as tl
import torch
@triton.jit
def conditional_hang_kernel(x_ptr,
flag, # int32 scalar
n_elements, # int32 scalar
BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offs < n_elements
# Load values
x = tl.load(x_ptr + offs, mask=mask, other=0)
# If flag == 1: do a normal "+1" update
if flag == 1:
x = x + 1
tl.store(x_ptr + offs, x, mask=mask)
else:
# Else: non-terminating loop, no break.
# The loop condition depends on `flag`, which is invariant,
# so this is effectively an infinite loop when flag == 0.
while flag == 0:
# do something trivial so the loop isn't optimized away
x = x + 1
tl.store(x_ptr + offs, x, mask=mask)
x = torch.ones(16, dtype=torch.float32, device="cuda")
n_elements = x.numel()
BLOCK_SIZE = 16
# 1) Normal behavior: increment by 1
conditional_hang_kernel[(1,)](
x,
flag=1,
n_elements=n_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
print("After flag=1:", x) # should be all 2s
# 2) Hanging behavior: this will spin forever
conditional_hang_kernel[(1,)](
x,
flag=0,
n_elements=n_elements,
BLOCK_SIZE=BLOCK_SIZE,
)
# this print will hang, because printing x will synchronize the device,
# and the kernel will never finish.
print("After flag=0:", x)
# the following line will never be reached
x = x + 2
torch.cuda.synchronize()
执行此代码将无限期挂起。要调试此问题,我们可以启用用户诱导的 GPU 核心转储生成
CUDA_ENABLE_USER_TRIGGERED_COREDUMP=1 \
CUDA_COREDUMP_PIPE="/tmp/cuda_coredump_pipe_%h.%p.%t" \
CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \
CUDA_COREDUMP_SHOW_PROGRESS=1 \
CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \
CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \
python conditional_hang.py
当代码无限期运行时,我们可以通过写入管道来触发 CUDA 核心转储
dd if=/dev/zero bs=1M count=1 > /tmp/cuda_coredump_pipe_hostname.3000837.1764236276
我们向管道写入 1MB 零以触发 CUDA 核心转储。请注意,由于管道缓冲,简单的 echo 命令可能无法工作。
触发核心转储后,运行 python conditional_hang.py 的原始终端将显示核心转储进度
[01:39:15.256278] coredump: Writing ELF file to /tmp/cuda_coredump_hostname.3000837.1764236276
[01:39:15.256350] coredump: Writing out global memory (0 bytes)
[01:39:15.256354] coredump: Writing out device table
[01:39:15.292027] coredump: Writing out metadata
[01:39:15.292039] coredump: Finalizing
[01:39:15.292124] coredump: Writing done
[01:39:15.292128] coredump: All done (took 00s)
然后我们可以使用 cuda-gdb 打开核心转储文件并准确查看内核挂起的位置
Opening GPU coredump: /tmp/cuda_coredump_hostname.3000837.1764236276
[Current focus set to CUDA kernel 0, grid 53, block (0,0,0), thread (0,0,0), device 0, sm 124, warp 0, lane 0]
#0 0x00007f2e6fbff300 in conditional_hang_kernel<<<(1,1,1),(128,1,1)>>> () at conditional_hang.py:31
31 tl.store(x_ptr + offs, x, mask=mask)
这种方法不仅可以识别挂起的内核 (conditional_hang_kernel),还可以查明它挂起的代码行。这与以前的情况相比是一个显著的改进,以前识别问题内核是不可能的,更不用说导致挂起的特定行了。
一个轻微的不便之处是核心转储管道的路径由 CUDA 驱动程序动态生成,使其难以定位。我们可以通过使用 CUDA_COREDUMP_PIPE 环境变量来指定核心转储管道的模板路径来解决这个问题,从而通过检查进程的文件描述符轻松找到它
$ ls /proc/3037675/fd/ -alth | grep /tmp/cuda_coredump_pipe_
lr-x------ 1 user user 64 Nov 27 01:50 98 -> /tmp/cuda_coredump_pipe_hostname.3037675.1764237014
如何追溯复杂内核的源代码
在上一篇博客文章中,我们提到使用 export NVCC_PREPEND_FLAGS='-lineinfo' 环境变量进行编译会将行信息嵌入到编译后的二进制文件中,从而使我们能够追溯到导致问题的确切代码行。在讨论和调试几个实际问题后,我们发现 cuda-gdb 显示行信息的默认方式并不完美
-
对于某些复杂内核,即使行信息嵌入在编译后的二进制文件中,
cuda-gdb也无法找到导致问题的正确代码行。 -
即使
cuda-gdb可以找到正确的代码行,它也只显示编译器内联后的最后一行,这可能不是导致问题的实际行。由于 C++ 代码严重依赖内联来消除运行时函数调用开销,我们需要完整的内联堆栈来理解问题。
让我们通过一个具体的例子来说明这一点。以下 Python 脚本演示了一个非法内存访问问题
# save as illegal_memory_access.py
from dataclasses import dataclass
import torch
@dataclass
class TensorWrapper:
data_ptr: int
size_in_bytes: int
@property
def __cuda_array_interface__(self):
return {
"shape": (self.size_in_bytes,),
"typestr": '|u1',
"data": (self.data_ptr, False),
"version": 3,
}
def from_buffer(data_ptr: int, size_in_bytes: int, device: str, dtype: torch.dtype) -> torch.Tensor:
return torch.as_tensor(TensorWrapper(data_ptr, size_in_bytes), device=device).view(dtype)
data = from_buffer(123456, 1024, device="cuda:0", dtype=torch.uint8)
index = torch.ones(10, device="cuda", dtype=torch.int32) + 100
print(data[index])
使用 PyTorch >= 2.9.0 运行此代码(特别是,确保它包含此提交;否则您将看到诸如 RuntimeError: The specified pointer resides on host memory and is not registered with any CUDA device. 之类的错误)。这将触发非法内存访问错误。
首先,让我们在启用 CUDA 核心转储的情况下运行代码
CUDA_ENABLE_COREDUMP_ON_EXCEPTION=1 \
CUDA_COREDUMP_SHOW_PROGRESS=1 \
CUDA_COREDUMP_GENERATION_FLAGS='skip_nonrelocated_elf_images,skip_global_memory,skip_shared_memory,skip_local_memory,skip_constbank_memory' \
CUDA_COREDUMP_FILE="/tmp/cuda_coredump_%h.%p.%t" \
python illegal_memory_access.py
核心转储进度将明确识别导致问题的内核
_ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_
从内核名称中,我们可以看到问题是由 PyTorch 的 index_elementwise_kernel 引起的。要找到导致问题的确切代码行,我们需要使用 export NVCC_PREPEND_FLAGS='-lineinfo' 环境变量从源代码构建 PyTorch,然后再次运行代码。
当编译后的 GPU 内核嵌入了行信息时,我们可以使用 cuda-gdb 打开核心转储文件并准确查看导致问题的代码行
(cuda-gdb) target cudacore /tmp/cuda_coredump_flow-matic.3756036.1764250282
Opening GPU coredump: /tmp/cuda_coredump_flow-matic.3756036.1764250282
[Current focus set to CUDA kernel 0, grid 4, block (0,0,0), thread (0,0,0), device 0, sm 124, warp 3, lane 0]
CUDA Exception: Warp Illegal Address
The exception was triggered at PC 0x7ff533bb91d0 ...
#0 void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel<at::native::index_kernel_impl<at::native::OpaqueType<1> >(at
::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef<
long>, c10::ArrayRef<long>, at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayR
ef<long>)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel<at::native::index_kernel_imp
l<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1}>(at::Ten
sorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&,
c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1})<<<(1,1,1),(128,1,1)>>> ()
at /data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu:203 in _ZZN2at6native17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS
_18TensorIteratorBaseEN3c108ArrayRefIlEES8_ENKUlPcPKclE_clES9_SB_l inlined from IndexKernel.cu:118
203 *reinterpret_cast<scalar_t*>(out_data) = *reinterpret_cast<const scalar_t*>(in_data + offset);
接下来,在 cuda-gdb 中,我们可以使用 info symbol $errorpc 来获取有关错误位置的更多信息
(cuda-gdb) info symbol $errorpc
void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_kernel<at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}>(long, at::native::gpu_index_kernel<at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1}>(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>, at::native::index_kernel_impl<at::native::OpaqueType<1> >(at::TensorIteratorBase&, c10::ArrayRef<long>, c10::ArrayRef<long>)::{lambda(char*, char const*, long)#1} const&, bool)::{lambda(int)#1}) + 11472 in section .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_ of /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn
这提供了有关错误位置的更多信息。cuda-gdb 解压缩编译后的二进制文件,/tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn 是一个包含 index_elementwise_kernel 的 cubin 文件。错误发生在 cubin 文件中的位置 0x7ff533bb91d0。我们可以使用 nvdisasm 反汇编 cubin 文件并准确查看是哪行代码导致了问题
$ nvdisasm -ndf -c -gi /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt
$ grep -C20 7ff533bb91d0 output.txt
...
/*7ff533bb9190*/ IMAD.IADD R19, R23, 0x1, R3 ;
.L_x_27840:
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 203 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37
/*7ff533bb91a0*/ ULDC.64 UR4, c[0x0][0x480] ;
/*7ff533bb91b0*/ IADD3 R2, P0, P1, R22, UR4, R2 ;
/*7ff533bb91c0*/ IADD3.X R3, R19, UR5, RZ, P0, P1 ;
/*7ff533bb91d0*/ LDG.E.U8 R3, desc[UR36][R2.64] ;
...
现在我们可以看到导致问题的代码的完整内联堆栈。默认情况下,cuda-gdb 只显示最后的内联展开。
命令的简要说明
-ndf:反汇编后禁用数据流分析器。-c:只打印代码段。-gi:使用从 .debug_line 部分获取的源行信息以及函数内联信息(如果存在)来注释反汇编。-C20:一个grep参数,显示在找到的程序计数器地址7ff533bb91d0周围的 20 行上下文。
如果 cubin 文件包含多个具有相同程序计数器地址的内核(即 grep 显示多个匹配项),我们需要进一步过滤信息
$ cuobjdump -elf /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > elf.txt
$ cat elf.txt | grep ".text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_" | grep PROGBITS
1ac 1b83f80 b200 0 80 PROGBITS 6 3 26a .text._ZN2at6native24index_elementwise_kernelILi128ELi4EZNS0_16gpu_index_kernelIZNS0_17index_kernel_implINS0_10OpaqueTypeILi1EEEEEvRNS_18TensorIteratorBaseEN3c108ArrayRefIlEESA_EUlPcPKclE_EEvS7_SA_SA_RKT_bEUliE_EEvlT1_
$ nvdisasm -ndf -c -gi -fun 0x26a /tmp/cuda-dbg/2123124/session1/elf.21407f80.24fe2940.o.4gyLzn > output.txt
$ grep -C20 7ff533bb91d0 output.txt
...
/*7ff533bb9190*/ IMAD.IADD R19, R23, 0x1, R3 ;
.L_x_27840:
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 203 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 118 inlined at "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37
//## File "/data/youkaichao/pytorch/aten/src/ATen/native/cuda/IndexKernel.cu", line 37
/*7ff533bb91a0*/ ULDC.64 UR4, c[0x0][0x480] ;
/*7ff533bb91b0*/ IADD3 R2, P0, P1, R22, UR4, R2 ;
/*7ff533bb91c0*/ IADD3.X R3, R19, UR5, RZ, P0, P1 ;
/*7ff533bb91d0*/ LDG.E.U8 R3, desc[UR36][R2.64] ;
...
主要区别在于通过搜索函数的 ELF 段从 cuobjdump 获取 CUDA 函数索引(-fun 参数),在本例中为 26a。
请注意,这是一个简化示例,旨在演示该技术。实际内核可能复杂得多。例如,这是一个复杂的内联情况
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 185 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_traits.hpp", line 133 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 103 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/atom/copy_atom.hpp", line 124 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 211 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/algorithm/copy.hpp", line 412 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265
//## File "/data/youkaichao/data/vllm_flash_attn/hopper/epilogue_fwd.hpp", line 265 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454
//## File "/data/youkaichao/data/vllm_flash_attn/hopper/flash_fwd_kernel_sm90.h", line 454 inlined at "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41
//## File "/data/youkaichao/data/vllm_flash_attn/hopper/utils.h", line 41 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122
//## File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cutlass/device_kernel.h", line 122
/*7eebf5e9eb80*/ STSM.16.M88.4 [R13], R4 ;
/*7eebf5e9eb90*/ MOV R34, R26 ;
在这种情况下,有问题的代码是
注意力内核中一行被污染的代码。
有问题的源代码调用了一些 CUTLASS 函数,并且包含它的函数也被上层调用者内联。在这种情况下,cuda-gdb 无法正确关联行。实际上,它没有显示错误位置附近的任何行信息。即使它显示了正确的行,它也只显示最后一个内联帧,即 File "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/copy_sm90.hpp", line 93 inlined at "/data/youkaichao/data/vllm_flash_attn/csrc/cutlass/include/cute/arch/util.hpp", line 158——CUTLASS 函数的内部内联展开,对调试底层问题仍然没有帮助。
使用上面概述的方法,我们可以揭示源代码的完整内联链,并仔细检查每个帧以确定哪一行导致了错误。
警告:为了最大限度地利用 CUDA 核心转储,行信息至关重要。建议使用 export NVCC_PREPEND_FLAGS='-lineinfo' 环境变量进行编译,因为这可以透明地应用于所有编译的内核,而无需修改编译脚本。但是,这种透明性意味着如果您使用编译缓存机制(例如 ccache),它可能会忽略该标志并重用以前编译的结果而无需实际编译。从源代码编译时,请确保禁用编译缓存机制。如果您使用即时编译,请查阅您的即时编译工具的文档,了解如何添加行信息。
结论
这篇博客文章介绍了两种用于 CUDA 内核的高级调试技术。第一种技术使用用户触发的核心转储来识别挂起的内核,而第二种技术通过利用编译后的二进制文件中嵌入的行信息将复杂内核追溯到其源代码。这些技术是调试 CUDA 内核中复杂问题(尤其是非法内存访问问题)的强大工具。通过两者结合使用,我们最近成功调试了 CUTLASS MLA 注意力后端中一个难以重现的棘手挂起,该问题实际上源于上游 CUTLASS 代码示例,并已在 v4.3.0 中修复。
vLLM 项目旨在为每个人提供简单、快速、经济的 LLM 服务,而可访问的调试是这一使命的重要方面。我们将在未来继续分享更多调试技巧和技术,共同构建一个强大的 LLM 推理生态系统。要分享您与 vLLM 的故事或用法,请在博客文章存储库提交 PR。
致谢
我们要感谢 NVIDIA 的 Ze Long 和 Sandarbh Jain 提供的有益讨论。Moonshot AI 的 Chao Hong 帮助提供了激励性示例。Red Hat 的 Lucas Wilkinson 帮助润色了草稿。