CUDA Graph 是 CUDA 10.0 引入的一项功能,用于 捕获、表示和重放由多个 CUDA 操作(如 kernel 启动、内存拷贝、事件、流等待等)组成的 执行依赖图。其主要目的是:

  • 减少 CPU 开销:避免每次重复提交相同操作序列时的驱动开销。
  • 提升性能可预测性:通过静态图结构实现更稳定的执行时序。
  • 支持复杂依赖关系:显式表达操作间的依赖,便于优化调度。

一、CUDA Graph 中可以包含哪些内容?

一个 CUDA Graph 节点(cudaGraphNode_t)可以表示以下操作类型:

节点类型 说明
Kernel 节点 表示一次 kernel 启动(cudaGraphAddKernelNode
Memcpy 节点 Host↔Device、Device↔Device、Host↔Host 的内存拷贝(cudaGraphAddMemcpyNode
Memset 节点 设备内存填充(cudaGraphAddMemsetNode
Host 节点 在主机上执行一个回调函数(cudaGraphAddHostNode
Event 节点 记录或等待 CUDA 事件(cudaGraphAddEventRecordNode, cudaGraphAddEventWaitNode
空节点(Empty) 用于构建依赖结构,无实际操作(cudaGraphAddEmptyNode
子图节点(CUDA 12+) 嵌套图(较少用)

二、如何将 kernel 启动配置和参数传入 Graph?

每个 kernel 节点通过 cudaKernelNodeParams 结构体定义,包含:

  • 函数指针(func
  • 网格和块尺寸(gridDim, blockDim
  • 共享内存大小(sharedMemBytes
  • kernel 参数(通过 kernelParamsextra 传入,类似 cudaLaunchKernel

注意:参数必须在图构建时提供完整副本(值语义),不能依赖外部指针在运行时变化(除非使用 CUDA Graph 的“可更新节点”机制)。


三、完整示例:包含 kernel、host-device 拷贝、host 回调

以下是一个完整示例,展示:

  1. 从 host 拷贝数据到 device
  2. 启动一个 kernel
  3. 从 device 拷贝结果回 host
  4. 执行一个 host 回调(打印结果)
#include <cuda_runtime.h>
#include <iostream>
#include <vector>

__global__ void addKernel(float* a, float* b, float* c, int n) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n) c[idx] = a[idx] + b[idx];
}

// Host callback function
void CUDART_CB hostCallback(void* data) {
    float* result = static_cast<float*>(data);
    std::cout << "Host callback: result[0] = " << result[0] << std::endl;
}

int main() {
    const int N = 1024;
    const size_t bytes = N * sizeof(float);

    // Host data
    std::vector<float> h_a(N, 1.0f);
    std::vector<float> h_b(N, 2.0f);
    std::vector<float> h_c(N, 0.0f);

    // Device pointers
    float *d_a, *d_b, *d_c;
    cudaMalloc(&d_a, bytes);
    cudaMalloc(&d_b, bytes);
    cudaMalloc(&d_c, bytes);

    // Create graph
    cudaGraph_t graph;
    cudaGraphCreate(&graph, 0);

    // 1. Memcpy H2D for a
    cudaMemcpy3DParms copyA = {0};
    copyA.srcPtr = make_cudaPitchedPtr(h_a.data(), bytes, N, 1);
    copyA.dstPtr = make_cudaPitchedPtr(d_a, bytes, N, 1);
    copyA.extent = make_cudaExtent(bytes, 1, 1);
    copyA.kind = cudaMemcpyHostToDevice;

    cudaGraphNode_t memcpyH2D_a;
    cudaGraphAddMemcpyNode(&memcpyH2D_a, graph, nullptr, 0, &copyA);

    // 2. Memcpy H2D for b
    cudaMemcpy3DParms copyB = {0};
    copyB.srcPtr = make_cudaPitchedPtr(h_b.data(), bytes, N, 1);
    copyB.dstPtr = make_cudaPitchedPtr(d_b, bytes, N, 1);
    copyB.extent = make_cudaExtent(bytes, 1, 1);
    copyB.kind = cudaMemcpyHostToDevice;

    cudaGraphNode_t memcpyH2D_b;
    cudaGraphAddMemcpyNode(&memcpyH2D_b, graph, nullptr, 0, &copyB);

    // 3. Kernel node
    dim3 grid((N + 255) / 256), block(256);
    void* kernelArgs[] = {&d_a, &d_b, &d_c, &N};
    cudaKernelNodeParams kernelParams = {0};
    kernelParams.func = (void*)addKernel;
    kernelParams.gridDim = grid;
    kernelParams.blockDim = block;
    kernelParams.sharedMemBytes = 0;
    kernelParams.kernelParams = kernelArgs;
    kernelParams.extra = nullptr;

    cudaGraphNode_t kernelNode;
    cudaGraphAddKernelNode(&kernelNode, graph, nullptr, 0, &kernelParams);

    // Add dependencies: memcpy must finish before kernel
    cudaGraphAddDependencies(graph, &memcpyH2D_a, &kernelNode, 1);
    cudaGraphAddDependencies(graph, &memcpyH2D_b, &kernelNode, 1);

    // 4. Memcpy D2H for c
    cudaMemcpy3DParms copyC = {0};
    copyC.srcPtr = make_cudaPitchedPtr(d_c, bytes, N, 1);
    copyC.dstPtr = make_cudaPitchedPtr(h_c.data(), bytes, N, 1);
    copyC.extent = make_cudaExtent(bytes, 1, 1);
    copyC.kind = cudaMemcpyDeviceToHost;

    cudaGraphNode_t memcpyD2H_c;
    cudaGraphAddMemcpyNode(&memcpyD2H_c, graph, &kernelNode, 1, &copyC);

    // 5. Host callback node
    cudaHostNodeParams hostParams = {0};
    hostParams.fn = hostCallback;
    hostParams.userData = h_c.data();  // 注意:必须确保 h_c 在回调时仍有效!

    cudaGraphNode_t hostNode;
    cudaGraphAddHostNode(&hostNode, graph, &memcpyD2H_c, 1, &hostParams);

    // Instantiate and launch
    cudaGraphExec_t graphExec;
    cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);

    // Launch multiple times
    for (int i = 0; i < 3; ++i) {
        cudaGraphLaunch(graphExec, 0);
        cudaDeviceSynchronize(); // 确保 host 回调执行完毕
    }

    // Cleanup
    cudaGraphExecDestroy(graphExec);
    cudaGraphDestroy(graph);
    cudaFree(d_a);
    cudaFree(d_b);
    cudaFree(d_c);

    return 0;
}

四、关键注意事项

  1. 参数生命周期

    • kernelArgs 中的指针(如 &d_a)在 cudaGraphAddKernelNode 调用时被拷贝值,所以 d_a 本身可以后续变化,但 d_a 指向的设备内存必须在图执行时有效。
    • Host 回调中的 userData(如 h_c.data())必须在图执行期间保持有效(不能是局部变量临时地址)。
  2. 依赖关系

    • 使用 cudaGraphAddDependencies 显式建立节点间依赖。
    • 也可以在 cudaGraphAddXxxNodepDependencies 参数中直接指定前驱节点。
  3. 性能

    • 图一旦实例化(cudaGraphInstantiate),就无法修改结构(但 CUDA 11.1+ 支持可更新图)。
  4. 错误检查

    • 实际代码中应检查所有 CUDA API 返回值(为简洁省略)。

总结

CUDA Graph 允许你将多个异构操作(kernel、memcpy、host 回调等)打包成一个可高效重放的执行图。通过 cudaKernelNodeParams 传入 kernel 的配置和参数,并通过依赖关系精确控制执行顺序。这对于需要重复执行相同计算流程的高性能应用(如 CFD 时间步推进)非常有用。

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐