cuda-oxide:hello-constant 拆解 08——mir-lower:dialect conversion driver

拆 hello-constant 系列第八站。dialect-mir + dialect-nvvm 的混合 IR 整体降级到 dialect-llvm,靠的是 pliron 的 DialectConversion 框架——你声明'我能转换什么 op、什么类型、怎么转换',driver 自动遍历 module。本文讲清楚 trait 分发机制、跟 Step 7 字符串 match 的对比、replace_operation 的关键作用、full conversion 模式,以及降级完之后 IR 真实长什么样。

📚 系列 cuda-oxide · 第 13 篇

上一篇(块尾控制流 + intrinsic dispatch)完成了 dialect-mir + dialect-nvvm 的混合 IR。这一篇看 mir-lower 把它整体降级到 dialect-llvm——一个全新方言,跟 LLVM IR 同构。重点是 pliron 的 DialectConversion 框架怎么用、为什么这里改用 trait 分发而不是字符串 match,以及降级完之后 IR 真实长什么样。

0. 几个名词先说清楚

缩写 / 术语英文全称中文含义
dialectdialect(方言)方言MLIR / pliron 概念,一组相关 op + type + attribute 的集合
loweringlowering(降级)降级 / 下沉把高层 IR 翻译成低层 IR 的过程,信息逐步丢失但更接近硬件
DialectConversionDialectConversion(方言转换)方言转换框架pliron 提供的 driver,你声明转换规则,它负责遍历 IR 并应用
op interfaceop interface(操作接口)op 接口pliron 概念,类似 Rust trait——一组 op 可以共同实现的方法
op_castop cast(操作类型转换)op 类型转换把一个 op 当成某个 interface 的实现来调用,类似 dyn Trait
signlesssignless(无符号性)无符号性LLVM IR 风格:整数类型没有”有符号 / 无符号”标签,有符号性由操作决定
partial / full conversionpartial / full conversion部分 / 完全转换框架两种模式:partial 允许部分 op 没转换,full 要求全部转完

1. 输入和输出

输入:dialect-mir + dialect-nvvm 混合的 module


输出:dialect-llvm 唯一方言的 module

具体看一对例子(后面会展开):

                                       
输入(dialect-mir / dialect-nvvm 混合):       输出(dialect-llvm):
  v11 = nvvm.read_ptx_xxx ()           ──►   v49 = llvm.inline_asm
  mir.store (v2, v11)                  ──►   llvm.store *v2 <- v49
  mir.goto ^bb1                        ──►   llvm.br ^bb1

module 的结构不变(block 划分、控制流、参数布局都保留),只是每个具体 op 换成了 dialect-llvm 里对应的形态。

2. pliron 的 DialectConversion 框架

mir-lower 没写 for op in module.walk() { ... } 这种循环。而是实现 pliron 提供的 DialectConversion trait,把 module 交给框架的 driver 自动遍历:

pub fn lower_mir_to_llvm(ctx: &mut Context, module_op: Ptr<Operation>) -> Result<()> {
    let mut conversion = MirToLlvmConversionDriver {
        shared_globals: HashMap::new(),
        device_globals: HashMap::new(),
        dynamic_smem_alignments: HashMap::new(),
    };
    apply_dialect_conversion(ctx, &mut conversion, module_op)   // ← driver 接管遍历
}

DialectConversion trait 三个核心方法:

impl DialectConversion for MirToLlvmConversionDriver {
    fn can_convert_op(&self, ctx, op) -> bool {
        is_mir_or_nvvm_op(ctx, op)                  // ① 我能转换哪些 op
    }
    
    fn can_convert_type(&self, ctx, ty) -> bool {
        is_signed_or_unsigned_int(ty) || ...        // ② 我能转换哪些类型
    }
    
    fn convert_type(&mut self, ctx, ty) -> Result<Ptr<TypeObj>> {
        convert_type(ctx, ty)                       // ③ 类型怎么转
    }
    
    fn rewrite(&mut self, ctx, rewriter, op, operands_info) -> Result<()> {
        // ④ op 怎么转
        let opid = Operation::get_opid(op, ctx);
        // ... 分发到具体的 convert_* 函数
    }
}

apply_dialect_conversion driver 干的事:

  1. 遍历 module 里的所有 op
  2. 对每个 op 调 can_convert_op 决定要不要处理
  3. 处理时调 rewrite,我们生成新 op 替换旧的
  4. 全部转换完后清理:所有 dialect-mir / dialect-nvvm 的 op 应该都不剩了

这是 pliron 跟 MLIR 共享的设计——“dialect conversion pass” 抽象。

3. rewrite 的双层分发

rewrite 方法用了两层分发机制:

fn rewrite(&mut self, ctx, rewriter, op, operands_info) -> Result<()> {
    let opid = Operation::get_opid(op, ctx);
    
    // ── 第一层:特殊 case(需要 driver 状态的 op) ──
    if opid == MirFuncOp::get_opid_static() {
        return convert_func(ctx, rewriter, op, ..., &mut self.shared_globals, ...);
    }
    if opid == MirSharedAllocOp::get_opid_static() {
        return convert_shared_alloc_dc(ctx, rewriter, op, ..., &mut self.shared_globals);
    }
    // ... 几个其它特殊 case ...
    
    // ── 第二层:通用 op_cast 分发 ──
    let op_obj = Operation::get_op_dyn(op, ctx);
    let Some(converter) = op_cast::<dyn MirToLlvmConversion>(op_obj.as_ref()) else {
        return error!("Unsupported MIR/NVVM op for lowering");
    };
    converter.convert(ctx, rewriter, operands_info)
}

第一层(特殊 case):少数 op 需要 driver 上下文里的可变状态(shared_globals 这种全局表),手动 if 分支处理。

第二层(op_cast 分发):绝大多数 op 通过 MirToLlvmConversion 这个 op interface(操作接口,pliron 的 trait 概念)自己声明 lowering 逻辑。op_cast::<dyn MirToLlvmConversion> 把当前 op 强转成 interface trait object,调它的 convert 方法。

类比 Rust 的 dyn Trait:

Rust 概念pliron 对应
trait MirToLlvmConversion { fn convert(...) }op interface 声明
impl MirToLlvmConversion for MyOp { ... }per-op interface 实现
let trait_obj: &dyn MirToLlvmConversion = &op;op_cast::<dyn MirToLlvmConversion>(op)

好处:新增 op 只需要 impl MirToLlvmConversion for NewOp { ... },不用动 rewrite 函数——开闭原则(对扩展开放、对修改关闭)。

4. MirToLlvmConversion interface 在哪里实现

crates/mir-lower/src/convert/interface_impls.rs 集中放了所有 op 的 interface 实现,一个文件几百个 impl 块:

#[op_interface_impl]
impl MirToLlvmConversion for MirAllocaOp {
    fn convert(&self, ctx, rewriter, operands_info) -> Result<()> {
        // mir.alloca → llvm.alloca
        super::convert::ops::memory::convert_alloca(...)
    }
}

#[op_interface_impl]
impl MirToLlvmConversion for MirStoreOp {
    fn convert(&self, ...) -> Result<()> {
        // mir.store → llvm.store
        super::convert::ops::memory::convert_store(...)
    }
}

#[op_interface_impl]
impl MirToLlvmConversion for ReadPtxSregTidXOp {
    fn convert(&self, ctx, rewriter, operands_info) -> Result<()> {
        // nvvm.read_ptx_sreg_tid_x → llvm.call @llvm.nvvm.read.ptx.sreg.tid.x()
        super::intrinsics::basic::convert_read_tid_x(...)
    }
}

#[op_interface_impl]
impl MirToLlvmConversion for ReadPtxXXXOp {                  // ← 你写的这个
    fn convert(&self, ctx, rewriter, operands_info) -> Result<()> {
        // nvvm.read_ptx_xxx → llvm.inline_asm "mov.u32 $0, 42;" "=r"
        super::intrinsics::basic::convert_ptx_xxx_i32(...)
    }
}

每个 impl 通常只是一行——调一个 convert/intrinsics/*.rs 里的具体 helper 函数。这种分层让 interface_impls.rs 保持”调度表”形态,真正逻辑在专用文件。

5. 两种分发方式的对比

Step 7(importer)那边的 intrinsic dispatch:

match name {       // ⬅ 一个大 match,FQDN 字符串 match
    "cuda_device::xxx" | "cuda_device::thread::xxx" => { ... emit ReadPtxXXXOp ... }
    "cuda_device::threadIdx_x"                       => { ... emit ReadPtxSregTidXOp ... }
    // ...
}

Step 8(mir-lower)这边:

impl MirToLlvmConversion for ReadPtxXXXOp { fn convert(...) { ... } }     // op 类型分发
impl MirToLlvmConversion for ReadPtxSregTidXOp { fn convert(...) { ... } }
// ...

为什么 Step 7 用字符串 match,Step 8 用 trait 分发?

阶段输入分发依据
Step 7(importer)Rust 函数调用字符串 FQDN——函数名是源码层概念,没法挂到 op 上
Step 8(lower)pliron opop 类型——每个 op 都是独立 struct,直接挂 trait impl

Step 7 的输入还没”op 化”,Step 8 已经是 op 化的。到 Step 8 这一层,“每个 op 自己负责怎么 lower”是更自然的设计

6. 类型转换:signed / unsigned → signless

can_convert_type 这一段:

fn can_convert_type(&self, ctx, ty) -> bool {
    if let Some(int_ty) = ty.downcast_ref::<IntegerType>() {
        return int_ty.signedness() != Signedness::Signless;
    }
    type_impls::<dyn MirConvertibleType>(&**ty_ref)
}

为什么要这个转换?LLVM IR 没有”signed / unsigned”概念,只有”i32”、“i64” 等位宽。有符号性是在操作里体现的(add nsw(no signed wrap,无有符号溢出) vs add nuw(no unsigned wrap,无无符号溢出))。

dialect-mir 里 i32u32 是两个不同类型(带 signedness 信息),lower 到 dialect-llvm 时都要变成同一个 signless i32convert_type 就做这个 normalize(规范化):

dialect-mir:    si32(有符号) | ui32(无符号)
                     │              │
                     ▼              ▼
dialect-llvm:        i32 (signless) ← 同一个类型

这跟 LLVM IR 本身的设计哲学一致——把”有符号 / 无符号”信息留给操作而不是类型。

7. 降级完之后真实长什么样

跑一遍 hello_constant,降级完的 dialect-llvm IR 实际输出长这样:

7.1 hello_kernel(简单版,最直观)

llvm.func @hello_kernel: llvm.func <llvm.void (i32, llvm.ptr addrspace(0)) variadic = false>
  [gpu_kernel: "true"] 
{
  ^entry_block7v1(v84: i32, v85: llvm.ptr addrspace(0)):
    llvm.br ^block5v1(v84, v85)

  ^block5v1(v40: i32, v41: llvm.ptr addrspace(0)):
    v86 = llvm.constant <1: i32> : i32
    v87 = llvm.add v40, v86 <{nsw=false,nuw=false}>: i32
    llvm.store *v41 <- v87
    llvm.return 
}

这就是 mem2reg + lowering 的极致结果——*out = ins + 1 干干净净:常量 1、加法、存储、返回。10 个原本的 alloca 槽全消失了

注意:

  • llvm.constant / llvm.add / llvm.store / llvm.return 都是 dialect-llvm 的标准 op
  • <{nsw=false,nuw=false}> 是 add 的属性(nsw / nuw 是 LLVM IR 标准 flag)
  • llvm.ptr addrspace(0) 表示 0 号地址空间(默认空间)的指针——后面 NVPTX backend 会推断成 .global.shared
  • [gpu_kernel: "true"] 这个 attribute 从 dialect-mir 一路保留下来——后面 llvm-export 看这个标记决定输出 .entry 还是 .func

7.2 hello_constant(复杂版,含 inline asm + printf)

hello_constant 降级完的关键片段:

^block2v1(v0: llvm.ptr addrspace(0)):
    v49 = llvm.inline_asm                        ← thread::xxx() 变成内联汇编
    llvm.br ^block3v1()

^block3v1():
    v50 = llvm.zext <nneg=false> v49 to i64       ← u32 零扩展到 i64
    v51 = llvm.undef : llvm.struct<{ i64 }>       ← 创建一个未定义的 struct
    v52 = llvm.insert_value v51[0], v50           ← 把 v50 塞进 struct 第 0 个字段
    ... 9 条 llvm.constant 各加一个字符 ...
    v62 = llvm.undef : llvm.array[9 x i8]
    v63..v71 = llvm.insert_value (拼成字符串 "%lld <> \0")
    v73 = llvm.alloca [llvm.array[9 x i8] x 1] : llvm.ptr   ← 在栈上分配字符串空间
    llvm.store *v73 <- v71                         ← 把字符串写进去
    ...
    v82 = llvm.call @vprintf(v78, v81)             ← 调 vprintf
    llvm.br ^block4v1()

^block4v1():
    v83 = llvm.constant <42: i32>                  ← 常量 42
    llvm.store *v0 <- v83                          ← *out = 42  ★
    llvm.return 

观察:

现象解释
v49 = llvm.inline_asmnvvm.read_ptx_xxx 经你之前自己写的 convert_ptx_xxx_i32 替换
llvm.zext ... to i64Rust 端 xxx as i64 的 cast,从 u32 零扩展(zero extend)到 i64
llvm.undef + llvm.insert_value在栈上构造 struct / array,SSA 风格
llvm.alloca [9 x i8]字符串 buffer 的栈空间(printf 需要传指针)
llvm.call @vprintf(...)nvvm.vprintf 被降级成对外部函数 @vprintf 的调用
*out = 42 三条 opconstant + store + return,跟最早 hello_kernel 一样简洁
addrspace(0) 到处出现所有指针都还在通用地址空间,后面 export 阶段会插 addrspace cast

8. nvvm.read_ptx_xxx → llvm.inline_asm 的完整链路

你之前自己写过 convert_ptx_xxx_i32,再看一下它在 driver 里被调的链路:

apply_dialect_conversion 遍历 module


遇到 v11 = nvvm.read_ptx_xxx

        ├──▶ can_convert_op() = true (nvvm dialect)


rewrite()

        ├──▶ opid != MirFuncOp / MirSharedAllocOp / ...
        │    (不走第一层特殊 case)


op_cast::<dyn MirToLlvmConversion>(op)


ReadPtxXXXOp 的 MirToLlvmConversion::convert


convert_ptx_xxx_i32(ctx, rewriter, op, operands_info)

        ├──▶ 创建 llvm.inline_asm op
        ├──▶ rewriter.replace_operation(ctx, op, asm_op)
        │    (这一步把所有用 v11 的下游 op 重定向到新 op)
        └──▶ 完成

rewriter.replace_operation 是关键——它不只是删 + 插,而是会自动重定向所有使用旧 op 结果的下游 op。这就是为什么之前用 erase_operation 会 panic:“Operation with use(s) being erased”——因为下游还在用结果,必须 replace 而不能 erase。

9. 部分转换 vs 完全转换

pliron 的 driver 支持两种模式:

模式含义
Partial conversion(部分转换)允许部分 op 没转换(dialect-mir 和 dialect-llvm 可共存)
Full conversion(完全转换)所有 mir / nvvm op 必须全部转换完,否则报错

cuda-oxide 用 full conversion,所以 rewrite 里如果遇到没 MirToLlvmConversion impl 的 op,直接报错:

let Some(converter) = op_cast::<dyn MirToLlvmConversion>(op_obj.as_ref()) else {
    return pliron::input_err!(
        loc,
        "Unsupported MIR/NVVM op for lowering: {}",
        Operation::get_opid(op, ctx)
    );
};

这就是你新加一个 NVVM op 时,必须配套加 MirToLlvmConversion impl 的原因(之前踩过这个流程)。漏掉就报这个 “Unsupported MIR/NVVM op” 错。

10. PHASE 7 日志

实际运行时,你会看到这条 info 级日志:

INFO mir_lower: [PHASE 7/9] mir-lower::lower_mir_to_llvm — dialect-mir → dialect-llvm
=== Verifying dialect-llvm module ===
dialect-llvm verification successful ✓

从打印 [PHASE 7/9] 到看到 verification successful ✓,中间就是这一篇讲的全部:apply_dialect_conversion 遍历完 module、所有 op 被替换、verifier 检查每个 dialect-llvm op 都合法。

接下来下一站会进入 llvm-export,把这堆 dialect-llvm op 序列化成文本 .ll 文件,再调 llc 编成 .ptx

11. 一句话总结

mir-lower 实现 pliron 的 DialectConversion trait,让 apply_dialect_conversion driver 自动遍历 module。每个 mir / nvvm op 通过 op_cast::<dyn MirToLlvmConversion>(op 类型分发,类似 Rust 的 dyn Trait)调用自己的 convert 实现,绝大多数翻成对应的 llvm op,nvvm intrinsic 翻成 llvm.call @llvm.nvvm.*llvm.inline_asm。整个 module 的 block 结构和控制流都保留,只是每条 op 换了方言。降级完之后 hello_kernel 简化到 4 行,*out = 42 也只剩 constant + store + return 三条 op——干净到可以直接吐成 LLVM IR 文本。

系列上一篇: cuda-oxide:hello-constant 拆解 07——块尾控制流 + intrinsic dispatch

评论区
评论功能即将上线, 敬请期待。