cuda-oxide:从零加一个自定义 PTX intrinsic——五阶段管线 + 两个真实踩坑

走一遍真实流程:在 cuda-oxide 里加一个返回常量 42 的自定义 intrinsic xxx(),最终在 PTX 中通过内联汇编 mov.u32 %r1, 42; 实现。五个 crate 要改、两个错误版本一定会踩、convergent 属性是正确性而不是性能开关。把整套实战拆给你看。

📚 系列 cuda-oxide · 第 18 篇

拆解 07 讲了 intrinsic dispatch 的机制。这一篇是实战:从零写一个自定义 intrinsic xxx(),返回常量 42,最终在 PTX 里变成 mov.u32 %r1, 42;。涉及 5 个 crate、两个高频踩坑、一个常被低估的 convergent 属性。

0. 几个名词先说清楚

缩写 / 术语英文全称中文含义
intrinsicintrinsic编译器内建函数看起来像函数但没有真正函数体,编译时被替换成特定指令
桩函数stub function桩函数函数体写成 unreachable!() 等占位代码,backend 通过名字识别后替换
inline asminline assembly内联汇编在高级语言里直接嵌一段汇编字符串
convergentconvergent收敛LLVM 属性,告诉优化器”这条指令必须 warp 内同步执行,不准搬位置”
constraintconstraint string约束串内联汇编里告诉编译器每个操作数槽位是输出/输入、放哪类寄存器
clobberclobber list破坏列表告诉编译器这段 asm 会破坏哪些寄存器/内存,别假设旧值还活着
FQDNFully Qualified Domain Name完全限定路径名含完整 crate / module 路径的函数名
replace_operationrewriter API替换操作把旧 op 换成新 op,同时自动重定向所有 use

1. 任务

在 cuda-oxide 里加一个函数 thread::xxx(),Rust 端这样调:

let xxx = thread::xxx();    // xxx == 42

最终 PTX 里期望看到:

mov.u32 %r1, 42;

2. 五阶段管线总览

cuda-oxide 的 intrinsic 实现要动五个 crate:

cuda-device  →  mir-importer  →  dialect-nvvm  →  mir-lower  →  llvm-export
   桩函数         识别调用           IR op           lowering        textual LLVM IR
阶段crate改什么
1cuda-device#[inline(never)] 桩函数
2dialect-nvvm定义新 op、实现 Verify、注册
3mir-importertry_dispatch_intrinsic 里匹配 FQDN
4mir-lower实现 MirToLlvmConversion,写 lowering 函数
5llvm-export通常不用动(已有通用路径)

下面每一步给最小代码。

3. Stage 1:写桩函数

crates/cuda-device/src/thread.rs:

#[inline(never)]
pub fn xxx() -> u32 {
    // 最终被 backend 替换为:mov.u32 $0, 42;
    unreachable!("xxx called outside CUDA kernel context")
}

#[inline(never)] 是关键——rustc 不能 inline 这个函数,否则 MIR 里看不到对 xxx 的 Call,后面 dispatcher 没办法识别。

函数体永远不会被执行,因为 backend 在翻译 MIR 时直接替换整个 call,根本不进函数。

4. Stage 2:定义 NVVM op

crates/dialect-nvvm/src/ops/thread.rs:

#[pliron_op(
    name = "nvvm.read_ptx_xxx",
    format,
    interfaces = [NOpdsInterface<0>, NResultsInterface<1>],
)]
pub struct ReadPtxXXXOp;

impl ReadPtxXXXOp {
    pub fn new(op: Ptr<Operation>) -> Self {
        ReadPtxXXXOp { op }
    }
}

impl Verify for ReadPtxXXXOp {
    fn verify(&self, ctx: &Context) -> Result<(), Error> {
        let op = &*self.get_operation().deref(ctx);
        let res = op.get_result(0);
        let ty_obj = res.get_type(ctx).deref(ctx);

        let int_ty = ty_obj.downcast_ref::<IntegerType>()
            .ok_or_else(|| ...)?;
        if int_ty.width() != 32 {
            return verify_err!(op.loc(), "result must be 32-bit integer");
        }
        Ok(())
    }
}

pub(super) fn register(ctx: &mut Context) {
    ReadPtxXXXOp::register(ctx);
    // ...其它已有 op...
}

三件事:

干啥实现
声明 op#[pliron_op] 宏 + 名字 + 0 个操作数 + 1 个结果
类型校验Verify impl 检查 result 是 32-bit integer
注册register() 里加一行,让 pliron 知道这个 op 存在

5. Stage 3:importer 分发

crates/mir-importer/src/translator/terminator/mod.rs,在 try_dispatch_intrinsic 的大 match 里加一个分支:

"cuda_device::xxx" | "cuda_device::thread::xxx" => {
    Ok(Some(helpers::emit_nvvm_intrinsic(
        ctx,
        ReadPtxXXXOp::get_concrete_op_info(),
        destination, target, block_ptr, prev_op,
        value_map, block_map, loc,
    )?))
}

两个 case 都要加:cuda_device::xxx(用户 use cuda_device::xxx 后裸调)和 cuda_device::thread::xxx(完整路径调)。

emit_nvvm_intrinsic 是个 helper,它创建 NVVM op + 绑定结果到 destination SSA value + 发一条 branch 到 return target block。整个 Call 终结符被替换,桩函数体根本不会被翻译

6. Stage 4:lowering 到 LLVM inline asm

两处改动。

6.1 注册 conversion interface

crates/mir-lower/src/convert/interface_impls.rs:

use dialect_nvvm::ops::ReadPtxXXXOp;

#[op_interface_impl]
impl MirToLlvmConversion for ReadPtxXXXOp {
    fn convert(&self, ctx, rewriter, operands_info) -> Result<()> {
        super::intrinsics::basic::convert_ptx_xxx_i32(
            ctx, rewriter, self.get_operation(), operands_info,
        )
    }
}

6.2 写 lowering 函数

crates/mir-lower/src/convert/intrinsics/basic.rs:

pub(crate) fn convert_ptx_xxx_i32(
    ctx: &mut Context,
    rewriter: &mut DialectConversionRewriter,
    op: Ptr<Operation>,
    _operands_info: &OperandsInfo,
) -> Result<()> {
    let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);
    let asm_op = inline_asm_convergent(
        ctx, rewriter,
        i32_ty.into(),
        vec![],
        "mov.u32 $0, 42;",     // PTX 模板
        "=r",                  // 约束串:$0 是输出,32-bit 寄存器
    );
    rewriter.replace_operation(ctx, op, asm_op);
    Ok(())
}

写到这里听起来很简单。但实际上这 11 行代码我第一次写错了两个地方——分别就是下面两个踩坑案例。

7. Stage 5:导出

不用改。llvm-export 已经有通用路径处理 InlineAsmOp

8. 踩坑 1:=$0 是错的,应该是 =r

第一次写约束串时,我以为占位符 $0 在两个地方都能用——汇编模板里用 $0 指代操作数槽位,约束串里也用 $0 标记输出。完全错了

PTX 内联汇编约束串沿用 GCC / LLVM 那套语法:

出现位置用什么
汇编模板("mov.u32 $0, 42;")$0 $1 ... 指代操作数槽位
约束串("=r")=r 才是合法语法,标明该槽位是 32-bit 输出寄存器

写错版本:

"=$0"   // ✗ 不合法,LLVM 在后续 verify 时报错

正确版本:

"=r"    // ✓ 输出到 32-bit 寄存器

如果你之前看过 compiler-07 写 inline PTX 模板 里的约束串速查表,这个坑会绕过。

9. 踩坑 2:erase_operation vs replace_operation

第二次写错的是替换 op 的 API。

错误代码(导致 pliron 直接 panic):

pub fn convert_ptx_xxx_i32(/*...*/) -> Result<()> {
    let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);

    inline_asm_convergent(
        ctx, rewriter,
        i32_ty.into(),
        vec![],
        "mov.u32 $0, 42;",
        "=r",
    );

    rewriter.erase_operation(ctx, op);   // ✗ 这里出事
    Ok(())
}

跑起来直接 panic:

panicked at pliron/operation.rs:526:9: Operation with use(s) being erased
   at convert_ptx_xxx_i32 in basic.rs

9.1 根因

xxx() 返回 u32,被 gpu_printf! 用了——也就是说原 op 的结果有下游消费者

我调 erase_operation 想”删掉旧 op”,但 pliron 的不变量是:

要 erase 一个 op,它的所有 result 必须先没有任何 use(Operation::erase 入口处的 assert)。

直接 erase 会破坏 SSA 完整性,所以 pliron 拒绝。

9.2 正确做法

replace_operation 而不是 erase_operation:

let asm_op = inline_asm_convergent(...);
rewriter.replace_operation(ctx, op, asm_op);

replace_operation 干两件事:

  1. 把所有 op.result 的 use 重定向到 asm_op.result
  2. 删掉原 op

这是个原子操作——SSA 完整性始终保持。

9.3 经验法则

原 op 的结果状态该用哪个 API
有结果且下游引用了rewriter.replace_operation(ctx, op, new_op)
无结果(void)或纯副作用rewriter.erase_operation(ctx, op)

对照参考:

  • convert_threadfence_block:membar.cta 是 void → 用 erase_operation
  • convert_cluster_sreg:返回 i32 → 用 replace_operation

10. 深度讲解:inline_asm_convergent

我们用的 helper 函数,签名:

pub fn inline_asm_convergent(
    ctx: &mut Context,
    rewriter: &mut DialectConversionRewriter,
    result_ty: Ptr<TypeObj>,    // PTX 输出类型;无输出用 VoidType
    inputs: Vec<Value>,         // 输入 SSA 值列表
    asm_template: &str,         // PTX 模板,$0 $1 ... 是占位符
    constraints: &str,          // 约束串
) -> Ptr<Operation>

底层调 InlineAsmOp::new_convergent,等价于 InlineAsmOp::new(...) 加上 inline_asm_convergent = true 的 BoolAttr。

10.1 三种典型调用

有返回值,无输入(本次):

inline_asm_convergent(
    ctx, rewriter,
    i32_ty.into(),
    vec![],
    "mov.u32 $0, 42;",
    "=r",
)

无返回值(block 屏障):

inline_asm_convergent(
    ctx, rewriter,
    void_ty.into(),
    vec![],
    "membar.cta;",
    "~{memory}",
)

有输入有输出(从 shared memory 读):

inline_asm_convergent(
    ctx, rewriter,
    i32_ty.into(),
    vec![ptr_value],
    "ld.shared.u32 $0, [$1];",
    "=r,r",
)

10.2 返回值的处理

inline_asm_convergent 返回 Ptr<Operation>,它不会替调用方处理原 MIR op——必须自己二选一调 replace_operationerase_operation,这就是上面踩的坑。

11. convergent 到底是什么——正确性开关,不是性能开关

这是个特别容易低估的概念。

GPU 上,一个 warp 内的 32 个线程是锁步执行的。LLVM 优化器(GVN、CSE、loop unswitching、code motion)有时会把指令搬位置——这对普通 CPU 指令没事,但对需要 warp 内所有线程一起执行的指令是灾难。

11.1 典型的 convergent 指令

  • bar.sync(block 内同步屏障)
  • shfl.sync(warp 内寄存器互换)
  • vote.syncmbarrier.arrive
  • wgmma(Hopper 异步矩阵乘)

11.2 反例:不加 convergent 会死锁

if (tid % 2 == 0) {
    __syncthreads();   // 偶数线程进
} else {
    __syncthreads();   // 奇数线程进
}

优化器可能(它实际就会)把两个 __syncthreads() 合并提到 if 外面——它觉得反正都是无副作用的”同一条指令”。

结果:偶数线程的 bar 等不到奇数线程的 bar,整个 block 死锁

打上 convergent 属性后,LLVM 保证:“不准搬,不准跨控制流合并,不准跨控制流复制。“

11.3 关键认知

convergent 不是性能开关,是正确性开关。

凡是依赖 warp 同步语义的 PTX(屏障、shuffle、vote、async commit/wait),lowering 时必须带 convergent。否则你的 kernel 会偶发死锁或算错——而且因为优化器是非确定性的,bug 可能调试三天也复现不出来。

我们这次 mov.u32 $0, 42 严格说不需要 convergent(它跟 warp 同步无关),但用 inline_asm_convergent 包装是保守且无害的选择。

12. 端到端验证

加上 xxx() 之后,在 example 里调一下:

#[kernel]
pub unsafe fn hello_constant(out: *mut i32) {
    let xxx = thread::xxx();
    gpu_printf!("thread xxx: {}", xxx);
    unsafe { *out = 42 };
}

跑:

cargo oxide run hello_constant

输出会看到 256 个线程各打印一次 thread xxx: 42,因为 for_num_elems(1) 实际启动 256 个线程(for_num_elems(n) 是元素数不是线程数,见 host runtime 那篇 里的陷阱说明)。

13. 整条链路串起来

源码  let xxx = thread::xxx();

  ▼ rustc                       桩函数留作 Call,不 inline
MIR Call cuda_device::thread::xxx

  ▼ mir-importer (try_dispatch_intrinsic)
dialect-nvvm  ReadPtxXXXOp       Verify 检查 1 result × i32

  ▼ mir-lower (convert_ptx_xxx_i32)
dialect-llvm  llvm.inline_asm    "mov.u32 $0, 42;" "=r" + convergent

  ▼ llvm-export
LLVM IR       call i32 asm "mov.u32 $0, 42;", "=r"()

  ▼ llc / NVPTX
PTX           mov.u32 %r1, 42;

  ▼ CUDA driver JIT + launch
GPU           每个线程拿到 42

14. 复盘要点

  1. 桩函数必须 #[inline(never)],否则 rustc 把它 inline 掉,MIR 里看不到 Call,dispatcher 没机会识别
  2. inline_asm 约束串里没有 $0,只有 =r 这种——$0 是模板占位符,约束串用 =r/l/n/~{memory} 标类型
  3. lowering 替换原 op 时:有 use 用 replace_operation,无 use 用 erase_operation——pliron 在 Operation::erase 里 assert,错了直接 panic
  4. convergent 不是性能开关,是正确性开关——GPU 同步语义的 PTX 一定要带

15. 一句话总结

在 cuda-oxide 里加一个自定义 PTX intrinsic 要动五个 crate:cuda-device 写桩函数 + dialect-nvvm 定义 op + mir-importer 注册 FQDN dispatch + mir-lower 写 inline asm lowering + llvm-export 通常自动处理。两个高频踩坑是约束串语法(=$0 不合法,应该 =r)和 op 替换 API(有下游 use 必须用 replace_operation 而不是 erase_operation)。convergent 属性是正确性而不是性能开关——凡是 warp 同步语义的 PTX 都必须带,否则可能死锁。

系列上一篇: cuda-oxide:同一段代码在七层 IR 里长什么样

评论区
评论功能即将上线, 敬请期待。