从 Rust 写 inline PTX 模板

会读 PTX 之后,下一步是从 Rust 里写自己的 PTX——这是实现自定义 GPU intrinsic 的必备技能。本文讲清楚 inline asm 的'模板 + 约束串'两段对应关系、临时寄存器作用域、谓词输出 selp 套路、format! 动态生成的陷阱,以及一个'读 shared memory 比对'的完整合成例子。

📚 系列 compiler · 第 7 篇

上一篇讲了怎么读 PTX。读懂之后下一步是自己写 PTX——通过 inline assembly(内联汇编)把一段 PTX 字符串嵌进 Rust 代码,实现自定义 GPU intrinsic。本文从最简单的 mov.u32 $0, 42; 开始,逐步加复杂度。

0. 几个名词先说清楚

缩写 / 术语英文全称中文含义
inline asminline assembly内联汇编在高级语言里直接嵌一段汇编字符串,编译器照搬到输出
clobberclobber list破坏列表内联汇编告诉编译器”我会破坏这些寄存器 / 内存”,别假设它们没变
selpselect if predicate按谓词选择PTX 指令,根据谓词在两个值里选一个
constraintconstraint string约束串告诉编译器每个操作数槽位是输出 / 输入、放哪类寄存器
convergentconvergent收敛性LLVM 属性,标记”这条指令需要同 warp 所有线程协同执行,优化器别拆”
predicate registerpredicate register谓词寄存器PTX 里类型为 .pred 的 1-bit 寄存器,用于条件执行

1. 为什么需要从 Rust 写 PTX

LLVM 自带的 NVVM intrinsic 覆盖大部分常用 GPU 操作(thread index、内存屏障、原子操作等),但总有覆盖不到的:

  • 新硬件指令(Hopper TMA、Blackwell tcgen05 等)LLVM 还没跟上
  • 特定优化形态(用一条 selp 代替 if-else 分支)
  • 跟 cuda-oxide 这类自定义后端集成,把 Rust 函数映射到精确的 PTX 序列

这时候从 Rust 写 inline PTX 模板是唯一办法——编译器自己不会发的指令,你用字符串塞进去。

2. 最简单的一段 inline PTX

cuda-oxide 里写一个返回常量 42 的自定义 intrinsic,核心调用是:

inline_asm_convergent(
    ctx, rewriter,
    i32_ty.into(),              // 返回 i32
    vec![],                     // 无输入
    "mov.u32 $0, 42;",          // ← PTX 字符串模板
    "=r",                       // ← 约束串
)

这段 inline asm 会被 LLVM IR 原封不动嵌进生成的 .ll 文件:

%v = call i32 asm sideeffect "mov.u32 $0, 42;", "=r"() #convergent

然后 NVPTX backend 把它塞进 PTX:

// begin inline asm
mov.u32 %r1, 42;
// end inline asm

记住三个角色:

角色例子
模板字符串"mov.u32 $0, 42;"——$0 $1 ... 是占位符
约束串"=r"——告诉编译器每个占位符是输出还是输入、放哪类寄存器
sideeffectLLVM 自动加,表示有副作用,优化器别动

3. 模板基本规则

  • $0 $1 ... 是占位符,按约束串里的出现顺序映射到操作数
  • 多条 PTX 指令用 ; 分隔,全写在同一个字符串里
  • 临时寄存器作用域用 { ... } 包起来
  • Rust 字符串可以用 \ 续行;如果用 format! 动态生成,PTX 里的 {} 要写成 {{}}

4. 多指令串联

把多条 PTX 用 ; 接起来。一个真实例子(简化):

let asm_template =
    "{ .reg .pred p; mbarrier.test_wait.shared.b64 p, [$1], $2; selp.b32 $0, 1, 0, p; }";

拆开看是三条 PTX:

{ .reg .pred p;                                  // 局部声明 1-bit 谓词寄存器 p
  mbarrier.test_wait.shared.b64 p, [$1], $2;     // 主指令,结果写到 p
  selp.b32 $0, 1, 0, p; }                        // p 为真 → $0=1,否则 $0=0

约束串 "=r,l,l,~{memory}":

  • =r$0 输出 32-bit 寄存器
  • l,l$1$2 输入 64-bit 寄存器
  • ~{memory} → clobber 内存

5. 局部作用域 { ... }

PTX 里 { ... }指令块(block scope):

  • 块内声明的临时寄存器只在块里可见,不污染外层
  • 编译器调度时不会把外面的指令塞进块中间

凡是模板里要用自己声明的临时寄存器,就要包一层 {}

6. 临时寄存器声明 .reg

语法:.reg .类型 名字;

PTX 类型含义
.pred1-bit 谓词(布尔)
.b32 / .b64无类型 32 / 64 位
.b128128 位(Hopper 用)
.f32 / .f64浮点
.u32 / .s32有符号 / 无符号整数(带类型语义)

举一个真实例子(Hopper 上拼一个 128-bit 临时值):

let asm_template = format!(
    "{{ .reg .b128 %resp; mov.b128 %resp, {{$1, $2}}; \
     clusterlaunchcontrol.query_cancel.get_first_ctaid::{}.b32.b128 $0, %resp; }}",
    dim
);

注意三个细节:

  • {{ 是 Rust format!{ 转义,实际 PTX 里就是 {
  • 行尾 \Rust 字符串续行,不是 PTX 语法
  • {$1, $2} 是 PTX 的 vector / pair 语法,把两个 32-bit 拼成 64-bit 对

7. 谓词输出 → 整数(selp 模式)

PTX 的 .pred 寄存器不能直接当 i32 返出去,固定套路:

selp.b32 $0, 1, 0, p;
  • selp = select-if-predicate
  • 谓词 p 为真选第一个值 1,为假选 0
  • 写到 32-bit 输出 $0

调用方在 Rust 端再做 trunc_to_i1 截回 i1,变成 bool

8. 占位符 _ — 显式丢弃输出

PTX 某些指令有返回值但你不想要,用 _ 显式丢弃:

let asm_template = "mbarrier.arrive.release.cluster.shared::cluster.b64 _, [$0];";

mbarrier.arrive 本来有 64-bit 返回值,但我们不需要 → 用 _ 丢弃,约束串里也就不用为它分配输出寄存器。

9. 动态生成 — format!

模板里有 Rust 编译期才知道的常量(dim、shape、addrspace 编号、cluster rank),就用 format!

陷阱:PTX 自带的 {} 要写成 {{}} 来逃出 format! 的占位符。

let asm_template = format!(
    "ld.shared.u{}.b{} $0, [$1];",
    width_in_bits, width_in_bits
);

10. 约束串速查

字符类型备注
r32-bit regi32 / u32 / f32(PTX .b32)
l64-bit regi64 / u64 / 指针
h16-bit regi16 / f16
f32-bit float显式 .f32
d64-bit float显式 .f64
n立即数必须 compile-time const
= 前缀输出单独写在前面
+ 前缀输入兼输出(read-modify-write)同一 SSA 值进出
~{memory}clobber memory写内存就要加
~{$reg}clobber 具体寄存器

排列顺序:输出(=)在前 → 输入 → clobbers$0 对应第一个输出,$1 对应第二个输出或第一个输入,依次类推。

11. 合成例子:读 shared memory 并比对

要写一个”读 shared memory,跟期望值比对,返回 0/1”的 intrinsic:

let asm_template = "{
    .reg .b32 v;
    .reg .pred p;
    ld.shared.u32 v, [$1];
    setp.eq.u32 p, v, $2;
    selp.b32 $0, 1, 0, p;
}";

inline_asm_convergent(
    ctx, rewriter,
    i32_ty.into(),              // $0 输出
    vec![ptr, expected],        // $1, $2 输入
    asm_template,
    "=r,l,r,~{memory}",         // 输出 32-bit, 输入 64-bit ptr + 32-bit 值, clobber memory
)

逐句拆:

PTX含义
.reg .b32 v;32-bit 临时寄存器,存读到的值
.reg .pred p;1-bit 谓词,存比较结果
ld.shared.u32 v, [$1];从 shared memory $1 加载 32-bit 到 v
setp.eq.u32 p, v, $2;p = (v == $2)
selp.b32 $0, 1, 0, p;$0 = p ? 1 : 0

12. 多输出例子

PTX 一条指令可以同时写多个寄存器,比如 mul.wide.u32 同时拿 lo / hi 32-bit。约束串里多个 = 依次对应:

let asm_template = "mul.wide.u32 {$0, $1}, $2, $3;";

inline_asm_convergent(
    ctx, rewriter,
    /* tuple/struct type 或者用 InlineAsmCallOp 支持多结果 */,
    vec![a, b],
    asm_template,
    "=r,=r,r,r",
)

注:cuda-oxide 的 InlineAsmOp(单结果)不支持多输出,多输出要用 InlineAsmCallOp——支持 result tuple 和 tied operands(input/output 绑定),适合 wgmma 这种”输入兼输出”的硬核场景。

13. 调试技巧

PTX 写错了 / 约束串对不上模板,错误不会在 mir-lower 阶段冒出来——一般要等到 llc 编译 LLVM IR → PTX 时才报。所以排查顺序:

  1. 先把生成的 .ll(textual LLVM IR)dump 出来肉眼看 call ... asm "..." "..." () 这一行,确认模板和约束都对
  2. cargo oxide run,看 .ptx 文件里那段 inline asm 长成什么样
  3. cuobjdump --dump-sass 看 SASS,可以确认 driver JIT 没拒绝

LLVM IR 里 inline asm 那一行长这样:

%v = call i32 asm sideeffect "mov.u32 $0, 42;", "=r"() #convergent

sideeffect 是 LLVM 给 inline asm 默认加的(除非显式标 pure),表示有副作用,优化器别动。

14. 坑速记(写 inline PTX 时)

解决
模板里 $N 跟约束串对不上数一下:输出 = 在前,然后输入,最后 clobber——按顺序映射 $0 $1 $2 ...
format! 模板里 PTX 的 {} 出现意外PTX 的 {} 要双写 {{}} 转义
PTX 语法错误编译时不报一般要等 llc 阶段才暴露——先 dump LLVM IR 肉眼检查
bar.sync 被优化器搬位置 → 死锁inline asm 要带 convergent 属性,告诉 LLVM”warp 同步语义,别拆”
selp 返回值是 i32 但 Rust 想要 bool调用方加一步 trunc_to_i1,把 32-bit 截到 1-bit
.pred 寄存器想直接返出去不能,必须先 selp 转成 i32 再 truncate

15. 一句话总结

从 Rust 写 inline PTX 模板的核心是两段字符串的对应——模板字符串里 $0 $1 ... 是占位符,约束串里 =r l 是类型指示。临时寄存器必须包在 { ... } 里并用 .reg .类型 名字; 声明。谓词 .pred 不能直接当 i32 返出去,固定走 selp.b32 $0, 1, 0, p; 套路。format! 拼模板时 PTX 的 {} 要双写 {{}} 转义。写错了一般要等 llc 阶段才报,先 dump LLVM IR 肉眼检查最高效。掌握这套,加新硬件指令(TMA、tcgen05 等)、写自定义 GPU intrinsic 都是同一个套路。

系列上一篇: PTX 入门:读懂 NVIDIA GPU 的虚拟汇编

评论区
评论功能即将上线, 敬请期待。