博客
关于我
[LibTorch] Tensor 转为 Mat
阅读量:224 次
发布时间:2019-02-28

本文共 3402 字,大约阅读时间需要 11 分钟。

文章目录

Torch 框架搭建的神经网络,输出结果的类型一般是 CUDAFloatType。现在想把结果转为 cv::Mat,并通过 imwrite 输出图片。

单通道

神经网络输出的 tensor 的通道数为 1。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CPUByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).to(torch::kCPU);/* CPUByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);memcpy(img_C1.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C1);

三通道

神经网络输出的 tensor 的通道数为 3。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet2.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 3, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 3, img.rows, img.cols } -> CPUByteType { img.rows, img.cols, 3 } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).permute({ 1,2,0}).to(torch::kCPU);/* CPUByteType { img.rows, img.cols, 3 } -> cv::Mat */cv::Mat img_C3;img_C3.create(cv::Size(img.cols, img.rows), CV_8UC3);memcpy(img_C3.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C3);

CUDA2CPU 问题

把 tensor 从 CUDA 上转移到 CPU,操作是

result = result.to(torch::kCPU);

至于为什么要把 tensor 从 CUDA 转移到 CPU 上,是因为 memcpy 函数只能操作 CPU 上的数据。不过也可以不使用 memcpy,使用 CUDA 库提供的 cudaMemcpy 函数,但需要包含头文件 cuda_runtime.h,以单通道为例:

#include 
#include
#include
.../* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CUDAByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte);/* CUDAByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);/* cudaMemcpyDeviceToHost 表示复制方向是从 Device(GPU) 到 Host(CPU) */cudaMemcpy(pred.data, out.data_ptr(), out.numel(), cudaMemcpyDeviceToHost);cv::imwrite(saveName, img_C1);

这两种方式需要的时间是差不多的。转移操作花费的时间与 神经网络深度输出大小 有关。

如果神经网络很深,就以经典网络骨架 resnet50 为例,输出大小为 { 1, 2048, 32, 32 },那么转移操作在 GTX 1080 Ti + CUDA 11.2 机器上(下同)大约是 130 ms。

再以目前使用的 poolnet 为例,网络比较深,输出大小为 { 500, 500 },转移操作大约需要 150 ms。如果把输出大小改为 { 256, 256 },时间缩短为 84 ms。

如果声明一个类型为 CUDAFloatType,大小为 { 256, 256 } 的 tensor,将其直接转移到 CPU 上,花费的时间为 0 ms。

转载地址:http://goep.baihongyu.com/

你可能感兴趣的文章
mysql -存储过程
查看>>
mysql /*! 50100 ... */ 条件编译
查看>>
mudbox卸载/完美解决安装失败/如何彻底卸载清除干净mudbox各种残留注册表和文件的方法...
查看>>
mysql 1264_关于mysql 出现 1264 Out of range value for column 错误的解决办法
查看>>
mysql 1593_Linux高可用(HA)之MySQL主从复制中出现1593错误码的低级错误
查看>>
mysql 5.6 修改端口_mysql5.6.24怎么修改端口号
查看>>
MySQL 8.0 恢复孤立文件每表ibd文件
查看>>
MySQL 8.0开始Group by不再排序
查看>>
mysql ansi nulls_SET ANSI_NULLS ON SET QUOTED_IDENTIFIER ON 什么意思
查看>>
multi swiper bug solution
查看>>
MySQL Binlog 日志监听与 Spring 集成实战
查看>>
MySQL binlog三种模式
查看>>
multi-angle cosine and sines
查看>>
Mysql Can't connect to MySQL server
查看>>
mysql case when 乱码_Mysql CASE WHEN 用法
查看>>
Multicast1
查看>>
mysql client library_MySQL数据库之zabbix3.x安装出现“configure: error: Not found mysqlclient library”的解决办法...
查看>>
MySQL Cluster 7.0.36 发布
查看>>
Multimodal Unsupervised Image-to-Image Translation多通道无监督图像翻译
查看>>
MySQL Cluster与MGR集群实战
查看>>