第一次看 PTX 代码很容易被吓到——
.global.param这些前缀、add.s32这种带类型后缀的指令、%r1 %rd2这种神秘寄存器命名。但 PTX 其实比 x86 简单得多——指令格式统一、寄存器无限、没有微架构包袱。本文给你一张完整的”读 PTX”地图。下一篇(从 Rust 写 inline PTX 模板)讲怎么自己写 PTX。
0. 几个名词先说清楚
| 缩写 / 术语 | 英文全称 | 中文 | 含义 |
|---|---|---|---|
| ISA | Instruction Set Architecture | 指令集架构 | 一个处理器能执行的所有指令的集合 |
| PTX | Parallel Thread Execution | 并行线程执行 | NVIDIA GPU 的虚拟 ISA,文本格式 |
| SASS | Streaming ASSembler | 流式汇编 | NVIDIA GPU 的真实机器码,每代 GPU 不同 |
| virtual ISA | virtual ISA | 虚拟 ISA | 不直接对应物理硬件,由 driver / JIT 翻译成真实机器码的中间层 |
| state space | state space | 状态空间 / 地址空间 | PTX 概念:每个变量 / 指针属于一个特定的”存储空间”(reg / global / shared / local 等) |
| predicate | predicate | 谓词 | PTX 的 1-bit 布尔类型,用于条件执行 |
| JIT | Just-In-Time compilation | 即时编译 | CUDA driver 在 cuModuleLoad 时把 PTX 编成 SASS |
1. PTX 是什么:虚拟 ISA
你写的 CUDA C++ / Rust kernel
│
▼ nvcc / cuda-oxide / Triton 等
hello.ptx (PTX 文本,虚拟 ISA)
│
▼ CUDA driver JIT(在 cuModuleLoad 时)
hello SASS (真实 GPU 机器码)
│
▼ 在 GPU SM 上执行
PTX 类似 Java 字节码——不直接对应任何物理芯片,由 CUDA driver 在加载时编译成当前 GPU 的真实指令(SASS)。
好处:同一份 PTX 可以跑在不同代 GPU 上(向前兼容)。 代价:首次加载有 JIT 开销(几十 ms,有缓存)。
2. PTX 比 x86 简单
PTX 的每条指令都长一个样:
opcode [.modifier1] [.modifier2] ... .type dst, src1, src2, ...;
举例 add.s32 %r1, %r2, %r3;:
add告诉你做什么.s32告诉你类型(signed 32-bit)%r1, %r2, %r3顺序固定:dst 在前,src 在后
跟 x86 比少了什么:
| 项 | x86 | PTX |
|---|---|---|
| 寄存器数量 | 固定 16 个(RAX 等) + 复杂别名(EAX/AX/AL) | 无限(%r1 %r2 %r3 ...,JIT 时映射到物理寄存器) |
| 标志位 | 隐式(CF/ZF/SF/OF 等) | 显式谓词 .pred |
| 指令编码 | 1-15 字节变长,寻址模式爆炸 | 定长 + 单一格式 |
| ABI 包袱 | 32 年历史包袱(实模式、x87、MMX、SSE…) | 无 |
跟 LLVM IR 比多了什么:
| 项 | LLVM IR | PTX |
|---|---|---|
| 寄存器 | SSA value(%v0 %v1,每个值只能定义一次) | 物理寄存器抽象(%r1 可以反复写) |
| 地址空间 | addrspace(N) 修饰指针 | 显式状态空间前缀(.global .shared .local) |
| 类型 | 操作隐含类型(add nsw i32) | 类型后缀显式标在 opcode 上(add.s32) |
3. State Spaces:PTX 最重要的概念
PTX 把”存哪里”作为头等公民——每个变量/指针都属于一个特定的状态空间,每条 load / store 都要指明:
| State space | 含义 | 谁能访问 | 速度 |
|---|---|---|---|
.reg | 寄存器 | 单线程 | 最快 |
.param | kernel 参数(host 传入) | 所有线程只读 | 缓存优化 |
.shared | block 内共享内存(SRAM) | block 内所有线程 | 快 |
.global | 全局显存(cudaMalloc 来的) | 所有线程,所有 grid | 慢,需要合并访问 |
.local | 线程本地(物理上在 global DRAM) | 单线程 | 慢,尽量避免 |
.const | 常量内存 | 所有线程只读 | 缓存优化 |
.tex | 纹理内存 | 只读,硬件缓存 | 特殊用途 |
记住这张表就掌握了 PTX 一半的语义。.local 大量出现 = 编译器把寄存器溢出到内存了,通常是性能警告。
4. 指令格式的修饰符
opcode [.modifier1] [.modifier2] ... .type dst, src1, src2, ...;
| 修饰符类别 | 例子 |
|---|---|
| state space | .global .shared .param |
| 类型 | .u32 .s32 .f32 .b64 |
| 行为修饰 | .wide(乘法宽化).lo .hi(取低/高半) |
| 同步语义 | .release .acquire .relaxed |
| 谓词执行 | @p @!p 前缀(条件执行) |
类型后缀字母
| 字母 | 含义 |
|---|---|
b | bitwise(无符号无类型) |
u | unsigned int |
s | signed int |
f | float |
pred | 谓词(1 bit) |
add.b32 和 add.u32 在加法上等价,但 add.s32 在溢出语义上不同(signed overflow 是 undefined behavior)。mul.lo.u32 取低 32 位,mul.wide.u32 输出 64 位。
5. 寄存器命名约定
PTX 寄存器是无限的——编译器随便声明,JIT 时映射到 GPU 真实寄存器(每个 SM 通常 65536 个 32-bit reg)。
.reg .b32 %r<3>; // 声明 %r1, %r2(<3> 表示 0..3 范围)
.reg .b64 %rd<3>; // 声明 %rd1, %rd2
命名约定:
| 前缀 | 位宽 | 全称 |
|---|---|---|
%rh | 8-bit | half-byte |
%rs | 16-bit | short |
%r | 32-bit | (default) |
%rd | 64-bit | double |
%f | 32-bit float | float |
%fd | 64-bit float | float double |
6. 指令速查(认这些覆盖 90%)
6.1 算术
| 指令 | 含义 |
|---|---|
add.s32 / .u32 / .f32 | 加 |
sub.s32 | 减 |
mul.lo.s32 | 32-bit 乘,取低 32 位 |
mul.wide.u32 | 32×32 → 64-bit |
mad.lo.s32 d, a, b, c | d = a*b + c(乘加融合) |
fma.rn.f32 | 浮点乘加(IEEE 单舍入) |
div.approx.f32 | 浮点除(近似快速版) |
6.2 内存
| 指令 | 含义 |
|---|---|
ld.global.b32 r, [addr] | 全局读 |
st.global.b32 [addr], r | 全局写 |
ld.shared.b32 r, [addr] | 共享内存读 |
st.shared.b32 [addr], r | 共享内存写 |
ld.param.b32 r, [param] | 读 kernel 参数 |
ld.const.b32 r, [addr] | 读常量内存 |
cvta.to.global.u64 d, s | 通用地址 → global |
cvta.global.u64 d, s | global → 通用地址 |
6.3 比较 + 控制流
setp.eq.u32 p, %r1, %r2; // p = (%r1 == %r2),p 是 .pred 寄存器
@p bra LABEL; // 如果 p 为真,跳到 LABEL
@!p bra ELSE_LABEL; // 如果 p 为假
bra UNCOND_LABEL; // 无条件跳
setp 的比较谓词:eq ne lt le gt ge(int 用),lt.f32 le.f32 等(float 用)。
6.4 同步
| 指令 | 含义 |
|---|---|
bar.sync 0 | block 屏障(对应 CUDA C++ 的 __syncthreads()) |
membar.cta | block 范围内存屏障 |
membar.gl | device 范围内存屏障 |
membar.sys | system 范围内存屏障 |
shfl.sync.b32 | warp 内寄存器互换 |
6.5 特殊寄存器(thread / block 索引)
mov.u32 %r, %tid.x; // threadIdx.x
mov.u32 %r, %tid.y;
mov.u32 %r, %ctaid.x; // blockIdx.x (CTA = Cooperative Thread Array = block)
mov.u32 %r, %ntid.x; // blockDim.x
mov.u32 %r, %nctaid.x; // gridDim.x
6.6 类型转换
| 指令 | 含义 |
|---|---|
cvt.u32.u64 d, s | u64 → u32(截断) |
cvt.u64.u32 d, s | u32 → u64(零扩展) |
cvt.s32.f32 d, s | float → int |
cvt.f32.s32 d, s | int → float |
cvt.rn.f32.f64 d, s | f64 → f32(最近舍入) |
selp.b32 d, a, b, p | d = p ? a : b(谓词选择) |
7. 实战:5 行 PTX 逐行拆
Rust 源码就一行:*out = ins + 1;,编出来的 PTX 5 条指令 + ret:
ld.param.b32 %r1, [hello_kernel_param_0]; // 从 param 空间读 ins
ld.param.b64 %rd2, [hello_kernel_param_1]; // 从 param 空间读 out(指针)
cvta.to.global.u64 %rd1, %rd2; // 通用指针 → global 指针
add.s32 %r2, %r1, 1; // %r2 = %r1 + 1
st.global.b32 [%rd1], %r2; // 写 *out
ret;
逐条拆
ld.param.b32 %r1, [hello_kernel_param_0]
| 字段 | 含义 |
|---|---|
ld | load |
.param | 从 parameter space 读 |
.b32 | 32-bit |
%r1 | 目的寄存器 |
[...] | 源地址(中括号 = 解引用) |
ld.param.b64 %rd2, [hello_kernel_param_1] —— 同理,但 64-bit 指针。
cvta.to.global.u64 %rd1, %rd2
cvta = ConVerT Address。.to.global 表示地址空间转换:host 传进来的指针是通用地址(generic),要先转成 .global 才能用 st.global 写。Pascal+ JIT 通常会优化掉这条,但 PTX 文本必须显式写。
add.s32 %r2, %r1, 1
%r2 = %r1 + 1,操作数顺序 dst, src1, src2。.s32 = signed 32-bit。
st.global.b32 [%rd1], %r2
| 字段 | 含义 |
|---|---|
st | store |
.global | 写到 global memory |
.b32 | 32-bit |
[%rd1] | 目的地址 |
%r2 | 要写的值 |
ret 函数返回。
8. 学习节奏建议
- 先把第 7 节这 5 行 PTX 能默写出来 —— 算是 PTX 的”hello world”
- 改源码,看 PTX 怎么变 —— 把
ins + 1改成ins * 2,看add.s32变成mul.lo.s32 - 加一个
threadIdx.x,看mov.u32 %r, %tid.x怎么出现的 - 改成读 shared memory(声明一个
__shared__数组),认.shared这个 state space - 会读 PTX 之后,开始学怎么写自己的 PTX——见下一篇 从 Rust 写 inline PTX 模板
每一步只引入 1-2 个新指令,写下来比读十倍文档管用。
9. 进阶资源
- PTX ISA 官方文档(字典用,不要通读):https://docs.nvidia.com/cuda/parallel-thread-execution/
- Compiler Explorer(godbolt.org):选 nvcc,左边写 CUDA C++,右边看 PTX。改一行看输出,最快反馈循环
- CUDA Binary Utilities(
cuobjdump --dump-sass看真实 SASS):https://docs.nvidia.com/cuda/cuda-binary-utilities/
10. 坑速记(读 PTX 时)
| 坑 | 解决 |
|---|---|
.local 出现一大堆 store/load | 栈溢出到内存了,看是不是寄存器用太多 / 取了栈变量地址 |
cvta.to.global 一直去不掉 | Pascal 之前 SASS JIT 不优化,Pascal+(sm_60+)才会去掉 |
selp.b32 ... p 这种模式频繁出现 | 这是把 .pred 转 i32 的固定套路 |
mul.lo 和 mul.wide 搞混 | .lo 同位宽,.wide 宽化(32×32→64) |
大量带 @p @!p 前缀的指令 | 谓词执行,相当于”如果 p 为真才执行这条” |
11. 一句话总结
PTX 是 NVIDIA GPU 的虚拟 ISA,文本格式,由 CUDA driver 在加载时 JIT 编成真正的 GPU 机器码(SASS)。它比 x86 简单——指令格式统一(
opcode.type dst, src1, src2)、寄存器无限、没有微架构包袱;比 LLVM IR 多了状态空间(.global/.shared/.param等)和显式类型后缀。state space 是最重要的概念,记住七个空间就掌握了一半语义。从一段 5 行的 hello_kernel PTX 开始,改源码看输出变化,比读文档管用十倍。