【pytorch扩展】CUDA自定义pytorch算子(简单demo入手)

Pytorch作为一款优秀的AI开发平台,提供了完备的自定义算子的规范。我们用torch开发时,经常会因为现有算子的不足限制我们idea的迸发。于是,CUDA/C++自定义pytorch算子是不得不磕了。

今天通过一个小实验来梳理自定义pytorch算子都需要做哪些准备。比如,我们做一个张量加法。
vim test_add.py

from add import sum_double_op
import torch
import time

class Timer:
    def __init__(self, op_name):
        self.begin_time = 0
        self.end_time = 0
        self.op_name = op_name

    def __enter__(self):
        torch.cuda.synchronize()
        self.begin_time = time.time()

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.cuda.synchronize()
        self.end_time = time.time()
        print(f"Average time cost of {self.op_name} is {(self.end_time - self.begin_time) * 1000:.4f} ms")


if __name__ == '__main__':
    n = 1000000
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    tensor1 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    tensor2 = torch.ones(n, dtype=torch.float32, device=device, requires_grad=True)
    with Timer("sum_double"):
        ans = sum_double_op(tensor1, tensor2)

这里的"sum_double_op"就是我们用CUDA写的算子。那这个可以直接调用,并且可以传递梯度的算子,需要怎么做呢?


众所周知,CUDA/C++都是编译性语言,编译以后再调用会比python这种解释性语言更快。所以,我们需要对CUDA有一个编译过程。这个编译过程用setuptools来实现(可以pip安装)。
先vim setup.py

from setuptools import find_packages, setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

setup(
    name='myAdd',
    packages=find_packages(),
    version='0.1.0',
    author='muzhan',
    ext_modules=[
        CUDAExtension(
            'sum_double',
            ['./add/add.cpp',
             './add/add_cuda.cu',]
        ),
    ],
    cmdclass={
        'build_ext': BuildExtension
    }
)

直接“python setup.py install”即可完成cuda算子的编译和安装。等等,你的add.cpp和add_cuda.cu还没呢?
vim add_cuda.cu

#include <cstdio>
#define THREADS_PER_BLOCK 256
#define WARP_SIZE 32
#define DIVUP(m, n) ((m + n - 1) / n)


__global__ void two_sum_kernel(const float* a, const float* b, float * c, int n){
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx < n){
        c[idx] = a[idx] + b[idx];
    }
}


void two_sum_launcher(const float* a, const float* b, float* c, int n){
    dim3 blockSize(DIVUP(n, THREADS_PER_BLOCK));
    dim3 threadSize(THREADS_PER_BLOCK);
    two_sum_kernel<<<blockSize, threadSize>>>(a, b, c, n);
}

vim add.cpp

#include <torch/extension.h>
#include <torch/serialize/tensor.h>

#define CHECK_CUDA(x) \
  TORCH_CHECK(x.type().is_cuda(), #x, " must be a CUDAtensor ")
#define CHECK_CONTIGUOUS(x) \
  TORCH_CHECK(x.is_contiguous(), #x, " must be contiguous ")
#define CHECK_INPUT(x) \
  CHECK_CUDA(x);       \
  CHECK_CONTIGUOUS(x)


void two_sum_launcher(const float* a, const float* b, float* c, int n);


void two_sum_gpu(at::Tensor a_tensor, at::Tensor b_tensor, at::Tensor c_tensor){
    CHECK_INPUT(a_tensor);
    CHECK_INPUT(b_tensor);
    CHECK_INPUT(c_tensor);

    const float* a = a_tensor.data_ptr<float>();
    const float* b = b_tensor.data_ptr<float>();
    float* c = c_tensor.data_ptr<float>();
    int n = a_tensor.size(0);
    two_sum_launcher(a, b, c, n);
}


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &two_sum_gpu, "sum two arrays (CUDA)");
}

我们看一下文件结构:

.
├── add
│   ├── add.cpp
│   ├── add_cuda.cu
│   ├── __init__.py
│   └── sum.py
├── README.md
├── setup.py
└── test_add.py

有了add.cpp和add_cuda.cu以后,我们就可以用"python setup.py install"来进行编译和安装了。编译和安装以后,我们需要用python类封装一下:

vim __init__.py
from .sum import *

vim sum.py

from torch.autograd import Function
import sum_double


class SumDouble(Function):

    @staticmethod
    def forward(ctx, array1, array2):
        """sum_double function forward.
        Args:
            array1 (torch.Tensor): [n,]
            array2 (torch.Tensor): [n,]
        
        Returns:
            ans (torch.Tensor): [n,]
        """
        array1 = array1.float()
        array2 = array2.float()
        ans = array1.new_zeros(array1.shape)
        sum_double.forward(array1.contiguous(), array2.contiguous(), ans)

        # ctx.mark_non_differentiable(ans) # if the function is no need for backpropogation

        return ans

    @staticmethod
    def backward(ctx, g_out):
        # return None, None   # if the function is no need for backpropogation

        g_in1 = g_out.clone()
        g_in2 = g_out.clone()
        return g_in1, g_in2


sum_double_op = SumDouble.apply

最后,直接

python test_add.py

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/772276.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

SAPUI5基础知识10 - i18与国际化

1. 背景 i18n 是 “internationalization” 的缩写&#xff0c;其中的 18 是 “internationalization” 这个单词中间的字符数。i18n 是一种让应用程序支持多种语言的方法&#xff0c;也就是我们通常所说的国际化。 在SAPUI5中&#xff0c;i18n主要通过使用资源模型&#xff…

Matplotlib 文本

可以使用 xlabel、ylabel、text向图中添加文本 mu, sigma 100, 15 x mu sigma * np.random.randn(10000)# the histogram of the data n, bins, patches plt.hist(x, 50, densityTrue, facecolorg, alpha0.75)plt.xlabel(Smarts) plt.ylabel(Probability) plt.title(Histo…

【️讲解下Laravel为什么会成为最优雅的PHP框架?】

&#x1f3a5;博主&#xff1a;程序员不想YY啊 &#x1f4ab;CSDN优质创作者&#xff0c;CSDN实力新星&#xff0c;CSDN博客专家 &#x1f917;点赞&#x1f388;收藏⭐再看&#x1f4ab;养成习惯 ✨希望本文对您有所裨益&#xff0c;如有不足之处&#xff0c;欢迎在评论区提出…

540. 有序数组中的单一元素(中等)

540. 有序数组中的单一元素 1. 题目描述2.详细题解3.代码实现3.1 Python3.2 Java 1. 题目描述 题目中转&#xff1a;540. 有序数组中的单一元素 2.详细题解 方法一&#xff1a;若不限定时间复杂度&#xff0c;则扫描遍历一遍即可找到仅出现一次的数&#xff0c;具体实现见Pyth…

Maven Archetype 自定义项目模板:高效开发的最佳实践

文章目录 前言一、Maven Archetype二、创建自定义 Maven Archetype三、定制 Archetype 模板四、手动创建 Archetype 模板项目五、FAQ5.1 如何删除自定义的模板5.2 是否可以在模板中使用空文件夹 六、小结推荐阅读 前言 在软件开发中&#xff0c;标准化和快速初始化项目结构能够…

什么是JSON ,ajax和json关系

一. JSON 1 JSON概述 JavaScript对象文本表示形式&#xff08;JavaScript Object Notation : js对象简写) json是js对象 json是目前 前后端数据交互的主要格式之一 * java对象表示形式User user new User();user.setUsername("后羿");user.setAge(23);user.setSex…

开发国际短剧系统的策略解析

一、明确项目目标和需求 1、功能需求&#xff1a;确定系统应具备的基本功能&#xff0c;如用户注册、登录、浏览短剧、评论、分享、个性化推荐等。 2、性能需求&#xff1a;确保系统能够承受高并发访问&#xff0c;保证视频流畅播放&#xff0c;减少卡顿和延迟。 3、跨文化传播…

中序遍历的两种实现——二叉树专题复习

递归实现&#xff1a; /*** Definition for a binary tree node.* public class TreeNode {* int val;* TreeNode left;* TreeNode right;* TreeNode() {}* TreeNode(int val) { this.val val; }* TreeNode(int val, TreeNode left, TreeNode right)…

【算法】(C语言):堆排序

堆&#xff08;二叉树的应用&#xff09;&#xff1a; 完全二叉树。最大堆&#xff1a;每个节点比子树所有节点的数值都大&#xff0c;根节点是最大值。父子索引号关系&#xff08;根节点为0&#xff09;&#xff1a;&#xff08;向上&#xff09;子节点x&#xff0c;父节点(x…

命令行升级ubuntu版本过程中出现的grub问题 解决

1、问题描述 使用命令行升级ubuntu18到20版本后&#xff0c;系统提示重启&#xff0c;使用reboot命令重启后&#xff0c;不显示服务器ip&#xff0c;或是显示但无法ssh远程连接服务器了&#xff0c;使用屏幕连接服务器后发现出现grub问题。 2、问题经过 命令行输入如下升级u…

【虚拟机】虚拟机网络无法访问问题【已解决】

【虚拟机】虚拟机无法上网问题【已解决】 问题探究解决方法法1&#xff1a;查看相关“网络服务”是否处于正常启动状态法2&#xff1a;重启网络法3&#xff1a;重新安装VMWare法4&#xff1a;使用NAT模式&#xff0c;每次打开win7都没连上网的解决办法 问题探究 安装了很多个虚…

Objection 对命令的批量操作

假定现在需要对好多不同的类进行批量hook&#xff0c;逐个hook非常繁琐&#xff0c;那么可以要将这些hook的类放到一个文件里&#xff0c;并且在这些类的前面加上hook命令&#xff0c;内容如下 使用如下命令执行该文件中的命令 objection -g 测试 explore -c d:/hookData/toHoo…

如何从腾讯云迁移到AWS

随着跨境出海潮不断扩大&#xff0c;企业越来越意识到将工作负载迁移到海外节点的必要性&#xff0c;以获取更多功能、灵活性和性能。然而&#xff0c;顺利迁移业务主机并确保业务稳定访问是一项具有挑战性的任务。在此挑战中&#xff0c;借助AWS迁移工具和迁移流程的强大支持&…

docker 安装 禅道

docker pull hub.zentao.net/app/zentao:20.1.1 sudo docker network create --subnet172.172.172.0/24 zentaonet 使用 8087端口号访问 使用禅道mysql 映射到3307 sudo docker run \ --name zentao2 \ -p 8087:80 \ -p 3307:3306 \ --networkzentaonet \ --ip 172.172.172.…

WIN32核心编程 - 进程操作(一) 进程基础 - 创建进程 - 进程句柄

公开视频 -> 链接点击跳转公开课程博客首页 -> 链接点击跳转博客主页 目录 进程基础 进程的定义与概念 进程的组成 创建进程 可执行文件 CreateProces 执行流程 GetStartupInfo 进程终止 进程句柄 创建进程 打开进程 进程提权 内核模拟 回溯对象 自身进…

有哪些好用的eHR人事系统?国内外HR软件选型指南分享

在人力资源管理信息化这个问题上&#xff0c;不同行业的企业对人力资源管理软件的需求侧重点不一样&#xff0c;并且通常企业规模决定了企业需求的强烈程度&#xff0c;以及能花在这个软件采购上的预算。 首先需要对公司需要人力资源软件的目的和基本需求加以明确。你为什么想用…

软件测试必问必背面试题

01 软件测试理论部分 1.1 测试概念 1. 请你分别介绍一下单元测试、集成测试、系统测试、验收测试、回归测试 单元测试&#xff1a;完成最小的软件设计单元&#xff08;模块&#xff09;的验证工作&#xff0c;目标是确保模块被正确的编码集成测试&#xff1a;通过测试发现与…

【Linux】探索网络编程:TCP/UDP协议解析与Socket应用实例

文章目录 前言&#xff1a;1. 预备知识1.1 理解源IP地址和目的IP地址1.2 认识端口号1.3 理解"端口号"和"进程ID"1.4 理解源端口号和目的端口号1.5 认识TCP协议1.6 认识UDP协议1.6 TCP vs UDP 可靠性1.7 网络字节序 2. socket 编程接口2.1 socket 常见API2.…

为了SourceInsight从Linux回到Windows

什么是SourceInsight 现在上网搜索这个软件&#xff0c;大多数说他是一个代码阅读软件&#xff1b;但是在官方的说法里面&#xff0c;这是一款支持多语言的编辑器。大概长这样&#xff1a; 看起来十分老旧是吧&#xff0c;但是他其实他已经是第四代了哈哈哈。其实这个软件是我…

LeetCode 全排列

思路&#xff1a;这是一道暴力搜索问题&#xff0c;我们需要列出答案的所有可能组合。 题目给我们一个数组&#xff0c;我们很容易想到的做法是将数组中的元素进行排列&#xff0c;如何区分已选中和未选中的元素&#xff0c;容易想到的是建立一个标记数组&#xff0c;已经选中的…