1. 程式人生 > 其它 >yolov3_tensorflow實現、修改記錄

yolov3_tensorflow實現、修改記錄

技術標籤:神經網路與深度學習

yolov3_tensorflow

環境搭建

回憶版,寫個大概、重點,主要是按照作者指導的步驟進行,在需要注意的地方做下說明。
建立新環境命令:
大神YunYang1994/tensorflow-yolov3原專案連結:點選跳轉

conda create --name yolov3_tf python=3.6 ipykernel -y

然後是安裝專案中requiremen檔案中的庫
執行命令時會發現
import tensorflow as tf時報錯
Failed to load the native TensorFlow runtime

解決辦法:提高tf的版本,問題是在gpu版本上出現的,故執行命令

pip install tensorflow-gpu==1.13.1

還有一些其他的小庫,據提示安裝:

pip install -i http://pypi.douban.com/simple/ --trusted-host pypi.douban.com 模組名

config.py等檔案的修改

其他儘量保持一致,訓練後的模型檔案修改
若訓練後的模型檔案位置為:

/home/wxd/tensorflow-yolov3-master/checkpoint/yolov3_test_loss=3.9814.ckpt-11

在TEST option部分作如下修改

__C.TEST.WEIGHT_FILE            = "./checkpoint/yolov3_test_loss=3.9814.ckpt-11"

由於config.py中有如下定義:

__C.YOLO.CLASSES                = "./data/classes/voc.names"

故修改檔案./data/classes/voc.names中的內容,為自己資料集中的類別

missing_hole
mouse_bite
open_circuit
short
spur
spurious_copper
dust
scratch

以及voc_annotation.py檔案中類別的修改(下面程式碼的line 7),此檔案也是將xml檔案轉成yolov3需要的txt資料格式,原函式不能實現,附上修改後成功的:

import os
import argparse
import xml.etree.ElementTree as ET

def convert_voc_annotation(data_path, data_type, anno_path, use_difficult_bbox=True):

    classes = ['missing_hole','mouse_bite','open_circuit','short','spur','spurious_copper','dust','scratch']
    
    path = os.path.join(data_path, 'ImageSets', 'Main')#wxd
    if not os.path.exists(path):
        os.makedirs(path)
    
    img_inds_file = os.path.join(data_path, 'ImageSets', 'Main', data_type + '.txt')
    if not os.path.exists(img_inds_file):
        file = open('img_inds_file','w')
        file.close()
    
    '''
    with open(img_inds_file, 'r') as f:
        txt = f.readlines()
        image_inds = [line.strip() for line in txt]
        print(image_inds)
    '''
    imagelist = os.listdir(os.path.join(data_path, 'JPEGImages'))
    #imagelist = imagelist[:-4]
    for i in range(len(imagelist)):
        imagelist[i] = imagelist[i][:-4]
    with open(anno_path, 'a') as f:
        print('open voc_train.txt is successful')
        for image_ind in imagelist:#image_inds:
            image_path = os.path.join(data_path, 'JPEGImages', image_ind + '.jpg')
            annotation = image_path
            print('annotation is:',annotation)
            label_path = os.path.join(data_path, 'Annotations', image_ind + '.xml')
            root = ET.parse(label_path).getroot()
            objects = root.findall('object')
            for obj in objects:
                '''
                difficult = obj.find('difficult').text.strip()
                if (not use_difficult_bbox) and(int(difficult) == 1):
                    continue
                '''
                bbox = obj.find('bndbox')
                class_ind = classes.index(obj.find('name').text.lower().strip())
                xmin = bbox.find('xmin').text.strip()
                xmax = bbox.find('xmax').text.strip()
                ymin = bbox.find('ymin').text.strip()
                ymax = bbox.find('ymax').text.strip()
                annotation += ' ' + ','.join([xmin, ymin, xmax, ymax, str(class_ind)])
            print(annotation)
            f.write(annotation + "\n")
    return len(imagelist)#image_inds


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", default="/home/wxd/jupyter/FPN_Tensorflow-master/data")
    parser.add_argument("--train_annotation", default="./data/dataset/voc_train.txt")
    parser.add_argument("--test_annotation",  default="./data/dataset/voc_test.txt")
    flags = parser.parse_args()

    #if os.path.exists(flags.train_annotation):os.remove(flags.train_annotation)
    #if os.path.exists(flags.test_annotation):os.remove(flags.test_annotation)

    num1 = convert_voc_annotation(os.path.join(flags.data_path, 'VOC_PCB_test'), 'voc_test', flags.train_annotation, False)
    #num2 = convert_voc_annotation(os.path.join(flags.data_path, 'train/VOCdevkit/VOC2012'), 'trainval', flags.train_annotation, False)
    #num3 = convert_voc_annotation(os.path.join(flags.data_path, 'VOC_PCB_test'),  'test', flags.test_annotation, False)
    print('=> The number of image for train is: %d\tThe number of image for train is',num1)# is:%d' %(num1 )+ num2, num3


train

解壓原始權重檔案:

cd checkpoint
tar -xvf yolov3_coco.tar.gz

轉換權重:

python3 convert_weight.py --train_from_coco

開始訓練:

python3 train.py

訓練時間較長,,,
訓練完一個epoch時的截圖
在這裡插入圖片描述

自己的資料集圖片大小為600*600的,還有哪裡要修改?

test