拆解 07 讲了 intrinsic dispatch 的机制。这一篇是实战:从零写一个自定义 intrinsic
xxx(),返回常量42,最终在 PTX 里变成mov.u32 %r1, 42;。涉及 5 个 crate、两个高频踩坑、一个常被低估的convergent属性。
0. 几个名词先说清楚
| 缩写 / 术语 | 英文全称 | 中文 | 含义 |
|---|---|---|---|
| intrinsic | intrinsic | 编译器内建函数 | 看起来像函数但没有真正函数体,编译时被替换成特定指令 |
| 桩函数 | stub function | 桩函数 | 函数体写成 unreachable!() 等占位代码,backend 通过名字识别后替换 |
| inline asm | inline assembly | 内联汇编 | 在高级语言里直接嵌一段汇编字符串 |
| convergent | convergent | 收敛 | LLVM 属性,告诉优化器”这条指令必须 warp 内同步执行,不准搬位置” |
| constraint | constraint string | 约束串 | 内联汇编里告诉编译器每个操作数槽位是输出/输入、放哪类寄存器 |
| clobber | clobber list | 破坏列表 | 告诉编译器这段 asm 会破坏哪些寄存器/内存,别假设旧值还活着 |
| FQDN | Fully Qualified Domain Name | 完全限定路径名 | 含完整 crate / module 路径的函数名 |
replace_operation | rewriter API | 替换操作 | 把旧 op 换成新 op,同时自动重定向所有 use |
1. 任务
在 cuda-oxide 里加一个函数 thread::xxx(),Rust 端这样调:
let xxx = thread::xxx(); // xxx == 42
最终 PTX 里期望看到:
mov.u32 %r1, 42;
2. 五阶段管线总览
cuda-oxide 的 intrinsic 实现要动五个 crate:
cuda-device → mir-importer → dialect-nvvm → mir-lower → llvm-export
桩函数 识别调用 IR op lowering textual LLVM IR
| 阶段 | crate | 改什么 |
|---|---|---|
| 1 | cuda-device | 加 #[inline(never)] 桩函数 |
| 2 | dialect-nvvm | 定义新 op、实现 Verify、注册 |
| 3 | mir-importer | 在 try_dispatch_intrinsic 里匹配 FQDN |
| 4 | mir-lower | 实现 MirToLlvmConversion,写 lowering 函数 |
| 5 | llvm-export | 通常不用动(已有通用路径) |
下面每一步给最小代码。
3. Stage 1:写桩函数
crates/cuda-device/src/thread.rs:
#[inline(never)]
pub fn xxx() -> u32 {
// 最终被 backend 替换为:mov.u32 $0, 42;
unreachable!("xxx called outside CUDA kernel context")
}
#[inline(never)] 是关键——rustc 不能 inline 这个函数,否则 MIR 里看不到对 xxx 的 Call,后面 dispatcher 没办法识别。
函数体永远不会被执行,因为 backend 在翻译 MIR 时直接替换整个 call,根本不进函数。
4. Stage 2:定义 NVVM op
crates/dialect-nvvm/src/ops/thread.rs:
#[pliron_op(
name = "nvvm.read_ptx_xxx",
format,
interfaces = [NOpdsInterface<0>, NResultsInterface<1>],
)]
pub struct ReadPtxXXXOp;
impl ReadPtxXXXOp {
pub fn new(op: Ptr<Operation>) -> Self {
ReadPtxXXXOp { op }
}
}
impl Verify for ReadPtxXXXOp {
fn verify(&self, ctx: &Context) -> Result<(), Error> {
let op = &*self.get_operation().deref(ctx);
let res = op.get_result(0);
let ty_obj = res.get_type(ctx).deref(ctx);
let int_ty = ty_obj.downcast_ref::<IntegerType>()
.ok_or_else(|| ...)?;
if int_ty.width() != 32 {
return verify_err!(op.loc(), "result must be 32-bit integer");
}
Ok(())
}
}
pub(super) fn register(ctx: &mut Context) {
ReadPtxXXXOp::register(ctx);
// ...其它已有 op...
}
三件事:
| 干啥 | 实现 |
|---|---|
| 声明 op | #[pliron_op] 宏 + 名字 + 0 个操作数 + 1 个结果 |
| 类型校验 | Verify impl 检查 result 是 32-bit integer |
| 注册 | 在 register() 里加一行,让 pliron 知道这个 op 存在 |
5. Stage 3:importer 分发
crates/mir-importer/src/translator/terminator/mod.rs,在 try_dispatch_intrinsic 的大 match 里加一个分支:
"cuda_device::xxx" | "cuda_device::thread::xxx" => {
Ok(Some(helpers::emit_nvvm_intrinsic(
ctx,
ReadPtxXXXOp::get_concrete_op_info(),
destination, target, block_ptr, prev_op,
value_map, block_map, loc,
)?))
}
两个 case 都要加:cuda_device::xxx(用户 use cuda_device::xxx 后裸调)和 cuda_device::thread::xxx(完整路径调)。
emit_nvvm_intrinsic 是个 helper,它创建 NVVM op + 绑定结果到 destination SSA value + 发一条 branch 到 return target block。整个 Call 终结符被替换,桩函数体根本不会被翻译。
6. Stage 4:lowering 到 LLVM inline asm
两处改动。
6.1 注册 conversion interface
crates/mir-lower/src/convert/interface_impls.rs:
use dialect_nvvm::ops::ReadPtxXXXOp;
#[op_interface_impl]
impl MirToLlvmConversion for ReadPtxXXXOp {
fn convert(&self, ctx, rewriter, operands_info) -> Result<()> {
super::intrinsics::basic::convert_ptx_xxx_i32(
ctx, rewriter, self.get_operation(), operands_info,
)
}
}
6.2 写 lowering 函数
crates/mir-lower/src/convert/intrinsics/basic.rs:
pub(crate) fn convert_ptx_xxx_i32(
ctx: &mut Context,
rewriter: &mut DialectConversionRewriter,
op: Ptr<Operation>,
_operands_info: &OperandsInfo,
) -> Result<()> {
let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);
let asm_op = inline_asm_convergent(
ctx, rewriter,
i32_ty.into(),
vec![],
"mov.u32 $0, 42;", // PTX 模板
"=r", // 约束串:$0 是输出,32-bit 寄存器
);
rewriter.replace_operation(ctx, op, asm_op);
Ok(())
}
写到这里听起来很简单。但实际上这 11 行代码我第一次写错了两个地方——分别就是下面两个踩坑案例。
7. Stage 5:导出
不用改。llvm-export 已经有通用路径处理 InlineAsmOp。
8. 踩坑 1:=$0 是错的,应该是 =r
第一次写约束串时,我以为占位符 $0 在两个地方都能用——汇编模板里用 $0 指代操作数槽位,约束串里也用 $0 标记输出。完全错了。
PTX 内联汇编约束串沿用 GCC / LLVM 那套语法:
| 出现位置 | 用什么 |
|---|---|
汇编模板("mov.u32 $0, 42;") | $0 $1 ... 指代操作数槽位 |
约束串("=r") | =r 才是合法语法,标明该槽位是 32-bit 输出寄存器 |
写错版本:
"=$0" // ✗ 不合法,LLVM 在后续 verify 时报错
正确版本:
"=r" // ✓ 输出到 32-bit 寄存器
如果你之前看过 compiler-07 写 inline PTX 模板 里的约束串速查表,这个坑会绕过。
9. 踩坑 2:erase_operation vs replace_operation
第二次写错的是替换 op 的 API。
错误代码(导致 pliron 直接 panic):
pub fn convert_ptx_xxx_i32(/*...*/) -> Result<()> {
let i32_ty = IntegerType::get(ctx, 32, Signedness::Signless);
inline_asm_convergent(
ctx, rewriter,
i32_ty.into(),
vec![],
"mov.u32 $0, 42;",
"=r",
);
rewriter.erase_operation(ctx, op); // ✗ 这里出事
Ok(())
}
跑起来直接 panic:
panicked at pliron/operation.rs:526:9: Operation with use(s) being erased
at convert_ptx_xxx_i32 in basic.rs
9.1 根因
xxx() 返回 u32,被 gpu_printf! 用了——也就是说原 op 的结果有下游消费者。
我调 erase_operation 想”删掉旧 op”,但 pliron 的不变量是:
要 erase 一个 op,它的所有 result 必须先没有任何 use(
Operation::erase入口处的 assert)。
直接 erase 会破坏 SSA 完整性,所以 pliron 拒绝。
9.2 正确做法
用 replace_operation 而不是 erase_operation:
let asm_op = inline_asm_convergent(...);
rewriter.replace_operation(ctx, op, asm_op);
replace_operation 干两件事:
- 把所有
op.result的 use 重定向到asm_op.result - 删掉原 op
这是个原子操作——SSA 完整性始终保持。
9.3 经验法则
| 原 op 的结果状态 | 该用哪个 API |
|---|---|
| 有结果且下游引用了 | rewriter.replace_operation(ctx, op, new_op) |
| 无结果(void)或纯副作用 | rewriter.erase_operation(ctx, op) |
对照参考:
convert_threadfence_block:membar.cta是 void → 用erase_operationconvert_cluster_sreg:返回 i32 → 用replace_operation
10. 深度讲解:inline_asm_convergent
我们用的 helper 函数,签名:
pub fn inline_asm_convergent(
ctx: &mut Context,
rewriter: &mut DialectConversionRewriter,
result_ty: Ptr<TypeObj>, // PTX 输出类型;无输出用 VoidType
inputs: Vec<Value>, // 输入 SSA 值列表
asm_template: &str, // PTX 模板,$0 $1 ... 是占位符
constraints: &str, // 约束串
) -> Ptr<Operation>
底层调 InlineAsmOp::new_convergent,等价于 InlineAsmOp::new(...) 加上 inline_asm_convergent = true 的 BoolAttr。
10.1 三种典型调用
有返回值,无输入(本次):
inline_asm_convergent(
ctx, rewriter,
i32_ty.into(),
vec![],
"mov.u32 $0, 42;",
"=r",
)
无返回值(block 屏障):
inline_asm_convergent(
ctx, rewriter,
void_ty.into(),
vec![],
"membar.cta;",
"~{memory}",
)
有输入有输出(从 shared memory 读):
inline_asm_convergent(
ctx, rewriter,
i32_ty.into(),
vec![ptr_value],
"ld.shared.u32 $0, [$1];",
"=r,r",
)
10.2 返回值的处理
inline_asm_convergent 返回 Ptr<Operation>,它不会替调用方处理原 MIR op——必须自己二选一调 replace_operation 或 erase_operation,这就是上面踩的坑。
11. convergent 到底是什么——正确性开关,不是性能开关
这是个特别容易低估的概念。
GPU 上,一个 warp 内的 32 个线程是锁步执行的。LLVM 优化器(GVN、CSE、loop unswitching、code motion)有时会把指令搬位置——这对普通 CPU 指令没事,但对需要 warp 内所有线程一起执行的指令是灾难。
11.1 典型的 convergent 指令
bar.sync(block 内同步屏障)shfl.sync(warp 内寄存器互换)vote.sync、mbarrier.arrivewgmma(Hopper 异步矩阵乘)
11.2 反例:不加 convergent 会死锁
if (tid % 2 == 0) {
__syncthreads(); // 偶数线程进
} else {
__syncthreads(); // 奇数线程进
}
优化器可能(它实际就会)把两个 __syncthreads() 合并提到 if 外面——它觉得反正都是无副作用的”同一条指令”。
结果:偶数线程的 bar 等不到奇数线程的 bar,整个 block 死锁。
打上 convergent 属性后,LLVM 保证:“不准搬,不准跨控制流合并,不准跨控制流复制。“
11.3 关键认知
convergent 不是性能开关,是正确性开关。
凡是依赖 warp 同步语义的 PTX(屏障、shuffle、vote、async commit/wait),lowering 时必须带 convergent。否则你的 kernel 会偶发死锁或算错——而且因为优化器是非确定性的,bug 可能调试三天也复现不出来。
我们这次 mov.u32 $0, 42 严格说不需要 convergent(它跟 warp 同步无关),但用 inline_asm_convergent 包装是保守且无害的选择。
12. 端到端验证
加上 xxx() 之后,在 example 里调一下:
#[kernel]
pub unsafe fn hello_constant(out: *mut i32) {
let xxx = thread::xxx();
gpu_printf!("thread xxx: {}", xxx);
unsafe { *out = 42 };
}
跑:
cargo oxide run hello_constant
输出会看到 256 个线程各打印一次 thread xxx: 42,因为 for_num_elems(1) 实际启动 256 个线程(for_num_elems(n) 是元素数不是线程数,见 host runtime 那篇 里的陷阱说明)。
13. 整条链路串起来
源码 let xxx = thread::xxx();
│
▼ rustc 桩函数留作 Call,不 inline
MIR Call cuda_device::thread::xxx
│
▼ mir-importer (try_dispatch_intrinsic)
dialect-nvvm ReadPtxXXXOp Verify 检查 1 result × i32
│
▼ mir-lower (convert_ptx_xxx_i32)
dialect-llvm llvm.inline_asm "mov.u32 $0, 42;" "=r" + convergent
│
▼ llvm-export
LLVM IR call i32 asm "mov.u32 $0, 42;", "=r"()
│
▼ llc / NVPTX
PTX mov.u32 %r1, 42;
│
▼ CUDA driver JIT + launch
GPU 每个线程拿到 42
14. 复盘要点
- 桩函数必须
#[inline(never)],否则 rustc 把它 inline 掉,MIR 里看不到 Call,dispatcher 没机会识别 - inline_asm 约束串里没有
$0,只有=r这种——$0是模板占位符,约束串用=r/l/n/~{memory}标类型 - lowering 替换原 op 时:有 use 用
replace_operation,无 use 用erase_operation——pliron 在Operation::erase里 assert,错了直接 panic - convergent 不是性能开关,是正确性开关——GPU 同步语义的 PTX 一定要带
15. 一句话总结
在 cuda-oxide 里加一个自定义 PTX intrinsic 要动五个 crate:cuda-device 写桩函数 + dialect-nvvm 定义 op + mir-importer 注册 FQDN dispatch + mir-lower 写 inline asm lowering + llvm-export 通常自动处理。两个高频踩坑是约束串语法(
=$0不合法,应该=r)和 op 替换 API(有下游 use 必须用replace_operation而不是erase_operation)。convergent 属性是正确性而不是性能开关——凡是 warp 同步语义的 PTX 都必须带,否则可能死锁。
系列上一篇: cuda-oxide:同一段代码在七层 IR 里长什么样