caffe python 圖片訓練識別 例項
本文主要是使用caffe python做圖片識別的示例包括訓練資料lmdb生成,訓練,以及模型測試,主要內容如下:
-
訓練,驗證資料lmdb生成,主要包括:樣本的預處理 (直方圖均衡化,resize),訓練樣本以及驗證樣本的lmdb的生成,以及mean_file mean.binaryproto生成
-
訓練驗證資料準備完成之後,就是模型的訓練
-
得到訓練模型之後,一般會進行本地測試以及從資料庫獲取url測試然後將結果寫到資料庫中
先上個程式碼的框架圖,說明見圖片(下面會有詳細的講解):
下面給出最終的識別結果:
注:本文做影象分類的時候大概是在2016年,第一個分類模型用的是
Alexnet
這個模型現在基本不怎麼用了。一般用的是googlenet v2版本
。
而且caffe的
model zoo
https://github.com/BVLC/caffe/wiki/Model-Zoo中有不少新的模型,比如Towards Principled Design of Deep Convolutional Networks: Introducing SimpNet
感興趣的可以多多嘗試下。
1. 訓練,驗證資料lmdb生成
-
對圖片進行預處理包括直方圖均衡化(Histogram equalization)以及resize到指定的大小,並生成lmdb格式,圖片以及對於的標籤(label)
-
按照一定的比例生成,訓練樣本lmdb以及驗證樣本lmdb,以及mean_file mean.binaryproto
-
在測試的時候,我們往往是從資料庫中讀取url以及id資訊,然後將url轉化為cv2 可以處理的圖片樣式,因此我們還要實現將url轉化cv2可以處理的圖片
1.1 圖片進行預處理包括直方圖均衡化,url->cv2 image 格式
下面通過程式碼來講解(檔案: utils->img_process.py):
# _*_coding:utf-8 _*_
import cv2
import urllib
import numpy as np
IMG_HEIGHT = 227
IMG_WIDTH = 227
# 對圖片做直方圖均衡化處理
def pre_process_img (img, img_height=IMG_HEIGHT, img_width=IMG_WIDTH):
# firstly histogram equalization
img[:, :, 0] = cv2.equalizeHist(img[:, :, 0])
img[:, :, 1] = cv2.equalizeHist(img[:, :, 1])
img[:, :, 2] = cv2.equalizeHist(img[:, :, 2])
# resize image to size
img = cv2.resize(img, (img_width, img_height), interpolation=cv2.INTER_CUBIC)
return img
# 通過圖片url將其轉化為cv2可以處理的形式
def get_cv_img__from_url(url):
"""
read image from url to cv codec
:param url:
:return:
"""
try:
url_response = urllib.urlopen(url)
img_array = np.array(bytearray(url_response.read()), dtype=np.uint8)
img = cv2.imdecode(img_array, -1)
return img
except Exception, e:
print e
return None
if __name__ == '__main__':
url = 'http://www.sanyarb.com.cn/images/attachement/jpg/site2/20161009/A121475977636942_change_ljx6a9_b.jpg'
img = get_cv_img__from_url(url)
cv2.imshow("zhan lang", img)
img = pre_process_img(img)
cv2.imshow("pre_process_img", img)
cv2.waitKey()
pass
下面是下載網上的圖片,然後對其進行直方圖均衡化以及resize的執行的結果:
1.2 圖片按照一定的比例生成訓練樣本以及驗證樣本lmdb]
# _*_coding:utf-8 _*_
import sys
sys.path.insert(0, '../../caffe_train_test/')
import os
import glob
import random
import numpy as np
import cv2
import caffe
from caffe.proto import caffe_pb2
import lmdb
from utils.img_process import *
# 根據圖片和標籤轉化為對應的lmdb格式
def make_datum(img, label):
# image is numpy.ndarray format. BGR instead of RGB
return caffe_pb2.Datum(
channels=3,
width=IMG_HEIGHT,
height=IMG_WIDTH,
label=label,
data=np.rollaxis(img, 2).tostring())
# 建立lmdb的基類
class GenerateLmdb(object):
def __init__(self, img_path):
"""
img_path -> multiple calss directory
like, class_1, class_2, class_3....
each class has corresponding class image like class_1_1.png
:param img_path:
"""
# get all the images in different class directory
# 獲取到多有的圖片列表
self.img_lst = glob.glob(os.path.join(img_path, '*', '*.png'))
print 'input_img list num is %s' % len(self.img_lst)
# shuffle all the images
# 需要對列表亂序
random.shuffle(self.img_lst)
# 根據標籤,比例生成訓練lmdb以及驗證lmdb
def generate_lmdb(self, label_lst, percentage, train_path, validation_path):
"""
label_lst like ['class_1', 'class_2', 'class_3', .....]
percentage like is 5 (4/5) then 80% be train image, (1/5) 20% be validation image
train_path like that '/data/train/train_lmdb'
validation_path like '/data/train/validation_lmdb'
"""
print 'now generate train lmdb'
self._generate_lmdb(label_lst, percentage, True, train_path)
print 'now generate validation lmdb'
self._generate_lmdb(label_lst, percentage, False, validation_path)
print '\n generate all images'
def _generate_lmdb(self, label_lst, percentage, b_train, input_path):
"""
b_train is True means to generate train lmdb, or validation lmdb
"""
output_db = lmdb.open(input_path, map_size=int(1e12))
with output_db.begin(write=True) as in_txn:
for idx, img_path in enumerate(self.img_lst):
# create train data
if b_train:
# !=0 means validation data then skip loop
if idx % percentage != 0:
continue
# create validation data
else:
# ==0 means train data then skip
if idx % percentage == 0:
continue
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = pre_process_img(img)
# path like that '../../class_1/0001.png'
# so img_path.split('/')[-2] -> class_1
label = label_lst.index(img_path.split('/')[-2])
datum = make_datum(img, label)
in_txn.put('{:0>5d}'.format(idx), datum.SerializeToString())
print '{:0>5d}'.format(idx) + '->label: ', label, " " + img_path
output_db.close()
def get_label_lst_by_dir(f_dir):
"""
f_dir like 'home/user/class', sub dir 'class_1', 'class_2'...'class_n'
:return: ['class_1', 'class_2'...'class_n']
"""
return os.listdir(f_dir)
if __name__ == '__main__':
img_path = '../../ad_train/'
cl = GenerateLmdb(img_path)
train_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/train_lmdb'
validation_lmdb = '/data6/light/storm_1_1/images/ad_train_py/input_data/validation_lmdb'
# 刪除原有的lmdb檔案
os.system('rm -rf ' + train_lmdb)
os.system('rm -rf ' + validation_lmdb)
input_path = '/data6/light/storm_1_1/images/ad_train/'
label_lst = get_label_lst_by_dir(input_path)
print 'label_lst is: %s' % ', '.join(label_lst)
# (1/10)10% to be validation data, 90% to be train data
# 1/10的檔案為驗證lmdb, 9/10為訓練lmdb
percentage = 10
cl.generate_lmdb(label_lst, percentage, train_lmdb, validation_lmdb)
pass
下面是實踐的執行截圖(這個程式碼好早前就運行了,這次寫bolg做了一些處理)下面是一個三分類的目錄(前面做過十幾中的分類,這裡寫bolg,做了簡化)
類別標籤是: ad_text(文字廣告), ad_web(網頁廣告),others(其他類)
類別目錄如下:
下面是輸出的label列表:
下面是執行 python create_lmdb.py
的部分日誌結果(為了簡便做了很多處理)
下面是最終生成的lmdb檔案:
到此我們生成了,caffe訓練需要的lmdb檔案
1.3 mean_file mean.binaryproto
# _*_ coding:utf-8
import os
# 生成,生成mean_binaryproto檔案的字串命令
def get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path):
# create train command
return '%s -backend=lmdb %s %s ' % (mean_tool_path, train_lmdb_path, mean_binaryproto_path)
if __name__ == '__main__':
# caffe mean 工具的路徑
mean_tool_path = '/home/ubuntu/caffe/build/tools/compute_image_mean'
train_lmdb_path = '/home/xiongyu/input/train_lmdb'
mean_binaryproto_path = '/home/xiongyu/input/mean.binaryproto'
cmd = get_mean_cmd(mean_tool_path, train_lmdb_path, mean_binaryproto_path)
print cmd
# 執行生成命令
os.system(cmd)
cmd合成的字串
實際生成的結果
2. caffe中模型的配置檔案的定義以及說明
2.1 訓練模型定義
caffe中模型的定義,主要是修改 caffe Alexnet 訓練檔案train_val.prototxt
。主要修改mean_file mean.binaryproto,source train lmdb 路徑,
由於這個示例主要講的是3分類,因此還要修改num_output為3(記得修改對應的 部署檔案
)
2.2 部署檔案
2.3 訓練執行引數檔案
net: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_train_val_1.prototxt"
test_iter: 1000
# 每1000次做一次驗證
test_interval: 1000
base_lr: 0.001
lr_policy: "step"
gamma: 0.1
stepsize: 2500
display: 50
# 最大迭代次數
max_iter: 30000
momentum: 0.9
# 權重衰減因子
weight_decay: 0.0005
# 每訓練6000次生成一次模型快照
snapshot: 5000
# 模型快照字首
snapshot_prefix: "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1"
# GPU模式
solver_mode: GPU
下面看下最終生成的模型檔案(檔案太大刪除了很多,只保留一個執行時的)
3. 訓練驗證資料準備完成之後,就是模型的訓練
程式碼類似與mean 檔案的生成,這裡就不解釋了
command |& tee out.log
, 將結果輸出到標準輸出流以及out.log檔案中
# _*_ coding:utf-8
import os
def get_train_cmd(caffe_path, solver_path, log_path):
# create train command
return '%s train --solver %s |& tee %s ' % (caffe_path, solver_path, log_path)
if __name__ == '__main__':
caffe_path = "/home/xiongyu/caffe/build/tools/caffe"
solver_path = "/home/xiongyu/caffe_models/caffe_model_1/solver_1.prototxt"
log_path = "/home/xiongyu/caffe_models/caffe_model_1/model_1_train.log"
train = get_train_cmd(caffe_path, solver_path, log_path)
print train
# use caffe to train model
os.system(train)
pass
下面是訓練時的部分截圖:
4. 本地測試以及從資料庫獲取url測試然後將結果寫到資料庫中
4.1 測試基類檔案predict_base.py
為了保證程式碼的模組性,測試的便捷性,這個基類提供給測試本地檔案以及資料庫檔案呼叫
# _*_coding:utf-8 _*_
import sys
sys.path.insert(0, '../../caffe_train_test/')
import os
import glob
import cv2
import caffe
import lmdb
import numpy as np
from caffe.proto import caffe_pb2
from utils.img_process import *
class CaffePredict(object):
def __init__(self, b_gpu, mean_path, deploy_path, model_path):
# cpu或者是gpu模式
if b_gpu:
caffe.set_mode_gpu()
else:
caffe.set_mode_cpu()
mean_blob = caffe_pb2.BlobProto()
with open(mean_path) as f:
mean_blob.ParseFromString(f.read())
mean_array = np.asarray(mean_blob.data, dtype=np.float32).\
reshape((mean_blob.channels, mean_blob.height, mean_blob.width))
self.net = caffe.Net(deploy_path, model_path, caffe.TEST)
# Define image transformers
self.transformer = caffe.io.Transformer({'data': self.net.blobs['data'].data.shape})
self.transformer.set_mean('data', mean_array)
# puts the channel as the first dimention
self.transformer.set_transpose('data', (2, 0, 1))
# predict只需要輸入cv2 image格式圖片即可
def predict(self, img):
img = pre_process_img(img)
self.net.blobs['data'].data[...] = self.transformer.preprocess('data', img)
out = self.net.forward()
pred_probas = out['prob']
# predict result
ret_lst = [round(f, 4) for f in pred_probas[0].tolist()]
return ret_lst
# 獲取預設的caffe模型
def get_default_caffe_predict():
# Read model architecture and trained model's weights
mean_path = "/data6/light/storm_1_1/images/ad_train_py/input_data/mean.binaryproto"
deploy_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffenet_deploy_1.prototxt"
model_path = "/data6/light/storm_1_1/images/ad_train_py/caffe_model/caffe_model_1_iter_10000.caffemodel"
b_gpu = True
caffe_predict = CaffePredict(b_gpu, mean_path, deploy_path, model_path)
return caffe_predict
if __name__ == '__main__':
# 使用預設的模型識別
caffe_predict = get_default_caffe_predict()
img_path = '/data6/light/storm_1_1/images/ad_train_py/test_data/0.png'
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
print caffe_predict.predict(img)
pass
識別一張圖片,執行結果如下:
4.2 測試本地目錄所有圖片檔案
predict_from_local.py
讀取目錄下的所有檔案,並輸出識別結果
import sys
sys.path.insert(0, '../../caffe_train_test/')
from predict_base import CaffePredict, get_default_caffe_predict
import glob
import cv2
def get_img_lst(img_dir):
"""
img_dir: /data6/light/storm_1_1/images/ad_train_py/test_data/
lots of images like '0.jpg, 1.jpg ......'
"""
return glob.glob(img_dir + "*.png")
def predict_all():
path = '/data6/light/storm_1_1/images/ad_train_py/test_data/'
img_lst = get_img_lst(path)
caffe_predict = get_default_caffe_predict()
for path in img_lst:
try:
img = cv2.imread(path, cv2.IMREAD_COLOR)
# caffe_predict.predict is not thread safe,so can't be used in multiple thread
# python is dummy multiple threads
ret_lst = caffe_predict.predict(img)
print path, ret_lst
except Exception, e:
print e
if __name__ == '__main__':
predict_all()
pass
執行結果如下:
4.3 測試資料庫所有圖片檔案
當然在實際的執行中我們往往測試幾十萬張圖片,一般上傳到伺服器也很麻煩(圖片要下載下來,然後打包在sz到linux目錄,這樣很麻煩而且,打包檔案太大的話上傳到伺服器往往報錯)。所以我們一般在資料庫上面讀取url然後識別,在把識別的結果寫回到資料庫,例如這樣:
# _*_ coding:utf-8 _*_
import sys
sys.path.insert(0, '../../caffe_train_test/')
from utils.DbBse import DbService, get_default_db
from utils.img_process import get_cv_img__from_url
from predict_base import CaffePredict, get_default_caffe_predict
def predict_from_db():
"""
get all the url and id from database and
then predict, write predict result to database
:return:
"""
db = get_default_db()
# [(1, 'http://xxx.1.jpg'), (2, 'http://xxx.2.jpg).....]
url_id_lst = db.get_ad_info()
print 'url_id_lst length is %s: ' % len(url_id_lst)
print 'url_id_lst first is', url_id_lst[0]
caffe_predict = get_default_caffe_predict()
for item in url_id_lst:
img = get_cv_img__from_url(item[1])
if img is None:
continue
ret_lst = caffe_predict.predict(img)
# item[0] is id
ret_lst.append(item[0])
# write result to database
print item[1], ret_lst
db.update_ad_info(ret_lst)
if __name__ == '__main__':
predict_from_db()
pass
下面是執行結果:
本文主要參考了下面這兩個英文bolg,同時做了大量修改,主要是分享給使用caffe做圖片學習需要的人: