libtorch1.7.0 cuda10.1 進行unet 模型部署
阿新 • • 發佈:2020-11-03
#include <iostream> #include <memory> #include <string> #include <torch/script.h> #include <opencv2/opencv.hpp> #include <opencv2/core/core.hpp> #include <opencv2/imgproc/imgproc.hpp> #include "opencv2/imgproc/types_c.h" using namespace std; using namespace cv; torch::Tensor unet_data_preprocess(Mat &image, float scale=1) { cv::cvtColor(image, image, CV_BGR2RGB); int w = image.cols; int h = image.rows; int newW = int(scale * w); int newH = int(scale * h); Mat img_processed; cv::resize(image, img_processed, cv::Size(newW, newH)); //cv::imshow("img_processed", img_processed); //cv::waitKey(0); torch::Tensor imgtransform; imgtransform = torch::from_blob(img_processed.data, {1,newH,newW,3}, torch::kByte); imgtransform = imgtransform.permute({0,3,1,2 }); imgtransform = imgtransform.to(torch::kFloat); imgtransform = imgtransform.div(255.0); return imgtransform; } int main() { //Load model. torch::jit::script::Module unet_module; try { unet_module = torch::jit::load("G:\\liu_projects\\unet_cpp\\Unet_package\\model\\traced_unet_model.pt"); } catch (const c10::Error& e) { std::cerr << "error loading the model!"; return -1; } torch::Device device(torch::kCUDA); unet_module.to(device); unet_module.eval(); std::cout << "model loaded on cuda!\n"; //prepare image tensor. //std::vector<torch::jit::IValue> inputs; cv::Mat image; image = cv::imread("G:\\liu_projects\\unet_cpp\\Unet_package\\test_imgs\\test.jpg"); torch::Tensor img_tensor=unet_data_preprocess(image,0.5); img_tensor = img_tensor.to(device); //forward. at::Tensor output = unet_module.forward({ img_tensor }).toTensor(); at::Tensor probs=torch::sigmoid(output); probs = probs.squeeze(0).detach().permute({ 1, 2, 0 }); cout<<"probs size: "<< probs.sizes() <<endl; probs = probs > 0.5; probs = probs.mul(255).clamp(0, 255).to(torch::kU8); probs = probs.to(torch::kCPU); cv::Mat resultImg(640, 959, CV_8UC1); // copy the data from out_tensor to resultImg std::memcpy((void*)resultImg.data, probs.data_ptr(), sizeof(torch::kU8) * probs.numel()); cv::imshow("resultImg", resultImg); cv::waitKey(0); return 0; }
參考文章:
https://www.cnblogs.com/yanghailin/p/12901586.html (libtorch 常用api函式示例(史上最全、最詳細))
https://pytorch.apachecn.org/docs/1.4/30.html (記住:模型要儲存成cpu的!)
https://blog.csdn.net/juluwangriyue/article/details/108360320 (libtorch tensor轉mat)