第一次读 MLIR 资料经常被一大堆方言名字砸晕——
linalg、tosa、affine、scf、memref、vector、arith、llvm等等。“高层方言”、“低层方言”两个词到处出现,但没人正式定义过它们的边界。其实判断标准只有一条:这个方言在表达”什么”,还是表达”怎么做”?
1. 一条判断标准
它离”人怎么想”近,还是离”机器怎么执行”近?
| 维度 | 高层方言 | 低层方言 |
|---|---|---|
| 关注点 | 表达意图(做什么) | 表达机制(怎么做) |
| 数据 | 值语义(tensor,不可变,像数学) | 引用语义(memref,有地址,有读写) |
| 控制流 | 隐式或结构化(整个张量运算) | 显式循环、分支、跳转 |
| 内存 | 没有”内存”概念 | 有 alloc、load、store |
| 粒度 | 粗(一个 op = 矩阵乘) | 细(一个 op = 一次加法) |
| 离硬件 | 远 | 近 |
2. 用同一个矩阵乘串起来:从一行变成几十行
最直观的方式:看同一个语义在不同层次的 MLIR 长什么样。
2.1 最高层(linalg / tosa):一条 op 表达整个矩阵乘
%C = linalg.matmul ins(%A, %B : tensor<128x256xf32>, tensor<256x64xf32>)
outs(%C0 : tensor<128x64xf32>) -> tensor<128x64xf32>
- 用
tensor类型,不可变、无地址(像数学里的矩阵) - 一条指令表达整个矩阵乘的语义
- 完全不关心循环顺序、tiling、向量化、内存布局——那是后端的事
2.2 中层(affine / scf + memref):显式循环 + 内存读写
scf.for %i = %c0 to %c128 step %c1 {
scf.for %j = %c0 to %c64 step %c1 {
scf.for %k = %c0 to %c256 step %c1 {
%a = memref.load %A[%i, %k] : memref<128x256xf32>
%b = memref.load %B[%k, %j] : memref<256x64xf32>
%c = memref.load %C[%i, %j] : memref<128x64xf32>
%mul = arith.mulf %a, %b : f32
%add = arith.addf %c, %mul : f32
memref.store %add, %C[%i, %j] : memref<128x64xf32>
}
}
}
memref是带地址的内存引用,可以 load/store- 循环结构显式写出来了
- 还是结构化的(
scf.for而非裸跳转)
2.3 低层(llvm 方言):接近 LLVM IR
llvm.br ^bb1(%c0 : i64)
^bb1(%i: i64):
%cond = llvm.icmp "slt" %i, %c128 : i64
llvm.cond_br %cond, ^bb2, ^bb3
^bb2:
...裸指针、getelementptr、load、store、跳转...
- 控制流退化成基本块 + 条件跳转
- 内存操作变成裸指针、
getelementptr - 几乎就是 LLVM IR 的 MLIR 封装,离机器码一步之遥
同一个矩阵乘,从一行变成几十上百行——这就是”降级(lowering)“。
3. 常见方言的层次定位
高层(贴近用户/算法)
├── tosa 张量算子集,接近神经网络层面
├── stablehlo XLA 的 IR,机器学习用
├── tf TensorFlow 算子
├── torch PyTorch 算子
├── linalg 结构化的张量操作(matmul、conv、generic)
中层(贴近通用编译)
├── tensor 张量值操作(没有内存)
├── memref 带内存的多维数组
├── affine 多面体模型友好的循环(可分析)
├── scf 结构化控制流(for / if / while)
├── vector SIMD 向量操作
├── arith 普通算术加减乘除
├── math 数学函数(sin、exp、log)
低层(贴近硬件)
├── gpu GPU 编程抽象(kernel、thread、block)
├── nvgpu/nvvm NVIDIA 专属
├── amdgpu/rocdl AMD 专属
├── spirv Vulkan / OpenCL 的 SPIR-V
├── llvm 直接对应 LLVM IR
4. 一眼区分的”信号”
看到一段 MLIR,这些信号能帮你判断它在哪一层。
4.1 “这是高层”的信号
- 用
tensor<...>类型(不可变值语义) - 一条 op 做”很多事”(
linalg.matmul、tosa.conv2d) - 没有显式
for循环 - 没有
load/store
4.2 “这是中层”的信号
- 用
memref<...>类型(有内存地址) - 有
scf.for/affine.for循环 - 有
memref.load/memref.store arith/math这种细粒度操作
4.3 “这是低层”的信号
- 大量
llvm.xxx操作 - 出现
^bb0、^bb1这种裸基本块标签 +llvm.br跳转 llvm.getelementptr、裸指针!llvm.ptr- GPU 方言的
gpu.thread_id、gpu.block_id之类
5. 为什么要分层——Progressive Lowering
MLIR 的杀手锏就是这种渐进式降级:
| 层次 | 最擅长的优化 | 为什么这一层才能做 |
|---|---|---|
| 高层 | 算子融合、layout 选择、常量折叠 | 粒度粗、信息全——能看到”这是个 conv 后面接 ReLU”,才能把它们融合 |
| 中层 | 循环 tiling、向量化、并行化、bufferize | 能看到循环结构——scf.for 显式写出来,才能做循环变换 |
| 低层 | 寄存器分配、指令选择 | 接近硬件——交给 LLVM,各 backend(x86 / GPU / RISC-V)各自玩 |
每一层都在自己最擅长的抽象层次做优化,然后把简化后的形式交给下一层。这就是 MLIR 比传统单层 IR(只有 LLVM IR)强大的地方——信息不会在一开始就被丢掉。
LLVM IR 那种”早期就退化成基本块 + 跳转”的设计,在做高层优化时是束手无策的——它根本不知道有”matmul”这个概念,所有结构都已经被翻译成 add / mul / br 了。MLIR 让你在合适的抽象层次做合适的优化。
6. 类比:pliron / cuda-oxide 的方言栈
pliron 是 Rust 写的 MLIR-like 框架,cuda-oxide 用它实现了类似的多方言栈:
| pliron 方言 | 对应 MLIR 层次 | 作用 |
|---|---|---|
dialect-mir | 中层(类比 scf + memref) | 保留 Rust 语义(move、borrow、drop) |
dialect-nvvm | 低层(类比 nvgpu / nvvm) | NVIDIA GPU 专属 intrinsic |
dialect-llvm | 低层(类比 llvm 方言) | 直接对应 LLVM IR |
整条 lowering 链路:MIR → dialect-mir → dialect-llvm → LLVM IR → PTX。这是个典型的 progressive lowering 例子——每一层降一点,只在合适的层做对应优化(比如 mem2reg 在 dialect-mir 上跑,bank conflict 检测如果做的话会在 dialect-nvvm 层)。
7. 学习建议
如果是从零学 MLIR,先重点掌握中层(scf、memref、arith、affine)。原因:
- 大多数教学例子和 pass 都在中层
- 它跟”传统编译器中间表示”心智模型最接近,Rust 程序员、C++ 程序员都能秒懂
- 理解了中层,往两头扩展最高效:往上看 linalg 怎么 lowering 到 scf,往下看 scf 怎么变成 llvm
直接从 linalg / tosa 入门容易迷失(太抽象,不知道这些 op 真正在算什么),从 llvm 方言入门又看不到 MLIR 的杀手锏(那就是 LLVM IR 而已,何必学 MLIR)。
8. 一句话总结
MLIR 方言的高低不看名字,看 op 表达的抽象层次:用
tensor和粗粒度算子的是高层,用memref和显式循环的是中层,用裸基本块和llvm.*的是低层。MLIR 的核心玩法就是从高层一步步降级到低层,每层做该层最擅长的优化。