將mnist訓練的caffemodel生成動態連結庫DLL
阿新 • • 發佈:2021-10-03
在專案程式中經常看到動態連結庫,非常好奇,想自己實現一下,於是乎嘗試一波。就因為這種好奇,每天都被bug所困擾。。。
1. 訓練caffemodel
在windows環境下搭建caffe無果,轉投Ubuntu。。。
用的caffe--example--mnist中的檔案,新建資料夾的話注意改路徑,下面為train.sh
#!/usr/bin/env sh set -e /home/fish/caffe/build/tools/caffe train --solver=/home/fish/STUDY/lenet_solver.prototxt
2.使用cv::dnn裡的API載入model,輸入圖片,進行測試(可跳過)
根據文章https://blog.csdn.net/sushiqian/article/details/78555891,修改模型檔案。若圖片為白底黑字,bitwise_not一下。
#include #include <opencv2/opencv.hpp> #include <opencv2/dnn.hpp> using namespace std; using namespace cv; using namespace cv::dnn; /* Find best class for the blob (i. e. class with maximal probability)*/ static void getMaxClass(const Mat& probBlob, int* classId, double* classProb) { Mat probMat = probBlob.reshape(1, 1); Point classNumber; minMaxLoc(probMat, NULL, classProb, NULL, &classNumber); *classId = classNumber.x; } int main(int argc, char* argv[]) { string modelTxt = "C:\\Users\\ATWER\\Desktop\\lenet_train_test.prototxt"; string modelBin = "C:\\Users\\ATWER\\Desktop\\lenet_iter_10000.caffemodel"; string imgFileName = "C:\\Users\\ATWER\\Desktop\\9.png"; //read image Mat imgSrc = imread(imgFileName); if (imgSrc.empty()) { cout << "Failed to read image " << imgFileName << endl; exit(-1); } Mat img; cvtColor(imgSrc, img, COLOR_BGR2GRAY); //LeNet accepts 28*28 gray image resize(img, img, Size(28, 28)); bitwise_not(img, img); img /= 255; //transfer image(1*28*28) to blob data with 4 dimensions(1*1*28*28) Mat inputBlob = dnn::blobFromImage(img); dnn::Net net; try { net = dnn::readNetFromCaffe(modelTxt, modelBin); } catch (cv::Exception& ee) { cerr << "Exception: " << ee.what() << endl; if (net.empty()) { cout << "Can't load the network by using the flowing files:" << endl; cout << "modelTxt: " << modelTxt << endl; cout << "modelBin: " << modelBin << endl; exit(-1); } } Mat pred; net.setInput(inputBlob, "data");//set the network input, "data" is the name of the input layer pred = net.forward("prob");//compute output, "prob" is the name of the output layer cout << pred << endl; int classId; double classProb; getMaxClass(pred, &classId, &classProb); cout << "Best Class: " << classId << endl; cout << "Probability: " << classProb * 100 << "%" << endl; }
3. 建立動態連結庫
參考https://blog.csdn.net/qq_30139555/article/details/103621955
class.h
#include #include <opencv2/opencv.hpp> #include <opencv2/dnn/dnn.hpp> using namespace std; using namespace cv; using namespace cv::dnn; extern "C" _declspec(dllexport) void Classfication(char* imgpath, char* result);
在此處卡的最久,原本我寫的是Classfication(string imgpath, stringresult),生成dll時沒問題,呼叫時總是System.AccessViolationException: 嘗試讀取或寫入受保護的記憶體。後來發現要寫成指標的形式。
class.cpp
#include #include <opencv2/opencv.hpp> #include <opencv2/dnn/dnn.hpp> #include "class.h" using namespace std; using namespace cv; using namespace cv::dnn; /* Find best class for the blob (i. e. class with maximal probability) */ static void getMaxClass(const Mat& probBlob, int* classId, double* classProb) { Mat probMat = probBlob.reshape(1, 1); Point classNumber; minMaxLoc(probMat, NULL, classProb, NULL, &classNumber); *classId = classNumber.x; } void Classfication(char* imgpath, char* result) { string res = ""; string modelTxt = "C:\\Users\\ATWER\\Desktop\\lenet_train_test.prototxt"; string modelBin = "C:\\Users\\ATWER\\Desktop\\lenet_iter_10000.caffemodel"; //string imgFileName = "C:\\Users\\ATWER\\Desktop\\9.png"; string imgFileName = imgpath; //read image Mat imgSrc = imread(imgFileName); if (imgSrc.empty()) { cout << "Failed to read image " << imgFileName << endl; exit(-1); } Mat img; cvtColor(imgSrc, img, COLOR_BGR2GRAY); //LeNet accepts 28*28 gray image resize(img, img, Size(28, 28)); bitwise_not(img, img); img /= 255; //transfer image(1*28*28) to blob data with 4 dimensions(1*1*28*28) Mat inputBlob = dnn::blobFromImage(img); dnn::Net net; try { net = dnn::readNetFromCaffe(modelTxt, modelBin); } catch (cv::Exception& ee) { cerr << "Exception: " << ee.what() << endl; if (net.empty()) { cout << "Can't load the network by using the flowing files:" << endl; cout << "modelTxt: " << modelTxt << endl; cout << "modelBin: " << modelBin << endl; exit(-1); } } Mat pred; net.setInput(inputBlob, "data");//set the network input, "data" is the name of the input layer pred = net.forward("prob");//compute output, "prob" is the name of the output layer int classId;
double classProb;
getMaxClass(pred, &classId, &classProb); res += to_string(classId); res += '|'; res += to_string(classProb); strcpy_s(result, 15, res.c_str()); }
4. 呼叫動態連結庫
根據資料的長度申請非託管空間參考:https://blog.csdn.net/xiaoyong_net/article/details/50178021
文中說:“一定要加1,否則後面是亂碼,原因未找到 ”,應該是列印字串時會列印到“\n”為止,沒有遇到\n會一直列印下去。.Length方法沒有計算"\n",+1的空間用於存放“\n”。
using System; using System.Runtime.InteropServices; namespace Test { class Program { [DllImport("E:/c++project/caffedll/x64/Debug/caffedll.dll", EntryPoint = "Classfication")] unsafe private static extern void Classfication(IntPtr imgpath, IntPtr result); private static IntPtr mallocIntptr(string strData) { //先將字串轉化成位元組方式 Byte[] btData = System.Text.Encoding.Default.GetBytes(strData); //申請非拖管空間 IntPtr m_ptr = Marshal.AllocHGlobal(btData.Length); //給非拖管空間清0 Byte[] btZero = new Byte[btData.Length + 1]; //一定要加1,否則後面是亂碼,原因未找到 Marshal.Copy(btZero, 0, m_ptr, btZero.Length); //給指標指向的空間賦值 Marshal.Copy(btData, 0, m_ptr, btData.Length); return m_ptr; } private static IntPtr mallocIntptr(int length) { //申請非拖管空間 IntPtr m_ptr = Marshal.AllocHGlobal(length); //給非拖管空間清0 Byte[] btZero = new Byte[length + 1]; //一定要加1,否則後面是亂碼,原因未找到 Marshal.Copy(btZero, 0, m_ptr, btZero.Length); //給指標指向的空間賦值 Marshal.Copy(btZero, 0, m_ptr, length); return m_ptr; } static void Main(string[] args) { string s = "C:\\Users\\ATWER\\Desktop\\9.png"; IntPtr ptrFileName; IntPtr res; //根據資料的長度申請非託管空間 ptrFileName = mallocIntptr(s); res = mallocIntptr(50); Classfication(ptrFileName, res); string result = Marshal.PtrToStringAnsi(res); string[] a = result.Split('|'); Console.WriteLine("class:"+a[0]+"\n"+"score:"+a[1]); Marshal.FreeHGlobal(res); } } }