1. 程式人生 > >opencv----呼叫TensorFlow模型

opencv----呼叫TensorFlow模型

OpenCV 基於Inception模型影象分類

原創: gloomyfish  OpenCV學堂  4月21日

Network in Network(NIN)

要介紹Inception網路結構首先應該介紹一下NIN(Network in Network)網路模型,2014年新加坡國立大學發表了一篇關於計算機視覺影象分類的論文,提到採用了一種新的網路結構NIN實現影象分類,該論文的第二作者顏水成畢業於北京大學數學系,現任360人工智慧研究院院長與首席科學家。NIN主要思想是認為CNN網路中卷積濾波是基於線性濾波器實現的,抽象能力不夠,所以一般是用一大堆filter把所有特徵都找出來,但是這樣就導致網路引數過大,論文作者提出通過MLP(多個權重階層組成+一個非線性啟用函式)對輸入區域通過MLP產生一個輸出feature map,然後繼續滑動MLP視窗,對比如下:

這樣做有兩個好處,

  1. MLP可以共享引數,減少引數總數

  2. 對每個區域性感受野神經元實現更加複雜計算,提升能力

論文中提到NIN網路完整結構如下:

包含了三個MLP卷積層與一個全域性池化層。

前方高能預警,乾貨在後面!

Inception v1

受到這篇文章的影響與啟發,谷歌在2014也提出一個新的網路模型結構Inception網路也就是大家熟知v1網路,其主要貢獻在於實現了NIN網路層數的增加,並且在訓練各個網路時候為了提高收斂,考慮中間層的輸出與最終分類錯誤。只是中間層不同,最初inception網路的中間層為:

後來發現3x3與5x5的卷積計算耗時很長,而且輸出導致卷積厚度增加,如果層數過度將導致卷積網路不可控制,於是就在3x3與5x5的卷積之前分別加上1x1的卷積做降維,修改後的結構如下:

最終得到v1版本的網路結構如下:

Inception v2 and Inception v3

於是在v1的基礎上作者繼續工作,加入了BN層,對大於3x3的卷積用一系列小的卷積進行替代,比如7x7可以被1x7與7x1替代兩個小卷積核,5x5可以被1x5與5x1兩個小卷積核替代,這樣就得到Inception v2的版本。於是作者繼續對此網路結構各種優化調整,最終又得到了Inception v3版本

Inception v4

Inception v4一個最大的改動就是引入了殘差網路結構,對原有的網路結構進行優化,得到v1與v2的殘差版本網路結構,最終得到一個更加優化的v4模型,完整的v4結構:

對應的Block A、B、C結構如下:Inception-A

Inception-B

Inception-C

v1模型加殘差網路結構

OpenCV DNN模組中使用Inception模型

  1. 下載Inception預訓練網路模型

  2. 使用OpenCV DNN模組相關API載入模型

  3. 執行Inception網路實現影象分類 完整的程式碼實現如下:

  1. #include <opencv2/opencv.hpp>

  2. #include <opencv2/dnn.hpp>

  3. #include <iostream>

  4. /******************************************************

  5. *

  6. * 作者:賈志剛

  7. * QQ: 57558865

  8. * OpenCV DNN 完整視訊教程:

  9. * http://edu.51cto.com/course/11516.html

  10. *

  11. ********************************************************/

  12. using namespace cv;

  13. using namespace cv::dnn;

  14. using namespace std;

  15. String labels_txt_file = "D:/android/opencv_tutorial/data/models/inception5h/imagenet_comp_graph_label_strings.txt";

  16. String tf_pb_file = "D:/android/opencv_tutorial/data/models/inception5h/tensorflow_inception_graph.pb";

  17. vector<String> readClassNames();

  18. int main(int argc, char** argv) {

  19.    Mat src = imread("D:/vcprojects/images/twocat.png");

  20.    if (src.empty()) {

  21.        printf("could not load image...\n");

  22.        return -1;

  23.    }

  24.    namedWindow("input", CV_WINDOW_AUTOSIZE);

  25.    imshow("input", src);

  26.    vector<String> labels = readClassNames();

  27.    Mat rgb;

  28.    cvtColor(src, rgb, COLOR_BGR2RGB);

  29.    int w = 224;

  30.    int h = 224;

  31.    // 載入網路

  32.    Net net = readNetFromTensorflow(tf_pb_file);

  33.    if (net.empty()) {

  34.        printf("read caffe model data failure...\n");

  35.        return -1;

  36.    }

  37.    Mat inputBlob = blobFromImage(src, 1.0f, Size(224, 224), Scalar(), true, false);

  38.    inputBlob -= 117.0; // 均值

  39.    // 執行影象分類

  40.    Mat prob;

  41.    net.setInput(inputBlob, "input");

  42.    prob = net.forward("softmax2");

  43.    // 得到最可能分類輸出

  44.    Mat probMat = prob.reshape(1, 1);

  45.    Point classNumber;

  46.    double classProb;

  47.    minMaxLoc(probMat, NULL, &classProb, NULL, &classNumber);

  48.    int classidx = classNumber.x;

  49.    printf("\n current image classification : %s, possible : %.2f", labels.at(classidx).c_str(), classProb);

  50.    // 顯示文字

  51.    putText(src, labels.at(classidx), Point(20, 20), FONT_HERSHEY_SIMPLEX, 1.0, Scalar(0, 0, 255), 2, 8);

  52.    imshow("Image Classification", src);

  53.    imwrite("D:/result.png", src);

  54.    waitKey(0);

  55.    return 0;

  56. }

  57. std::vector<String> readClassNames()

  58. {

  59.    std::vector<String> classNames;

  60.    std::ifstream fp(labels_txt_file);

  61.    if (!fp.is_open())

  62.    {

  63.        printf("could not open file...\n");

  64.        exit(-1);

  65.    }

  66.    std::string name;

  67.    while (!fp.eof())

  68.    {

  69.        std::getline(fp, name);

  70.        if (name.length())

  71.            classNames.push_back(name);

  72.    }

  73.    fp.close();

  74.    return classNames;

  75. }

輸入原圖:

測試結果:

關鍵是速度很快,比VGG快N多,基本秒出結果!