零侵入、极简适配!飞桨CINN实现类CUDA硬件“即插即用”

张开发
2026/4/4 10:22:52 15 分钟阅读
零侵入、极简适配!飞桨CINN实现类CUDA硬件“即插即用”
简介继飞桨框架3.1版本推出“插件式 CUDA兼容类硬件接入方案”飞桨实现插件式硬件图接入方案模型推理加速2.2倍实现运行时Runtime与算子Kernel的高效复用后飞桨成功打通了PaddleCustomDevice仓硬件使用神经网络编译器CINNCompiler Infrastructure for Neural Networks的链路。新方案通过C接口隔离与动态链接机制实现了CINN编译器与硬件底层编译工具链的彻底解耦。CUDA兼容类硬件如天数、沐曦等仅需实现标准化接口即可“即插即用”地接入CINN充分享受飞桨框架编译器带来的图优化与算子自动生成等核心技术红利即彻底告别手写底层代码不仅实现模型端到端训练提速27.4%更在科学计算的微分方程求解场景下取得了比PyTorch快115%的压倒性优势统一多场景自动编译加速——支持动态shape场景一套架构搞定训推需求。飞桨多硬件统一适配方案01整体架构概览分层解耦动态挂载本方案在架构设计上延续了飞桨CustomDevice插件式接入的核心思想将整个链路严格划分为“CinnCustomDevicePlugin 框架接口层”、“C_CinnInterface C接口桥接层”、与“C_DeviceInterface厂商插件层”。CINN编译器实现CUDA兼容类硬件插件式接入方案02核心实现框架与插件的“双向奔赴”为了实现上述架构飞桨核心仓与PaddleCustomDevice硬件插件仓进行了清晰的分工与接口对齐。01框架侧Paddle Repo定义标准与抽象封装框架侧通过抽象接口定义了标准的设备交互规范消除了对任何特定硬件后端的依赖。核心接口定义(C_CinnInterface)定义了三大类跨语言C接口Compiler Toolchain编译工具链包含compile()用于调用外部编译器get_runtime_source()获取设备端运行时源码。Runtime Strategy运行时策略包含module_load/unload()用于加载和卸载编译后的模块get_kernel_address()寻址函数以及launch_kernel()启动执行。Compile Strategy编译优化提供apply_custom_pass()允许厂商应用自定义的IR中间表示优化Pass。C抽象接口与桥接层在custom_device_backend_api.h中定义了 CustomCompilerToolchain 和 CustomRuntimeStrategy 等高层抽象类。通过 CinnCustomDevicePlugin 获取插件实例并在桥接层将框架的 C 调用安全地转换为对 C 接口的调用。02插件侧PaddleCustomDevice Repo动态注册与能力映射硬件厂商以类CUDA硬件为例只需在自定义插件仓中实现对应的逻辑并通过动态库的形式供飞桨加载。插件入口动态注册在InitPlugin()中不仅填充设备管理、显存、通信等基础接口还通过调用InitCinnInterface()将底层实现挂载到C_CinnInterface上。编译器与运行时实现编译阶段通过实现CustomDeviceCompile调用厂商自研编译器如天数的IXCC,沐曦的MXCC将源码编译为二进制格式。同时注入设备端需要的Runtime Source如特定的math函数或reduce原语。运行阶段完美兼容CUDA Driver API范式。例如CustomDeviceModuleLoad底层直接映射为加载编译模块类似cuModuleLoadCustomDeviceLaunchKernel映射为执行类似cuLaunchKernel。03一览全景模型执行的调用链路当用户模型在支持CINN的环境下运行时一条清晰的调用链路被瞬间激活数据与指令在框架与硬件之间高效流转01编译期(CINN Compiler)首先是编译期的深度优化与底层代码生成。飞桨框架的前端在完成基础的图优化后会首先进入由硬件主导的深度定制阶段。此时框架会调用compile_strategy-ApplyCustomPass触发名为apply_custom_pass的C接口。通过该接口系统动态加载并执行硬件厂商自定义的IR Pass中间表示优化与Schedule调度编排策略充分针对特定硬件的寄存器与访存层级进行极致的算力压榨。在经历层层定制优化后系统进入Codegen代码生成环节将优化后的中间表示转化为结构化的底层代码。最后框架获取对应的编译工具链并调用Compile方法进而触发compile C接口直接唤醒硬件厂商自研的底层编译器例如天数的IXCC和沐曦的MXCC由它负责将这些代码最终翻译成目标硬件能懂的高效二进制机器码。02加载期紧接着进入加载期。机器码生成后框架的运行时策略模块会主动调用LoadModule方法这会向下触发module_load这一C接口。该接口的作用是指示底层驱动将刚刚编译好的二进制模块妥善载入设备的显存空间其底层行为完全对标并兼容CUDA中的cuModuleLoad操作。03执行期(Engine)最后是真正的执行期。当模型运行到具体的计算节点时飞桨的执行器会先从加载好的模块中精准寻址提取出对应的函数指针。随后框架调用LaunchKernel方法该调用最终被映射到底层的launch_kernelC接口上在物理硬件层面上真正拉起计算内核的执行如同执行cuLaunchKernel。04适配成本与代码量对比基于本方案对于CINN编译器的硬件适配研发门槛和代码开发量实现了显著降低。以往采用内置侵入式方案需要直接修改飞桨核心代码仓初次跑通往往需要在核心框架内硬编码数千行C代码。而目前沐曦、天数智芯等厂商已采用全新插件式接入方案硬件后端与飞桨主仓解耦。厂商仅需利用自有编译器如天数智芯的IXCC、沐曦的MXCC在PaddleCustomDevice独立仓中实现数十行纯C接口的映射挂载和几百行对应实现即可沿用CINN的基础Schedule、Pass和Codegen逻辑激活CINN基本访存密集型算子融合能力。相比之下该方案能极大降低编译器底层的技术壁垒大幅加速新硬件生态落地飞桨。在真实的工程落地中底层编译适配工作实现了从“月级”到“周级”的跨越。如今以天数智芯和沐曦为例无需侵入式修改Paddle仓仅需不到一周即可完成基本接入。以沐曦为例/work/PaddleCustomDevice/backends/metax_gpu/cinn/cinn_interface.cc中#include cinn_interface.hnamespace paddle {namespace custom_device {namespace metax {// // 外部函数声明 (External Function Declarations)// 这些函数需要在对应的子目录文件中实现 (.cc)// // --- 来自 compiler/compiler.cc ---// 负责调用 mxcc 将 CINN 生成的源代码编译为二进制extern C_Status MetaxCompile(void* dev_ptr, const char* code, char* out_path, size_t len);// 负责提供沐曦 GPU 运行时的基础源码 (类似 cuda_device_runtime.cu)extern const char* MetaxGetRuntimeSource(void* dev_ptr);// --- 来自 runtime/cinn_runtime.cc ---// 负责加载编译好的二进制模块 (.mx / .so)extern C_Status MetaxModuleLoad(void* dev_ptr, const char* path, void** mod_out);// 负责卸载模块extern C_Status MetaxModuleUnload(void* dev_ptr, void* module_handle);// 负责从模块中查找核函数地址extern C_Status MetaxGetKernelAddress(void* dev_ptr, void* module_handle, const char* func_name, void** func_out);// 负责启动核函数 (Launch Kernel)extern C_Status MetaxLaunchKernel(void* dev_ptr, void* func_ptr, void** args, int num_args, int gx, int gy, int gz, int bx, int by, int bz, int shm, void* stream);// --- 来自 passes/pass_manager.cc ---// 负责应用自定义的图优化 Pass由框架在遇到未知 Pass 名时回调extern C_Status MetaxApplyCustomPass(void* dev_ptr, const char* pass_name, void* ir_module);// 返回 MetaX GPU 的有序 Pass 执行列表extern C_Status MetaxQueryPassPipeline(void* dev_ptr, char pass_names[][128], int* count);// // 接口初始化实现 (Interface Initialization)// static C_CinnInterface metax_cinn_impl;void InitCinnInterface(C_DeviceInterface* device_interface) { // 1. 安全起见先清零 std::memset(metax_cinn_impl, 0, sizeof(C_CinnInterface)); // 2. 设置结构体大小 (用于版本校验) metax_cinn_impl.size sizeof(C_CinnInterface); // 3. 设置上下文指针 (可选) // 如果你的实现需要全局状态可以指向一个结构体否则设为 nullptr metax_cinn_impl.dev_ptr nullptr; // 4. 挂载 Compiler Toolchain 接口 metax_cinn_impl.compile MetaxCompile; metax_cinn_impl.get_runtime_source MetaxGetRuntimeSource; // 5. 挂载 Runtime Strategy 接口 metax_cinn_impl.module_load MetaxModuleLoad; metax_cinn_impl.module_unload MetaxModuleUnload; metax_cinn_impl.get_kernel_address MetaxGetKernelAddress; metax_cinn_impl.launch_kernel MetaxLaunchKernel; // 6. 挂载 Compile Strategy 接口 metax_cinn_impl.apply_custom_pass MetaxApplyCustomPass; metax_cinn_impl.query_pass_pipeline MetaxQueryPassPipeline; // 7. 将填好的表挂载到 Paddle 主设备接口上 if (device_interface) { device_interface-cinn_interface metax_cinn_impl; VLOG(3) [MetaX] CINN Interface initialized successfully.; } else { std::cerr [MetaX] Error: device_interface is null during CINN init. std::endl; }}} // namespace metax} // namespace custom_device} // namespace paddle通过插件式方案硬件厂商可以“免费”继承飞桨沉淀多年的前沿技术——无论是复杂的计算图优化、访存密集型算子融合还是未来针对大模型生成的动态Shape优化。彻底省去了深入理解飞桨CINN前后端复杂IR机制的工作只需专注打磨自身的闭源编译器引擎。飞桨将持续深化这套异构计算底座让每一片国产芯片都能以极低的门槛接入顶级AI生态共同释放大模型时代的澎湃算力总结CINN编译器插件式接入方案的落地标志着飞桨在多硬件适配的深度与广度上迈出了坚实的一步。这种设计不仅解耦了Paddle核心与厂商底层的实现细节更让国产硬件生态能够以极低的门槛接入飞桨编译器的核心能力体系共同为开发者带来极致的性能体验。未来飞桨将继续携手广大硬件生态伙伴基于此方案探索更深度的动态Shape优化与算子定制为开发者提供更高效、更灵活的AI基础设施。关注【飞桨PaddlePaddle】公众号获取更多技术内容~

更多文章