1. 程式人生 > 實用技巧 >YOLOv3訓練自己的檢測模型

YOLOv3訓練自己的檢測模型


YOLOv3訓練自己的檢測模型

目標檢測迷途小書童10個月前 (12-16)0評論

軟硬體環境

  • Intel(R) Xeon(R) CPU E5-1607 v4 @ 3.10GHz
  • ubuntu 18.04 64bit
  • GTX 1070Ti 8G/32G
  • darknet git version
  • cuda 8.0
  • opencv 3.4.3
  • miniconda with python 3.7.1

前言

先說說我這的具體情況,需要檢測的物件是老鼠,手上已經有的資料是圖片以及圖片中老鼠的座標位置(xy,widthheight)。要做的就是利用這些資訊,通過YOLOv3訓練出老鼠的檢測器,應用到實際的場景中去。

VOC資料集的組織結構

檢測模型的訓練依照VOC資料集的訓練方法進行。首先來看看VOC資料集訓練資料夾的目錄結構

其中,

  • Annotations: 這裡存放所有的xml檔案, 它的檔案格式如下
<annotation>
    <folder>VOC2007</folder>
    <filename>1548339112.jpg</filename>
    <size>
        <width>1920</width>
        <height>1080</height>
        <depth>3</depth>
    </size>
    <object>
        <name>mouse</name>
        <pose>Unspecified</pose>
        <truncated>0</truncated>
        <difficult>0</difficult>
        <bndbox>
            <xmin>522</xmin>
            <ymin>629</ymin>
            <xmax>588</xmax>
            <ymax>699</ymax>
        </bndbox>
    </object>
</annotation>

size標籤下的widthheight是指圖片的大小,object下的xminyminxmaxymax則是物體的座標資訊,如果有多個物體,object標籤對應會有多個。

  • ImageSets: 關注Main資料夾下的train.txttrainval.txtval.txttest.txttrain.txt必須,其它可以不要,它們的格式都是一樣的,記錄的是圖片的檔名,不帶副檔名,如下
1548339040
1548339112
1548339126
1548339138
1548339675
1548339682
1548339690
1548339698
1548339706
1548339712
1548339864
1548339870
1548339874
1548339880
1548339887
1548340230
  • JPEGImages: 訓練圖片的儲存位置

  • labels: 這個資料夾下的內容可以通過指令碼生成,一張圖片對應一個txt檔案,它的內容如下

0 0.2921875 0.6217592592592592 0.021875 0.05092592592592593

老鼠檢測模型的訓練步驟

既然已經有了圖片已及老鼠對應的座標資訊,所以手動標註這一步就可以省略掉。整體的訓練應該分成以下幾步

  • 根據座標資訊生成xml標註檔案,一張圖片對應一個xml檔案
  • 分別生成包含圖片檔名資訊的train.txtval.txttrainval.txttest.txt
  • 生成labels資料夾下的txt檔案
  • 修改配置檔案data/voc.namescfg/voc.datacfg/yolov3-voc.cfg
  • 開始訓練

訓練過程

生成Annotations下的xml檔案

由於手頭已經有了具體的座標資訊了,我把它們儲存到了mysql資料庫中,然後利用相應的圖片及座標資訊生成對應的xml檔案,程式碼存放在https://code.xugaoxiang.com/longjingtech/YOLOv3XmlGenerator

生成ImageSets/Main下的txt檔案

train.txt為例,其它的都一樣,要處理之前,可以把訓練的、校驗的、測試的圖片分別放在不同的資料夾下,這樣可以大大方便指令碼處理,具體情形需要你自行修改

# -*- coding: utf-8 -*-


"""
@author: Xu Gaoxiang
@license: Apache V2
@email: [email protected]
@site: https://www.xugaoxiang.com
@software: PyCharm
@file: mainTxtGenerator.py
@time: 2019/1/25 17:57
"""

import os

with open('train.txt', 'a') as f:
    source_folder = 'VOC2007/JPEGImages'

    file_list = os.listdir(source_folder)

    for file_obj in file_list:
        print('file: {}'.format(file_obj))
        file_name, file_extend = os.path.splitext(file_obj)
        f.write(file_name + '\n')

生成labeltxt檔案

darknet工程下scripts目錄下有個voc_label.py檔案,我們通過修改它來實現,需要將它移動到darknet下執行

import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join

sets=[('2007', 'train'), ('2007', 'val'), ('2007', 'test')]

classes = ["mouse"]


def convert(size, box):
    dw = 1./(size[0])
    dh = 1./(size[1])
    x = (box[0] + box[1])/2.0 - 1
    y = (box[2] + box[3])/2.0 - 1
    w = box[1] - box[0]
    h = box[3] - box[2]
    x = x*dw
    w = w*dw
    y = y*dh
    h = h*dh
    return (x,y,w,h)

def convert_annotation(year, image_id):
    in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
    out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
    tree=ET.parse(in_file)
    root = tree.getroot()
    size = root.find('size')
    w = int(size.find('width').text)
    h = int(size.find('height').text)

    for obj in root.iter('object'):
        difficult = obj.find('difficult').text
        cls = obj.find('name').text
        if cls not in classes or int(difficult)==1:
            continue
        cls_id = classes.index(cls)
        xmlbox = obj.find('bndbox')
        b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
        bb = convert((w,h), b)
        out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')

wd = getcwd()

for year, image_set in sets:
    if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
        os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
    image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
    list_file = open('%s_%s.txt'%(year, image_set), 'w')
    for image_id in image_ids:
        list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
        convert_annotation(year, image_id)
    list_file.close()

os.system("cat 2007_train.txt 2007_val.txt > train.txt")

除此以外,在darknet根下生成了2007_train.txt等檔案,內容是圖片的完整路徑

修改配置檔案

data/voc.names存放的是檢測物件的名稱,如我這裡的mouse,需要檢測幾個就寫幾個

cfg/voc.data內容如下,因為我只檢測老鼠,所以classes=1,其它路徑自行修改

classes= 1
train  = /home/longjing/Work/yolo3/darknet/2007_train.txt
valid  = /home/longjing/Work/yolo3/darknet/2007_val.txt
names = data/voc.names
backup = backup_mouse

cfg/yolov3-voc.cfg主要修改classes,根據自己的硬體情況調整batchsubdivisions的值

開始訓練

使用如下命令進行訓練

cd darknet
wget https://pjreddie.com/media/files/darknet53.conv.74
./darknet detector train cfg/voc.data cfg/yolov3-voc.cfg darknet53.conv.74

在這裡我碰到了darknet: ./src/parser.c:312: parse_yolo: Assertion l.outputs == params.inputs failed.的錯誤,解決的方法是修改cfg/yolov3-voc.cfg中的filters,將其值改為18,這個參考值來自網路,計算方法3*(classes+5)

小結

一般情況下,都是拿到包含某種待檢測物件的圖片,然後需要根據圖片進行訓練得到檢測模型。這樣的話,就需要手動標註,得到相應的座標資訊,生成xml檔案,GUI標註工具labelImg就是幹這樣的事情。我上面的xml生成器其實就是乾的labelImg的活。

參考資料