YOLO_tensorflow-master執行與參考記錄 模型儲存與執行
YOLO_tensorflow-master程式碼下載:
https://github.com/gliese581gg/YOLO_tensorflow
1.模型介紹
YOLO_tensorflow
(Version 0.2, Last updated :2016.02.16)
1.Introduction
This is tensorflow implementation of the YOLO:Real-Time Object Detection
It can only do predictions using pretrained YOLO_small & YOLO_tiny network for now.
I'm gonna support training later.
I extracted weight values from darknet's (.weight) files.
2.Install
(1) Download code
(2) Download YOLO weight file from
#YOLO_tensorflow
(Version 0.3, Last updated :2017.02.21)
###1.Introduction
This is tensorflow implementation of the YOLO:Real-Time Object Detection
It can only do predictions using pretrained YOLO_small & YOLO_tiny network for now.
I extracted weight values from darknet's (.weight) files.
My code does not support training. Use darknet for training.
###2.Install (1) Download code
(2) Download YOLO weight file from
(3) Put the 'YOLO_(version).ckpt' in the 'weight' folder of downloaded code
###3.Usage
(1) direct usage with default settings (display on console, show output image, no output file writing)
python YOLO_(small or tiny)_tf.py -fromfile (input image filename)
(2) direct usage with custom settings
python YOLO_(small or tiny)_tf.py argvs
where argvs are
-fromfile (input image filename) : input image file
-disp_console (0 or 1) : whether display results on terminal or not
-imshow (0 or 1) : whether display result image or not
-tofile_img (output image filename) : output image file
-tofile_txt (output txt filename) : output text file (contains class, x, y, w, h, probability)
(3) import on other scripts
import YOLO_(small or tiny)_tf
yolo = YOLO_(small or tiny)_tf.YOLO_TF()
yolo.disp_console = (True or False, default = True)
yolo.imshow = (True or False, default = True)
yolo.tofile_img = (output image filename)
yolo.tofile_txt = (output txt filename)
yolo.filewrite_img = (True or False, default = False)
yolo.filewrite_txt = (True of False, default = False)
yolo.detect_from_file(filename)
yolo.detect_from_cvmat(cvmat)
###4.Requirements
- Tensorflow
- Opencv2
###5.Copyright
According to the LICENSE file of the original code,
- Me and original author hold no liability for any damages
- Do not use this on commercial!
###6.Changelog 2016/02/15 : First upload!
2016/02/16 : Added YOLO_tiny, Fixed bug that ignores one of the boxes in grid when both boxes detected valid objects
2016/08/26 : Uploaded weight file converter! (darknet weight -> tensorflow ckpt)
(3) Put the 'YOLO_(version).ckpt' in the 'weight' folder of downloaded code
3.Usage
(1) direct usage with default settings (display on console, show output image, no output file writing)
python YOLO_(small or tiny)_tf.py -fromfile (input image filename)
(2) direct usage with custom settings
python YOLO_(small or tiny)_tf.py argvs
where argvs are
-fromfile (input image filename) : input image file
-disp_console (0 or 1) : whether display results on terminal or not
-imshow (0 or 1) : whether display result image or not
-tofile_img (output image filename) : output image file
-tofile_txt (output txt filename) : output text file (contains class, x, y, w, h, probability)
(3) import on other scripts
import YOLO_(small or tiny)_tf
yolo = YOLO_(small or tiny)_tf.YOLO_TF()
yolo.disp_console = (True or False, default = True)
yolo.imshow = (True or False, default = True)
yolo.tofile_img = (output image filename)
yolo.tofile_txt = (output txt filename)
yolo.filewrite_img = (True or False, default = False)
yolo.filewrite_txt = (True of False, default = False)
yolo.detect_from_file(filename)
yolo.detect_from_cvmat(cvmat)
4.Requirements
- Tensorflow
- Opencv2
5.Copyright
According to the LICENSE file of the original code,
- Me and original author hold no liability for any damages
- Do not use this on commercial!
6.Changelog
2016/02/15 : First upload!
2016/02/16 : Added YOLO_tiny, Fixed bug that ignores one of the boxes in grid when both boxes detected valid objects
2016/08/26 : Uploaded weight file converter! (darknet weight -> tensorflow ckpt)
2.模型使用
我使用YOLO時
出現
cv2.imread('./test/person.jpg')
讀取到的圖片為None.
解決辦法 在程式碼最前面加
import cv2
如果把import cv2新增到import YOLO_tiny_tf後面一樣報None.獲取不到圖片
下面是我調研yolo的所有程式碼。
#encoding:utf-8 import cv2 import YOLO_tiny_tf yolo = YOLO_tiny_tf.YOLO_TF() yolo.disp_console = True yolo.imshow = True yolo.tofile_img = './test/ttt.jpg' yolo.tofile_txt = './test/ttt.txt' yolo.filewrite_img = True yolo.filewrite_txt = True filename = './test/person.jpg' # 讀入影象 #im = cv2.imread('./test/person.jpg')yolo.detect_from_file(filename) #yolo.detect_from_cvmat(im) 執行結果
3.模型儲存與執行
(1).將YOLO_ting_tf.py中的輸入新增名字input,程式碼如下:
def build_networks(self): if self.disp_console : print "Building YOLO_tiny graph..." self.x = tf.placeholder('float32',[None,448,448,3],name="input")
(2).執行save_graph模型和權重一起儲存
import os import cv2 import tensorflow as tf import numpy as np from tensorflow.python.framework import test_util import freeze_graph from YOLO_tiny_tf import YOLO_TF def save_graph(sess,output_path,checkpoint,checkpoint_state_name,input_graph_name,output_graph_name): checkpoint_prefix = os.path.join(output_path,checkpoint) saver = tf.train.Saver(tf.all_variables()) saver.save(sess, checkpoint_prefix, global_step=0,latest_filename=checkpoint_state_name) tf.train.write_graph(sess.graph.as_graph_def(),output_path, input_graph_name) # We save out the graph to disk, and then call the const conversion # routine. input_graph_path = os.path.join(output_path, input_graph_name) input_saver_def_path = "" input_binary = False input_checkpoint_path = checkpoint_prefix + "-0" output_node_names = "19_fc" restore_op_name = "save/restore_all" filename_tensor_name = "save/Const:0" output_graph_path = os.path.join(output_path, output_graph_name) clear_devices = False freeze_graph.freeze_graph(input_graph_path, input_saver_def_path, input_binary, input_checkpoint_path, output_node_names, restore_op_name, filename_tensor_name, output_graph_path,clear_devices, "") yolo=YOLO_TF() #with open("weights/small_model.pb","wb") as f: # f.write(yolo.sess.graph.as_graph_def().SerializeToString()) save_graph(yolo.sess,"/home/acer/pbMake/yolo","saved_checkpoint","checkpoint_state","yoloting_input_graph.pb","yoloting_output_graph.pb")(3).讀取剛才保持的yoloting_output_graph.pb,進行測試和檢測
from __future__ import absolute_import from __future__ import division from __future__ import print_function import cv2 import tensorflow as tf import numpy as np def iou(box1, box2): tb = min(box1[0] + 0.5 * box1[2], box2[0] + 0.5 * box2[2]) - max(box1[0] - 0.5 * box1[2], box2[0] - 0.5 * box2[2]) lr = min(box1[1] + 0.5 * box1[3], box2[1] + 0.5 * box2[3]) - max(box1[1] - 0.5 * box1[3], box2[1] - 0.5 * box2[3]) if tb < 0 or lr < 0: intersection = 0 else: intersection = tb * lr return intersection / (box1[2] * box1[3] + box2[2] * box2[3] - intersection) def interpret_output(output): alpha = 0.1 threshold = 0.2 iou_threshold = 0.5 num_class = 20 num_box = 2 grid_size = 7 classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"] w_img = 640 h_img = 480 probs = np.zeros((7, 7, 2, 20)) class_probs = np.reshape(output[0:980], (7, 7, 20)) scales = np.reshape(output[980:1078], (7, 7, 2)) boxes = np.reshape(output[1078:], (7, 7, 2, 4)) offset = np.transpose(np.reshape(np.array([np.arange(7)] * 14), (2, 7, 7)), (1, 2, 0)) boxes[:, :, :, 0] += offset boxes[:, :, :, 1] += np.transpose(offset, (1, 0, 2)) boxes[:, :, :, 0:2] = boxes[:, :, :, 0:2] / 7.0 boxes[:, :, :, 2] = np.multiply(boxes[:, :, :, 2], boxes[:, :, :, 2]) boxes[:, :, :, 3] = np.multiply(boxes[:, :, :, 3], boxes[:, :, :, 3]) boxes[:, :, :, 0] *= w_img boxes[:, :, :, 1] *= h_img boxes[:, :, :, 2] *= w_img boxes[:, :, :, 3] *= h_img for i in range(2): for j in range(20): probs[:, :, i, j] = np.multiply(class_probs[:, :, j], scales[:, :, i]) filter_mat_probs = np.array(probs >= threshold, dtype='bool') filter_mat_boxes = np.nonzero(filter_mat_probs) boxes_filtered = boxes[filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]] probs_filtered = probs[filter_mat_probs] classes_num_filtered = np.argmax(filter_mat_probs, axis=3)[ filter_mat_boxes[0], filter_mat_boxes[1], filter_mat_boxes[2]] argsort = np.array(np.argsort(probs_filtered))[::-1] boxes_filtered = boxes_filtered[argsort] probs_filtered = probs_filtered[argsort] classes_num_filtered = classes_num_filtered[argsort] for i in range(len(boxes_filtered)): if probs_filtered[i] == 0: continue for j in range(i + 1, len(boxes_filtered)): if iou(boxes_filtered[i], boxes_filtered[j]) > iou_threshold: probs_filtered[j] = 0.0 filter_iou = np.array(probs_filtered > 0.0, dtype='bool') boxes_filtered = boxes_filtered[filter_iou] probs_filtered = probs_filtered[filter_iou] classes_num_filtered = classes_num_filtered[filter_iou] result = [] for i in range(len(boxes_filtered)): result.append( [classes[classes_num_filtered[i]], boxes_filtered[i][0], boxes_filtered[i][1], boxes_filtered[i][2], boxes_filtered[i][3], probs_filtered[i]]) return result def show_results(img, results): filewrite_img = False filewrite_txt = True img_cp = img.copy() if filewrite_txt: ftxt = open('./test/xsss.txt', 'w') for i in range(len(results)): x = int(results[i][1]) y = int(results[i][2]) w = int(results[i][3]) // 2 h = int(results[i][4]) // 2 cv2.rectangle(img_cp, (x - w, y - h), (x + w, y + h), (0, 255, 0), 2) cv2.rectangle(img_cp, (x - w, y - h - 20), (x + w, y - h), (125, 125, 125), -1) cv2.putText(img_cp, results[i][0] + ' : %.2f' % results[i][5], (x - w + 5, y - h - 7), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 0), 1) ftxt.write(results[i][0] + ',' + str(x) + ',' + str(y) + ',' + str(w) + ',' + str(h) + ',' + str( results[i][5]) + '\n') cv2.imwrite('./test/xlsld.jpg', img_cp) # produces the expected result. with tf.Graph().as_default(): output_graph_def = tf.GraphDef() output_graph_path = '/home/acer/pbMake/yolo/yoloting_output_graph.pb' x = tf.placeholder('float32', [None, 448, 448, 3]) with open(output_graph_path, "rb") as f: output_graph_def.ParseFromString(f.read()) _ = tf.import_graph_def(output_graph_def, name="") with tf.Session() as sess: tf.initialize_all_variables().run() input_x = sess.graph.get_tensor_by_name("input:0") print(input_x) output = sess.graph.get_tensor_by_name("19_fc:0") print(output) filename = './test/person.jpg' img = cv2.imread(filename) h_img, w_img, _ = img.shape img_resized = cv2.resize(img, (448, 448)) img_RGB = cv2.cvtColor(img_resized, cv2.COLOR_BGR2RGB) img_resized_np = np.asarray(img_RGB) inputs = np.zeros((1, 448, 448, 3), dtype='float32') inputs[0] = (img_resized_np / 255.0) * 2.0 - 1.0 #input_node = sess.graph.get_operation_by_name("input") in_dict = {input_x: inputs} net_output = sess.run(output, {input_x: inputs}) print("net_output", net_output) #net_output = sess.run(output_node, feed_dict=in_dict) result = interpret_output(net_output[0]) show_results(img, result)
執行結果,與直接執行模型的結果一樣。
2016-12-26 10:28 733人閱讀 評論(0) 收藏 舉報 分類:本系列文章會持續更新,主要會分以下幾個部分: 1、darknet下的yolo原始碼解讀 2、將yolo移植到mxnet下 3、模型壓縮與加速 白天需要工作,只有晚上時間寫,所以可能更新速度有點慢,還有就是該系列博文不一定會嚴格按照以上三點的順序來寫,也可能移植到caffe下,在caffe下進行壓縮和加速。
一、訓練
我用的是VOC2007的資料集,下載指令如下:
$curl -O http://pjreddie.com/media/files/VOCtrainval_06-Nov-2007.tar
$curl -O http://pjreddie.com/media/files/VOCtest_06-Nov-2007.tar
$tar xf VOCtrainval_06-Nov-2007.tar
$tar xf VOCtest_06-Nov-2007.tar
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
執行以下程式碼,將.xml
檔案轉換成.txt
檔案,以備YOLO訓練時資料解析:
import xml.etree.ElementTree as ET
import pickle
import os
from os import listdir, getcwd
from os.path import join
sets=[('2007', 'train'), ('2007', 'val')]
classes = ["aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor"]
def convert(size, box):
dw = 1./size[0]
dh = 1./size[1]
x = (box[0] + box[1])/2.0
y = (box[2] + box[3])/2.0
w = box[1] - box[0]
h = box[3] - box[2]
x = x*dw
w = w*dw
y = y*dh
h = h*dh
return (x,y,w,h)
def convert_annotation(year, image_id):
in_file = open('VOCdevkit/VOC%s/Annotations/%s.xml'%(year, image_id))
out_file = open('VOCdevkit/VOC%s/labels/%s.txt'%(year, image_id), 'w')
tree=ET.parse(in_file)
root = tree.getroot()
size = root.find('size')
w = int(size.find('width').text)
h = int(size.find('height').text)
for obj in root.iter('object'):
difficult = obj.find('difficult').text
cls = obj.find('name').text
if cls not in classes or int(difficult) == 1:
continue
cls_id = classes.index(cls)
xmlbox = obj.find('bndbox')
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text), float(xmlbox.find('ymax').text))
bb = convert((w,h), b)
out_file.write(str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n')
wd = getcwd()
for year, image_set in sets:
if not os.path.exists('VOCdevkit/VOC%s/labels/'%(year)):
os.makedirs('VOCdevkit/VOC%s/labels/'%(year))
image_ids = open('VOCdevkit/VOC%s/ImageSets/Main/%s.txt'%(year, image_set)).read().strip().split()
list_file = open('%s_%s.txt'%(year, image_set), 'w')
for image_id in image_ids:
list_file.write('%s/VOCdevkit/VOC%s/JPEGImages/%s.jpg\n'%(wd, year, image_id))
convert_annotation(year, image_id)
list_file.close()
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
在yolo.c
中找到以下程式碼,並修改:
//指向你剛才生成的train.txt檔案
char *train_images = "/data/voc/train.txt";
//新建個資料夾,然後指向這裡就可以,訓練生成的模型檔案會儲存在這裡
char *backup_directory = "/home/pjreddie/backup/";
- 1
- 2
- 3
- 4
- 1
- 2
- 3
- 4
執行以下指令開始訓練:
./darknet yolo train cfg/yolo.train.cfg extraction.conv.weights
- 1
- 1
二、原始碼解讀
1.首先我們看一下訓練的資料流,從main函式開始看,該函式在darknet.c
檔案中:
//darknet.c
int main(int argc, char **argv)
{
//test_resize("data/bad.jpg");
//test_box();
//test_convolutional_layer();
if(argc < 2){
fprintf(stderr, "usage: %s <function>\n", argv[0]);
return 0;
}
gpu_index = find_int_arg(argc, argv, "-i", 0);
if(find_arg(argc, argv, "-nogpu")) {
gpu_index = -1;
}
#ifndef GPU
gpu_index = -1;
#else
if(gpu_index >= 0){
cuda_set_device(gpu_index);
}
#endif
if (0 == strcmp(argv[1], "average")){
average(argc, argv);
} else if (0 == strcmp(argv[1], "yolo")){
//第一個引數是yolo,所以跳轉到run_yolo函式
run_yolo(argc, argv);
} else {
fprintf(stderr, "Not an option: %s\n", argv[1]);
}
return 0;
}
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
//yolo.c
void run_yolo(int argc, char **argv)
{
char *prefix = find_char_arg(argc, argv, "-prefix", 0);
float thresh = find_float_arg(argc, argv, "-thresh", .2);
int cam_index = find_int_arg(argc, argv, "-c", 0);
int frame_skip = find_int_arg(argc, argv, "-s", 0);
if(argc < 4){
fprintf(stderr, "usage: %s %s [train/test/valid] [cfg] [weights (optional)]\n", argv[0], argv[1]);
return;
}
char *cfg = argv[3];
char *weights = (argc > 4) ? argv[4] : 0;
char *filename = (argc > 5) ? argv[5]: 0;
if(0==strcmp(argv[2], "test")) test_yolo(cfg, weights, filename, thresh);
//第二個引數是train,所以跳轉到了train_yolo函式
else if(0==strcmp(argv[2], "train")) train_yolo(cfg, weights);
voc_names, 20, frame_skip, prefix);
}
void train_yolo(char *cfgfile, char *weightfile)
{
char *train_images = "/data/voc/train.txt";
char *backup_directory = "/home/pjreddie/backup/";
/*srand函式是隨機數發生器的初始化函式。
srand和rand()配合使用產生偽隨機數序列。rand函式在產生隨機數前,需要系統提供的生成偽隨機數序列的
種子,rand根據這個種子的值產生一系列隨機數。如果系統提供的種子沒有變化,每次呼叫rand函式生成的偽
隨機數序列都是一樣的。*/
srand(time(0));
/*第三個引數是:`cfg/yolo.train.cfg`,`basecfg()`這個函式把`cfg/yolo.train.cfg`變成了
`yolo0train.cfg`,然後用base指標指向`yolo0train.cfg`*/
char *base = basecfg(cfgfile);
//列印"yolo"字樣
printf("%s\n", base);
float avg_loss = -1;
//解析網路構架,下面會仔細分析該函式
network net = parse_network_cfg(cfgfile);
//載入預訓練引數,下面會仔細分析該函式
if(wei