博客
关于我
[LibTorch] Tensor 转为 Mat
阅读量:224 次
发布时间:2019-02-28

本文共 3402 字,大约阅读时间需要 11 分钟。

文章目录

Torch 框架搭建的神经网络,输出结果的类型一般是 CUDAFloatType。现在想把结果转为 cv::Mat,并通过 imwrite 输出图片。

单通道

神经网络输出的 tensor 的通道数为 1。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CPUByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).to(torch::kCPU);/* CPUByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);memcpy(img_C1.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C1);

三通道

神经网络输出的 tensor 的通道数为 3。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet2.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 3, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 3, img.rows, img.cols } -> CPUByteType { img.rows, img.cols, 3 } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).permute({ 1,2,0}).to(torch::kCPU);/* CPUByteType { img.rows, img.cols, 3 } -> cv::Mat */cv::Mat img_C3;img_C3.create(cv::Size(img.cols, img.rows), CV_8UC3);memcpy(img_C3.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C3);

CUDA2CPU 问题

把 tensor 从 CUDA 上转移到 CPU,操作是

result = result.to(torch::kCPU);

至于为什么要把 tensor 从 CUDA 转移到 CPU 上,是因为 memcpy 函数只能操作 CPU 上的数据。不过也可以不使用 memcpy,使用 CUDA 库提供的 cudaMemcpy 函数,但需要包含头文件 cuda_runtime.h,以单通道为例:

#include 
#include
#include
.../* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CUDAByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte);/* CUDAByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);/* cudaMemcpyDeviceToHost 表示复制方向是从 Device(GPU) 到 Host(CPU) */cudaMemcpy(pred.data, out.data_ptr(), out.numel(), cudaMemcpyDeviceToHost);cv::imwrite(saveName, img_C1);

这两种方式需要的时间是差不多的。转移操作花费的时间与 神经网络深度输出大小 有关。

如果神经网络很深,就以经典网络骨架 resnet50 为例,输出大小为 { 1, 2048, 32, 32 },那么转移操作在 GTX 1080 Ti + CUDA 11.2 机器上(下同)大约是 130 ms。

再以目前使用的 poolnet 为例,网络比较深,输出大小为 { 500, 500 },转移操作大约需要 150 ms。如果把输出大小改为 { 256, 256 },时间缩短为 84 ms。

如果声明一个类型为 CUDAFloatType,大小为 { 256, 256 } 的 tensor,将其直接转移到 CPU 上,花费的时间为 0 ms。

转载地址:http://goep.baihongyu.com/

你可能感兴趣的文章
Mysql中存储过程、存储函数、自定义函数、变量、流程控制语句、光标/游标、定义条件和处理程序的使用示例
查看>>
mysql中实现rownum,对结果进行排序
查看>>
mysql中对于数据库的基本操作
查看>>
Mysql中常用函数的使用示例
查看>>
MySql中怎样使用case-when实现判断查询结果返回
查看>>
Mysql中怎样使用update更新某列的数据减去指定值
查看>>
Mysql中怎样设置指定ip远程访问连接
查看>>
mysql中数据表的基本操作很难嘛,由这个实验来带你从头走一遍
查看>>
Mysql中文乱码问题完美解决方案
查看>>
mysql中的 +号 和 CONCAT(str1,str2,...)
查看>>
Mysql中的 IFNULL 函数的详解
查看>>
mysql中的collate关键字是什么意思?
查看>>
MySql中的concat()相关函数
查看>>
mysql中的concat函数,concat_ws函数,concat_group函数之间的区别
查看>>
MySQL中的count函数
查看>>
MySQL中的DB、DBMS、SQL
查看>>
MySQL中的DECIMAL类型:MYSQL_TYPE_DECIMAL与MYSQL_TYPE_NEWDECIMAL详解
查看>>
MySQL中的GROUP_CONCAT()函数详解与实战应用
查看>>
MySQL中的IO问题分析与优化
查看>>
MySQL中的ON DUPLICATE KEY UPDATE详解与应用
查看>>