PTX 入门:读懂 NVIDIA GPU 的虚拟汇编

PTX 是 NVIDIA GPU 的虚拟 ISA,介于 LLVM IR 和真正的 GPU 机器码 SASS 之间。它跟 x86 完全不同——指令格式统一、寄存器无限、地址空间显式。本文讲清楚 PTX 指令格式、state space(地址空间)、常用指令族,最后用一段 5 行的 hello_kernel PTX 逐行拆解。

📚 系列 compiler · 第 6 篇

第一次看 PTX 代码很容易被吓到——.global .param 这些前缀、add.s32 这种带类型后缀的指令、%r1 %rd2 这种神秘寄存器命名。但 PTX 其实比 x86 简单得多——指令格式统一、寄存器无限、没有微架构包袱。本文给你一张完整的”读 PTX”地图。下一篇(从 Rust 写 inline PTX 模板)讲怎么自己写 PTX。

0. 几个名词先说清楚

缩写 / 术语英文全称中文含义
ISAInstruction Set Architecture指令集架构一个处理器能执行的所有指令的集合
PTXParallel Thread Execution并行线程执行NVIDIA GPU 的虚拟 ISA,文本格式
SASSStreaming ASSembler流式汇编NVIDIA GPU 的真实机器码,每代 GPU 不同
virtual ISAvirtual ISA虚拟 ISA不直接对应物理硬件,由 driver / JIT 翻译成真实机器码的中间层
state spacestate space状态空间 / 地址空间PTX 概念:每个变量 / 指针属于一个特定的”存储空间”(reg / global / shared / local 等)
predicatepredicate谓词PTX 的 1-bit 布尔类型,用于条件执行
JITJust-In-Time compilation即时编译CUDA driver 在 cuModuleLoad 时把 PTX 编成 SASS

1. PTX 是什么:虚拟 ISA

你写的 CUDA C++ / Rust kernel

        ▼ nvcc / cuda-oxide / Triton 等
hello.ptx (PTX 文本,虚拟 ISA)

        ▼ CUDA driver JIT(在 cuModuleLoad 时)
hello SASS (真实 GPU 机器码)

        ▼ 在 GPU SM 上执行

PTX 类似 Java 字节码——不直接对应任何物理芯片,由 CUDA driver 在加载时编译成当前 GPU 的真实指令(SASS)。

好处:同一份 PTX 可以跑在不同代 GPU 上(向前兼容)。 代价:首次加载有 JIT 开销(几十 ms,有缓存)。

2. PTX 比 x86 简单

PTX 的每条指令都长一个样:

opcode  [.modifier1] [.modifier2] ... .type   dst, src1, src2, ...;

举例 add.s32 %r1, %r2, %r3;:

  • add 告诉你做什么
  • .s32 告诉你类型(signed 32-bit)
  • %r1, %r2, %r3 顺序固定:dst 在前,src 在后

跟 x86 比少了什么:

x86PTX
寄存器数量固定 16 个(RAX 等) + 复杂别名(EAX/AX/AL)无限(%r1 %r2 %r3 ...,JIT 时映射到物理寄存器)
标志位隐式(CF/ZF/SF/OF 等)显式谓词 .pred
指令编码1-15 字节变长,寻址模式爆炸定长 + 单一格式
ABI 包袱32 年历史包袱(实模式、x87、MMX、SSE…)

跟 LLVM IR 比多了什么:

LLVM IRPTX
寄存器SSA value(%v0 %v1,每个值只能定义一次)物理寄存器抽象(%r1 可以反复写)
地址空间addrspace(N) 修饰指针显式状态空间前缀(.global .shared .local)
类型操作隐含类型(add nsw i32)类型后缀显式标在 opcode 上(add.s32)

3. State Spaces:PTX 最重要的概念

PTX 把”存哪里”作为头等公民——每个变量/指针都属于一个特定的状态空间,每条 load / store 都要指明:

State space含义谁能访问速度
.reg寄存器单线程最快
.paramkernel 参数(host 传入)所有线程只读缓存优化
.sharedblock 内共享内存(SRAM)block 内所有线程
.global全局显存(cudaMalloc 来的)所有线程,所有 grid慢,需要合并访问
.local线程本地(物理上在 global DRAM)单线程慢,尽量避免
.const常量内存所有线程只读缓存优化
.tex纹理内存只读,硬件缓存特殊用途

记住这张表就掌握了 PTX 一半的语义.local 大量出现 = 编译器把寄存器溢出到内存了,通常是性能警告。

4. 指令格式的修饰符

opcode  [.modifier1] [.modifier2] ... .type   dst, src1, src2, ...;
修饰符类别例子
state space.global .shared .param
类型.u32 .s32 .f32 .b64
行为修饰.wide(乘法宽化).lo .hi(取低/高半)
同步语义.release .acquire .relaxed
谓词执行@p @!p 前缀(条件执行)

类型后缀字母

字母含义
bbitwise(无符号无类型)
uunsigned int
ssigned int
ffloat
pred谓词(1 bit)

add.b32add.u32 在加法上等价,但 add.s32 在溢出语义上不同(signed overflow 是 undefined behavior)。mul.lo.u32 取低 32 位,mul.wide.u32 输出 64 位。

5. 寄存器命名约定

PTX 寄存器是无限的——编译器随便声明,JIT 时映射到 GPU 真实寄存器(每个 SM 通常 65536 个 32-bit reg)。

.reg .b32 %r<3>;     // 声明 %r1, %r2(<3> 表示 0..3 范围)
.reg .b64 %rd<3>;    // 声明 %rd1, %rd2

命名约定:

前缀位宽全称
%rh8-bithalf-byte
%rs16-bitshort
%r32-bit(default)
%rd64-bitdouble
%f32-bit floatfloat
%fd64-bit floatfloat double

6. 指令速查(认这些覆盖 90%)

6.1 算术

指令含义
add.s32 / .u32 / .f32
sub.s32
mul.lo.s3232-bit 乘,取低 32 位
mul.wide.u3232×32 → 64-bit
mad.lo.s32 d, a, b, cd = a*b + c(乘加融合)
fma.rn.f32浮点乘加(IEEE 单舍入)
div.approx.f32浮点除(近似快速版)

6.2 内存

指令含义
ld.global.b32 r, [addr]全局读
st.global.b32 [addr], r全局写
ld.shared.b32 r, [addr]共享内存读
st.shared.b32 [addr], r共享内存写
ld.param.b32 r, [param]读 kernel 参数
ld.const.b32 r, [addr]读常量内存
cvta.to.global.u64 d, s通用地址 → global
cvta.global.u64 d, sglobal → 通用地址

6.3 比较 + 控制流

setp.eq.u32 p, %r1, %r2;   // p = (%r1 == %r2),p 是 .pred 寄存器
@p bra LABEL;              // 如果 p 为真,跳到 LABEL
@!p bra ELSE_LABEL;        // 如果 p 为假
bra UNCOND_LABEL;          // 无条件跳

setp 的比较谓词:eq ne lt le gt ge(int 用),lt.f32 le.f32 等(float 用)。

6.4 同步

指令含义
bar.sync 0block 屏障(对应 CUDA C++ 的 __syncthreads())
membar.ctablock 范围内存屏障
membar.gldevice 范围内存屏障
membar.syssystem 范围内存屏障
shfl.sync.b32warp 内寄存器互换

6.5 特殊寄存器(thread / block 索引)

mov.u32 %r, %tid.x;     // threadIdx.x
mov.u32 %r, %tid.y;
mov.u32 %r, %ctaid.x;   // blockIdx.x (CTA = Cooperative Thread Array = block)
mov.u32 %r, %ntid.x;    // blockDim.x
mov.u32 %r, %nctaid.x;  // gridDim.x

6.6 类型转换

指令含义
cvt.u32.u64 d, su64 → u32(截断)
cvt.u64.u32 d, su32 → u64(零扩展)
cvt.s32.f32 d, sfloat → int
cvt.f32.s32 d, sint → float
cvt.rn.f32.f64 d, sf64 → f32(最近舍入)
selp.b32 d, a, b, pd = p ? a : b(谓词选择)

7. 实战:5 行 PTX 逐行拆

Rust 源码就一行:*out = ins + 1;,编出来的 PTX 5 条指令 + ret:

ld.param.b32  %r1,  [hello_kernel_param_0];   // 从 param 空间读 ins
ld.param.b64  %rd2, [hello_kernel_param_1];   // 从 param 空间读 out(指针)
cvta.to.global.u64 %rd1, %rd2;                // 通用指针 → global 指针
add.s32       %r2,  %r1, 1;                   // %r2 = %r1 + 1
st.global.b32 [%rd1], %r2;                    // 写 *out
ret;

逐条拆

ld.param.b32 %r1, [hello_kernel_param_0]

字段含义
ldload
.param从 parameter space 读
.b3232-bit
%r1目的寄存器
[...]源地址(中括号 = 解引用)

ld.param.b64 %rd2, [hello_kernel_param_1] —— 同理,但 64-bit 指针。

cvta.to.global.u64 %rd1, %rd2

cvta = ConVerT Address。.to.global 表示地址空间转换:host 传进来的指针是通用地址(generic),要先转成 .global 才能用 st.global 写。Pascal+ JIT 通常会优化掉这条,但 PTX 文本必须显式写。

add.s32 %r2, %r1, 1

%r2 = %r1 + 1,操作数顺序 dst, src1, src2.s32 = signed 32-bit。

st.global.b32 [%rd1], %r2

字段含义
ststore
.global写到 global memory
.b3232-bit
[%rd1]目的地址
%r2要写的值

ret 函数返回。

8. 学习节奏建议

  1. 先把第 7 节这 5 行 PTX 能默写出来 —— 算是 PTX 的”hello world”
  2. 改源码,看 PTX 怎么变 —— 把 ins + 1 改成 ins * 2,看 add.s32 变成 mul.lo.s32
  3. 加一个 threadIdx.x,看 mov.u32 %r, %tid.x 怎么出现的
  4. 改成读 shared memory(声明一个 __shared__ 数组),认 .shared 这个 state space
  5. 会读 PTX 之后,开始学怎么写自己的 PTX——见下一篇 从 Rust 写 inline PTX 模板

每一步只引入 1-2 个新指令,写下来比读十倍文档管用

9. 进阶资源

10. 坑速记(读 PTX 时)

解决
.local 出现一大堆 store/load栈溢出到内存了,看是不是寄存器用太多 / 取了栈变量地址
cvta.to.global 一直去不掉Pascal 之前 SASS JIT 不优化,Pascal+(sm_60+)才会去掉
selp.b32 ... p 这种模式频繁出现这是把 .pred 转 i32 的固定套路
mul.lomul.wide 搞混.lo 同位宽,.wide 宽化(32×32→64)
大量带 @p @!p 前缀的指令谓词执行,相当于”如果 p 为真才执行这条”

11. 一句话总结

PTX 是 NVIDIA GPU 的虚拟 ISA,文本格式,由 CUDA driver 在加载时 JIT 编成真正的 GPU 机器码(SASS)。它比 x86 简单——指令格式统一(opcode.type dst, src1, src2)、寄存器无限、没有微架构包袱;比 LLVM IR 多了状态空间(.global / .shared / .param 等)和显式类型后缀。state space 是最重要的概念,记住七个空间就掌握了一半语义。从一段 5 行的 hello_kernel PTX 开始,改源码看输出变化,比读文档管用十倍。

系列上一篇: mem2reg:从 alloca-load-store 回到 SSA

评论区
评论功能即将上线, 敬请期待。