博客
关于我
[LibTorch] Tensor 转为 Mat
阅读量:224 次
发布时间:2019-02-28

本文共 3402 字,大约阅读时间需要 11 分钟。

文章目录

Torch 框架搭建的神经网络,输出结果的类型一般是 CUDAFloatType。现在想把结果转为 cv::Mat,并通过 imwrite 输出图片。

单通道

神经网络输出的 tensor 的通道数为 1。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CPUByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).to(torch::kCPU);/* CPUByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);memcpy(img_C1.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C1);

三通道

神经网络输出的 tensor 的通道数为 3。

#include 
#include
/* 初始化模型 */torch::jit::script::Module net = torch::jit::load("../models/poolnet2.pt");net.to(device);net.eval();/* 输入 CUDAFloatType { 1, 3, img.rows, img.cols } */cv::Mat img = cv::imread(fileName);cv::cvtColor(img, img, cv::COLOR_BGR2RGB);auto img_tensor = torch::from_blob(img.data, { 1, img.rows, img.cols, 3}, torch::TensorOptions().dtype(torch::kByte)).to(device);img_tensor = img_tensor.permute({ 0,3,1,2});img_tensor = img_tensor.toType(torch::kFloat);/* 输出 CUDAFloatType { 1, 3, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 3, img.rows, img.cols } -> CPUByteType { img.rows, img.cols, 3 } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte).permute({ 1,2,0}).to(torch::kCPU);/* CPUByteType { img.rows, img.cols, 3 } -> cv::Mat */cv::Mat img_C3;img_C3.create(cv::Size(img.cols, img.rows), CV_8UC3);memcpy(img_C3.data, result.data_ptr(), result.numel() * sizeof(torch::kByte));cv::imwrite(saveName, img_C3);

CUDA2CPU 问题

把 tensor 从 CUDA 上转移到 CPU,操作是

result = result.to(torch::kCPU);

至于为什么要把 tensor 从 CUDA 转移到 CPU 上,是因为 memcpy 函数只能操作 CPU 上的数据。不过也可以不使用 memcpy,使用 CUDA 库提供的 cudaMemcpy 函数,但需要包含头文件 cuda_runtime.h,以单通道为例:

#include 
#include
#include
.../* 输出 CUDAFloatType { 1, 1, img.rows, img.cols }*/torch::NoGradGuard no_grad;auto result = net.forward({ img_tensor}).toTensor();/* CUDAFloatType { 1, 1, img.rows, img.cols } -> CUDAByteType { img.rows, img.cols } */result = result.squeeze().sigmoid().mul(255.0).toType(torch::kByte);/* CUDAByteType { img.rows, img.cols } -> cv::Mat */cv::Mat img_C1;img_C1.create(cv::Size(img.cols, img.rows), CV_8UC1);/* cudaMemcpyDeviceToHost 表示复制方向是从 Device(GPU) 到 Host(CPU) */cudaMemcpy(pred.data, out.data_ptr(), out.numel(), cudaMemcpyDeviceToHost);cv::imwrite(saveName, img_C1);

这两种方式需要的时间是差不多的。转移操作花费的时间与 神经网络深度输出大小 有关。

如果神经网络很深,就以经典网络骨架 resnet50 为例,输出大小为 { 1, 2048, 32, 32 },那么转移操作在 GTX 1080 Ti + CUDA 11.2 机器上(下同)大约是 130 ms。

再以目前使用的 poolnet 为例,网络比较深,输出大小为 { 500, 500 },转移操作大约需要 150 ms。如果把输出大小改为 { 256, 256 },时间缩短为 84 ms。

如果声明一个类型为 CUDAFloatType,大小为 { 256, 256 } 的 tensor,将其直接转移到 CPU 上,花费的时间为 0 ms。

转载地址:http://goep.baihongyu.com/

你可能感兴趣的文章
mysql 多字段删除重复数据,保留最小id数据
查看>>
MySQL 多表联合查询:UNION 和 JOIN 分析
查看>>
MySQL 大数据量快速插入方法和语句优化
查看>>
mysql 如何给SQL添加索引
查看>>
mysql 字段区分大小写
查看>>
mysql 字段合并问题(group_concat)
查看>>
mysql 字段类型类型
查看>>
MySQL 字符串截取函数,字段截取,字符串截取
查看>>
MySQL 存储引擎
查看>>
mysql 存储过程 注入_mysql 视图 事务 存储过程 SQL注入
查看>>
MySQL 存储过程参数:in、out、inout
查看>>
mysql 存储过程每隔一段时间执行一次
查看>>
mysql 存在update不存在insert
查看>>
Mysql 学习总结(86)—— Mysql 的 JSON 数据类型正确使用姿势
查看>>
Mysql 学习总结(87)—— Mysql 执行计划(Explain)再总结
查看>>
Mysql 学习总结(88)—— Mysql 官方为什么不推荐用雪花 id 和 uuid 做 MySQL 主键
查看>>
Mysql 学习总结(89)—— Mysql 库表容量统计
查看>>
mysql 实现主从复制/主从同步
查看>>
mysql 审核_审核MySQL数据库上的登录
查看>>
mysql 导入 sql 文件时 ERROR 1046 (3D000) no database selected 错误的解决
查看>>