[Tensorflow]基於slim框架下inception模型的植物識別
阿新 • • 發佈:2019-01-10
1.資料獲取
python指令碼根據關鍵字爬取對應的圖片
#!/usr/bin/env python # encoding: utf-8 import urllib2 import re import os import sys reload(sys) sys.setdefaultencoding("utf-8") def img_spider(name_file): user_agent = "Mozilla/5.0 (Windows NT 10.0; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/45.0.2454.101 Safari/537.36" headers = {'User-Agent':user_agent} #讀取名單txt,生成包括所有物品的名單列表 with open(name_file) as f: name_list = [name.rstrip().decode('utf-8') for name in f.readlines()] f.close() #遍歷每一個物品,儲存在以該物品名字命名的資料夾中 for name in name_list: #生成資料夾(如果不存在的話) if not os.path.exists('data/my_data/' + name): os.makedirs('data/my_data/' + name) for i in range(2): #修改range內數值n,可改變爬取數量為n*60 try: num = (i+1)*60 url = "http://image.baidu.com/search/avatarjson?tn=resultjsonavatarnew&ie=utf-8&word=" + name.replace(' ','%20') + "&cg=girl&rn=60&pn="+ str(num) req = urllib2.Request(url, headers=headers) res = urllib2.urlopen(req) page = res.read() #print page #因為JSON的原因,在瀏覽器頁面按F12看到的,和你打印出來的頁面內容是不一樣的,所以匹配的是objURL img_srcs = re.findall('"objURL":"(.*?)"', page, re.S) print name,len(img_srcs) except: #如果訪問失敗,就跳到下一個繼續執行程式碼,而不終止程式 print name," error:" continue j = 1 src_txt = '' #訪問上述得到的圖片路徑,儲存到本地 for src in img_srcs: with open('data/my_data/' + name + '/'+name +'_' + str(num+j-60)+'.jpg','wb') as p: try: print "downloading No.%d"%(num+j-60) req = urllib2.Request(src, headers=headers) #設定一個urlopen的超時,如果3秒訪問不到,就跳到下一個地址,防止程式卡在一個地方。 img = urllib2.urlopen(src,timeout=3) p.write(img.read()) except: print "No.%d error:"%(num+j-60) p.close() continue p.close() src_txt = src_txt + src + '\n' if j==60: break j = j+1 #主程式,讀txt檔案開始爬 if __name__ == '__main__': name_file = "data/flower.txt" img_spider(name_file)
2.準備資料
3.下載slim 和inception v4模型
4.修改slim/datasets/download_and_convert_flowers.py 5處
5.生成tfrecord檔案# Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== r"""Downloads and converts Flowers data to TFRecords of TF-Example protos. This module downloads the Flowers data, uncompresses it, reads the files that make up the Flowers data and creates two TFRecord datasets: one for train and one for test. Each TFRecord dataset is comprised of a set of TF-Example protocol buffers, each of which contain a single image and label. The script should take about a minute to run. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import math import os import random import sys import tensorflow as tf from datasets import dataset_utils # The URL where the Flowers data can be downloaded. _DATA_URL = 'http://download.tensorflow.org/example_images/flower_photos.tgz' # The number of images in the validation set. _NUM_VALIDATION = 180 #修改驗證集數量,一般為資料集的1/10 # Seed for repeatability. _RANDOM_SEED = 0 # The number of shards per dataset split. _NUM_SHARDS = 2 #修改tfrecord個數,每個tfrecord 1024張左右圖片 class ImageReader(object): """Helper class that provides TensorFlow image coding utilities.""" def __init__(self): # Initializes function that decodes RGB JPEG data. self._decode_jpeg_data = tf.placeholder(dtype=tf.string) self._decode_jpeg = tf.image.decode_jpeg(self._decode_jpeg_data, channels=3) def read_image_dims(self, sess, image_data): image = self.decode_jpeg(sess, image_data) return image.shape[0], image.shape[1] def decode_jpeg(self, sess, image_data): image = sess.run(self._decode_jpeg, feed_dict={self._decode_jpeg_data: image_data}) assert len(image.shape) == 3 assert image.shape[2] == 3 return image def _get_filenames_and_classes(dataset_dir): """Returns a list of filenames and inferred class names. Args: dataset_dir: A directory containing a set of subdirectories representing class names. Each subdirectory should contain PNG or JPG encoded images. Returns: A list of image file paths, relative to `dataset_dir` and the list of subdirectories, representing class names. """ flower_root = os.path.join(dataset_dir, 'my_data') #修改資料集路徑 directories = [] class_names = [] for filename in os.listdir(flower_root): path = os.path.join(flower_root, filename) if os.path.isdir(path): directories.append(path) class_names.append(filename) photo_filenames = [] for directory in directories: for filename in os.listdir(directory): path = os.path.join(directory, filename) photo_filenames.append(path) return photo_filenames, sorted(class_names) def _get_dataset_filename(dataset_dir, split_name, shard_id): output_filename = 'flowers_%s_%05d-of-%05d.tfrecord' % ( split_name, shard_id, _NUM_SHARDS) return os.path.join(dataset_dir, output_filename) def _convert_dataset(split_name, filenames, class_names_to_ids, dataset_dir): """Converts the given filenames to a TFRecord dataset. Args: split_name: The name of the dataset, either 'train' or 'validation'. filenames: A list of absolute paths to png or jpg images. class_names_to_ids: A dictionary from class names (strings) to ids (integers). dataset_dir: The directory where the converted datasets are stored. """ assert split_name in ['train', 'validation'] num_per_shard = int(math.ceil(len(filenames) / float(_NUM_SHARDS))) with tf.Graph().as_default(): image_reader = ImageReader() with tf.Session('') as sess: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( dataset_dir, split_name, shard_id) with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer: start_ndx = shard_id * num_per_shard end_ndx = min((shard_id+1) * num_per_shard, len(filenames)) for i in range(start_ndx, end_ndx): sys.stdout.write('\r>> Converting image %d/%d shard %d' % ( i+1, len(filenames), shard_id)) sys.stdout.flush() # Read the filename: image_data = tf.gfile.FastGFile(filenames[i], 'rb').read() height, width = image_reader.read_image_dims(sess, image_data) class_name = os.path.basename(os.path.dirname(filenames[i])) class_id = class_names_to_ids[class_name] example = dataset_utils.image_to_tfexample( image_data, b'jpg', height, width, class_id) tfrecord_writer.write(example.SerializeToString()) sys.stdout.write('\n') sys.stdout.flush() def _clean_up_temporary_files(dataset_dir): """Removes temporary files used to create the dataset. Args: dataset_dir: The directory where the temporary files are stored. """ filename = _DATA_URL.split('/')[-1] filepath = os.path.join(dataset_dir, filename) tf.gfile.Remove(filepath) tmp_dir = os.path.join(dataset_dir, 'flower_photos') tf.gfile.DeleteRecursively(tmp_dir) def _dataset_exists(dataset_dir): for split_name in ['train', 'validation']: for shard_id in range(_NUM_SHARDS): output_filename = _get_dataset_filename( dataset_dir, split_name, shard_id) if not tf.gfile.Exists(output_filename): return False return True def run(dataset_dir): """Runs the download and conversion operation. Args: dataset_dir: The dataset directory where the dataset is stored. """ if not tf.gfile.Exists(dataset_dir): tf.gfile.MakeDirs(dataset_dir) if _dataset_exists(dataset_dir): print('Dataset files already exist. Exiting without re-creating them.') return # dataset_utils.download_and_uncompress_tarball(_DATA_URL, dataset_dir) 註釋此句 photo_filenames, class_names = _get_filenames_and_classes(dataset_dir) class_names_to_ids = dict(zip(class_names, range(len(class_names)))) # Divide into train and test: random.seed(_RANDOM_SEED) random.shuffle(photo_filenames) training_filenames = photo_filenames[_NUM_VALIDATION:] validation_filenames = photo_filenames[:_NUM_VALIDATION] # First, convert the training and validation sets. _convert_dataset('train', training_filenames, class_names_to_ids, dataset_dir) _convert_dataset('validation', validation_filenames, class_names_to_ids, dataset_dir) # Finally, write the labels file: labels_to_class_names = dict(zip(range(len(class_names)), class_names)) dataset_utils.write_label_file(labels_to_class_names, dataset_dir) # _clean_up_temporary_files(dataset_dir) 註釋此句 print('\nFinished converting the Flowers dataset!')
python download_and_convert_data.py --dataset_name=flowers --dataset_dir=/media/han/code/data/
記錄圖中紅框內訓練集和驗證集的圖片數量
6.修改flowers資料來源,slim/datasets/flowers.py 2處
#coding=utf-8 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Provides data for the flowers dataset. The dataset scripts used to create the dataset can be found at: tensorflow/models/slim/datasets/download_and_convert_flowers.py """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf from datasets import dataset_utils slim = tf.contrib.slim _FILE_PATTERN = 'flowers_%s_*.tfrecord' SPLITS_TO_SIZES = {'train': 1400, 'validation': 180} #修改訓練集和驗證集圖片數量 _NUM_CLASSES = 15 #修改標籤數量 _ITEMS_TO_DESCRIPTIONS = { 'image': 'A color image of varying size.', 'label': 'A single integer between 0 and 4', } def get_split(split_name, dataset_dir, file_pattern=None, reader=None): """Gets a dataset tuple with instructions for reading flowers. Args: split_name: A train/validation split name. dataset_dir: The base directory of the dataset sources. file_pattern: The file pattern to use when matching the dataset sources. It is assumed that the pattern contains a '%s' string so that the split name can be inserted. reader: The TensorFlow reader type. Returns: A `Dataset` namedtuple. Raises: ValueError: if `split_name` is not a valid train/validation split. """ if split_name not in SPLITS_TO_SIZES: raise ValueError('split name %s was not recognized.' % split_name) if not file_pattern: file_pattern = _FILE_PATTERN file_pattern = os.path.join(dataset_dir, file_pattern % split_name) # Allowing None in the signature so that dataset_factory can use the default. if reader is None: reader = tf.TFRecordReader keys_to_features = { 'image/encoded': tf.FixedLenFeature((), tf.string, default_value=''), 'image/format': tf.FixedLenFeature((), tf.string, default_value='png'), 'image/class/label': tf.FixedLenFeature( [], tf.int64, default_value=tf.zeros([], dtype=tf.int64)), } items_to_handlers = { 'image': slim.tfexample_decoder.Image(), 'label': slim.tfexample_decoder.Tensor('image/class/label'), }
如果標籤是中文,修改slim/datasets/dataset_utils.py
# -*- coding: utf-8 -*-
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains utilities for downloading and converting datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import sys
reload(sys)
sys.setdefaultencoding("utf-8") #中文標籤,增加utf-8
import tarfile
from six.moves import urllib
import tensorflow as tf
LABELS_FILENAME = 'labels.txt'
def int64_feature(values):
"""Returns a TF-Feature of int64s.
Args:
values: A scalar or list of values.
Returns:
a TF-Feature.
"""
if not isinstance(values, (tuple, list)):
values = [values]
return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def bytes_feature(values):
"""Returns a TF-Feature of bytes.
Args:
values: A string.
Returns:
a TF-Feature.
"""
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[values]))
def image_to_tfexample(image_data, image_format, height, width, class_id):
return tf.train.Example(features=tf.train.Features(feature={
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature(image_format),
'image/class/label': int64_feature(class_id),
'image/height': int64_feature(height),
'image/width': int64_feature(width),
}))
def download_and_uncompress_tarball(tarball_url, dataset_dir):
"""Downloads the `tarball_url` and uncompresses it locally.
Args:
tarball_url: The URL of a tarball file.
dataset_dir: The directory where the temporary files are stored.
"""
filename = tarball_url.split('/')[-1]
filepath = os.path.join(dataset_dir, filename)
def _progress(count, block_size, total_size):
sys.stdout.write('\r>> Downloading %s %.1f%%' % (
filename, float(count * block_size) / float(total_size) * 100.0))
sys.stdout.flush()
if not os.path.exists(file_path):
filepath, _ = urllib.request.urlretrieve(tarball_url, filepath, _progress)
print()
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
tarfile.open(filepath, 'r:gz').extractall(dataset_dir)
def write_label_file(labels_to_class_names, dataset_dir,
filename=LABELS_FILENAME):
"""Writes a file with the list of class names.
Args:
labels_to_class_names: A map of (integer) labels to class names.
dataset_dir: The directory in which the labels file should be written.
filename: The filename where the class names are written.
"""
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'w') as f:
for label in labels_to_class_names:
class_name = labels_to_class_names[label]
f.write('%d:%s\n' % (label, class_name))
def has_labels(dataset_dir, filename=LABELS_FILENAME):
"""Specifies whether or not the dataset directory contains a label map file.
Args:
dataset_dir: The directory in which the labels file is found.
filename: The filename where the class names are written.
Returns:
`True` if the labels file exists and `False` otherwise.
"""
return tf.gfile.Exists(os.path.join(dataset_dir, filename))
def read_label_file(dataset_dir, filename=LABELS_FILENAME):
"""Reads the labels file and returns a mapping from ID to class name.
Args:
dataset_dir: The directory in which the labels file is found.
filename: The filename where the class names are written.
Returns:
A map from a label (integer) to class name.
"""
labels_filename = os.path.join(dataset_dir, filename)
with tf.gfile.Open(labels_filename, 'rb') as f:
lines = f.read().decode()
lines = lines.split('\n')
lines = filter(None, lines)
labels_to_class_names = {}
for line in lines:
index = line.index(':')
labels_to_class_names[int(line[:index])] = line[index+1:]
return labels_to_class_names
7.開始訓練
slim下執行下面指令碼訓練資料
python -u train_image_classifier.py \
--dataset_name=flowers \
--dataset_dir=/media/han/code/data \ #資料集位置
--checkpoint_path=/media/Work/inception_v4.ckpt \ #下載的inception_v4位置
--model_name=inception_v4 \
--checkpoint_exclude_scopes=InceptionV4/Logits,InceptionV4/AuxLogits/Aux_logits \
--trainable_scopes=InceptionV4/Logits,InceptionV4/AuxLogits/Aux_logits \
--train_dir=/media/han/code/my_train \ #訓練集儲存位置
--learning_rate=0.001 \
--learning_rate_decay_factor=0.76\
--num_epochs_per_decay=50 \
--moving_average_decay=0.9999 \
--optimizer=adam \
--ignore_missing_vars=True \
--batch_size=32
開始訓練,生成ckpt檔案
python -u eval_image_classifier.py \
--dataset_name=flowers \
--dataset_dir=/media/han/code/data \
--dataset_split_name=train \
--model_name=inception_v4 \
--checkpoint_path=/media/han/code/my_train \
--eval_dir=/media/han/code/my_eval \
--batch_size=32 \
--num_examples=1328
執行eval_validation.sh驗證
9 tensorboard視覺化
tensorboard --logdir=/media/han/code/my_train
10 資料匯出
修改slim/export_inference_graph.py# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Saves out a GraphDef containing the architecture of the model.
To use it, run something like this, with a model name defined by slim:
bazel build tensorflow_models/slim:export_inference_graph
bazel-bin/tensorflow_models/slim/export_inference_graph \
--model_name=inception_v3 --output_file=/tmp/inception_v3_inf_graph.pb
If you then want to use the resulting model with your own or pretrained
checkpoints as part of a mobile model, you can run freeze_graph to get a graph
def with the variables inlined as constants using:
bazel build tensorflow/python/tools:freeze_graph
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/inception_v3_inf_graph.pb \
--input_checkpoint=/tmp/checkpoints/inception_v3.ckpt \
--input_binary=true --output_graph=/tmp/frozen_inception_v3.pb \
--output_node_names=InceptionV3/Predictions/Reshape_1
The output node names will vary depending on the model, but you can inspect and
estimate them using the summarize_graph tool:
bazel build tensorflow/tools/graph_transforms:summarize_graph
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph \
--in_graph=/tmp/inception_v3_inf_graph.pb
To run the resulting graph in C++, you can look at the label_image sample code:
bazel build tensorflow/examples/label_image:label_image
bazel-bin/tensorflow/examples/label_image/label_image \
--image=${HOME}/Pictures/flowers.jpg \
--input_layer=input \
--output_layer=InceptionV3/Predictions/Reshape_1 \
--graph=/tmp/frozen_inception_v3.pb \
--labels=/tmp/imagenet_slim_labels.txt \
--input_mean=0 \
--input_std=255 \
--logtostderr
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.platform import gfile
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from nets import nets_factory
slim = tf.contrib.slim
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to save.')
tf.app.flags.DEFINE_boolean(
'is_training', False,
'Whether to save out a training-focused version of the model.')
tf.app.flags.DEFINE_integer(
'default_image_size', 224,
'The image size to use if the model does not define it.')
tf.app.flags.DEFINE_string('dataset_name', 'imagenet',
'The name of the dataset to use with the model.')
tf.app.flags.DEFINE_integer(
'labels_offset', 0,
'An offset for the labels in the dataset. This flag is primarily used to '
'evaluate the VGG and ResNet architectures which do not use a background '
'class for the ImageNet dataset.')
tf.app.flags.DEFINE_string(
'output_file', '', 'Where to save the resulting file to.')
tf.app.flags.DEFINE_string(
'dataset_dir', '', 'Directory to save intermediate dataset files to')
FLAGS = tf.app.flags.FLAGS
def main(_):
if not FLAGS.output_file:
raise ValueError('You must supply the path to save to with --output_file')
tf.logging.set_verbosity(tf.logging.INFO)
with tf.Graph().as_default() as graph:
dataset = dataset_factory.get_dataset(FLAGS.dataset_name, 'validation',
FLAGS.dataset_dir)
preprocessing_name = FLAGS.model_name #預處理
image_preprocessing_fn = preprocessing_factory.get_preprocessing(
preprocessing_name,
is_training=False)
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=(dataset.num_classes - FLAGS.labels_offset),
is_training=FLAGS.is_training)
if hasattr(network_fn, 'default_image_size'):
image_size = network_fn.default_image_size
else:
image_size = FLAGS.default_image_size
# placeholder = tf.placeholder(name='input', dtype=tf.float32,
# shape=[1, image_size, image_size, 3])
placeholder = tf.placeholder(name='input', dtype=tf.string)
image = tf.image.decode_jpeg(placeholder, channels=3)
image = image_preprocessing_fn(image, image_size, image_size)
image = tf.expand_dims(image, 0)
network_fn(image)
graph_def = graph.as_graph_def()
with gfile.GFile(FLAGS.output_file, 'wb') as f:
f.write(graph_def.SerializeToString())
if __name__ == '__main__':
tf.app.run()
執行export.sh生成my_inception_v4.pb
python -u export_inference_graph.py \
--model_name=inception_v4 \
--output_file=./my_inception_v4.pb \
--dataset_name=flowers \
--dataset_dir=/media/han/code/data/
執行freeze.sh生成my_inception_v4_freeze.pb和my_inception_v4_freeze.label
python -u /usr/local/lib/python2.7/dist-packages/tensorflow/python/tools/freeze_graph.py \
--input_graph=my_inception_v4.pb \
--input_checkpoint=/media/han/code/my_train/model.ckpt-1835 \
--output_graph=./my_inception_v4_freeze.pb \
--input_binary=True \
--output_node_name=InceptionV4/Logits/Predictions
cp /media/han/code/data/labels.txt ./my_inception_v4_freeze.label
11.WEB
server.py生成web
# coding=utf-8
import os
import sys
reload(sys)
sys.setdefaultencoding("utf-8")
import time
from flask import request, send_from_directory
from flask import Flask, request, redirect, url_for
import uuid
import tensorflow as tf
from classify_image import run_inference_on_image
ALLOWED_EXTENSIONS = set(['jpg','JPG', 'jpeg', 'JPEG', 'png'])
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('model_dir', '', """Path to graph_def pb, """)
tf.app.flags.DEFINE_string('model_name', 'my_inception_v4_freeze.pb', '')
tf.app.flags.DEFINE_string('label_file', 'my_inception_v4_freeze.label', '')
tf.app.flags.DEFINE_string('upload_folder', '/tmp/', '')
tf.app.flags.DEFINE_integer('num_top_predictions', 5,
"""Display this many predictions.""")
tf.app.flags.DEFINE_integer('port', '5001',
'server with port,if no port, use deault port 80')
tf.app.flags.DEFINE_boolean('debug', False, '')
UPLOAD_FOLDER = FLAGS.upload_folder
ALLOWED_EXTENSIONS = set(['jpg','JPG', 'jpeg', 'JPEG', 'png'])
app = Flask(__name__)
app._static_folder = UPLOAD_FOLDER
def allowed_files(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1] in ALLOWED_EXTENSIONS
def rename_filename(old_file_name):
basename = os.path.basename(old_file_name)
name, ext = os.path.splitext(basename)
new_name = str(uuid.uuid1()) + ext
return new_name
def inference(file_name):
try:
predictions, top_k, top_names = run_inference_on_image(file_name, model_file=FLAGS.model_name)
print(predictions)
except Exception as ex:
print(ex)
return ""
new_url = '/static/%s' % os.path.basename(file_name)
image_tag = '<img src="%s"></img><p>'
new_tag = image_tag % new_url
format_string = ''
for node_id, human_name in zip(top_k, top_names):
score = predictions[node_id]
format_string += '%s (score:%.5f)<BR>' % (human_name, score)
ret_string = new_tag + format_string + '<BR>'
return ret_string
@app.route("/", methods=['GET', 'POST'])
def root():
result = """
<!doctype html>
<title>臨時測試用</title>
<h1>來喂一張照片吧</h1>
<form action="" method=post enctype=multipart/form-data>
<p><input type=file name=file value='選擇圖片'>
<input type=submit value='上傳'>
</form>
<p>%s</p>
""" % "<br>"
if request.method == 'POST':
file = request.files['file']
old_file_name = file.filename
if file and allowed_files(old_file_name):
filename = rename_filename(old_file_name)
file_path = os.path.join(UPLOAD_FOLDER, filename)
file.save(file_path)
type_name = 'N/A'
print('file saved to %s' % file_path)
out_html = inference(file_path)
return result + out_html
return result
if __name__ == "__main__":
print('listening on port %d' % FLAGS.port)
app.run(host='0.0.0.0', port=FLAGS.port, debug=FLAGS.debug, threaded=True)
執行server.sh
python -u server.py \
--model_name=my_inception_v4_freeze.pb \
--label_file=my_inception_v4_freeze.label \
--upload_folder=/tmp/upload
注:上述預測在幾乎未經訓練下給出
還有以下可改進之處:
1.改進資料生成形式,無需每次修改程式碼
2.改進資料來源形式,不使用TFRECORD形式
3.改進訓練部分,對LR進行Exponential Decay
4.改進訓練部分,使各層可以使用不同的LR進行訓練
5.改進驗證部分,使其一次執行,連續驗證
6.改進驗證部分,使其不佔用全部視訊記憶體
7.改進預測部分,使其可以對目錄進行預測
8.改進Server,使其無需每次重新建立計算圖