非常精簡的Mnist分類,基於tensorflow框架
阿新 • • 發佈:2018-11-12
一、介紹
基於tensorflow框架實現的Mnist資料分類。程式碼主要包括網路結構的搭建,訓練超引數的匯入和儲存,損失函式地繪製等。不足之處是在網路結尾沒用使用softmax函式,而直接使用了tanh輸出了分類結果。下面請看程式碼的詳細介紹
二、程式碼
- 匯入必要的包檔案,需要的包我直接通過pycharm匯入的,能匯入的原因是採用了anaconda3底下的python.exe,新建工程的時候,從外部匯入
# 需要使用到的包檔案 import tensorflow as tf import numpy as np import matplotlib.pyplot as plt import seaborn as sns from tensorflow.contrib.learn.python.learn.datasets.mnist import read_data_sets import argparse import os # 加上這一句能夠使Plot繪製出來的圖更精美 sns.set_style("whitegrid")
- 訓練引數設定,詳細介紹請看程式碼註釋,主要採用了argparse,該模組的好處是直接可以在執行時修改引數,比如:python main.py --data_dir= "**"
parser = argparse.ArgumentParser(description="Network for image classification") parser.add_argument('--data_dir', default='tem/data', help='Directory for training data') # Mnist資料集存放位置 parser.add_argument('--result_dir', default='tem/result') # 訓練結果的存放 parser.add_argument('--model_dir', default='model/', help='the place of saving networks parameters') #訓練引數的存放地址 parser.add_argument('--batch_size', default=32) parser.add_argument('--print_loss', default=10) # 每隔10次迭代列印損失值 parser.add_argument('--plot_loss', default=100) # 每隔100次迭代繪製損失函式曲線 parser.add_argument('--learning_rate', default=0.001, type=float) # 學習率,不易設定過大 parser.add_argument('--n_iterations', default=10000, type=int) # 迭代次數 args = parser.parse_args() # 將--*的*傳遞給arg,呼叫時直接使用args.data_dir這樣的結構
- 網路結構搭建
w_init = tf.random_normal_initializer(stddev=0.01) # 權重w初始化,標準差為0.01,平均值0 def network(x): # 啟用函式都為relu,除了輸出 layers1 = tf.layers.conv2d(x, 32, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init) # 32個卷積核,3x3卷積核大小,步長為1,padding為'same',即輸出大小為input/stride,向上取整 layers2 = tf.layers.conv2d(layers1, 62, 3, 1, padding='same', activation=tf.nn.relu, kernel_initializer=w_init) layers2_flatten = tf.contrib.layers.flatten(layers2) # 將layers2的輸出"磨平",降低相關維度,以供全連線層工作 layers3 = tf.layers.dense(layers2_flatten, 200, activation=tf.nn.relu, kernel_initializer=w_init) # 200為全連線層單元個數,其它的痛卷積函式類似 output = tf.layers.dense(layers3, 10, activation=tf.nn.tanh, kernel_initializer=w_init) # 使用tanh作為輸出,比sigmoid好,因為sigmoid是有0項,不利於網路訓練 return output
- 訓練網路,詳細介紹看註釋
def training():
input_x = tf.placeholder(tf.float32, [None, 28, 28, 1]) # 放置佔位矩陣
label_y = tf.placeholder(tf.float32, [None, 10])
output_y = network(input_x) # 前向傳播
loss = tf.reduce_sum(tf.square(label_y-output_y)) # 計算同便籤損失
optimizer = tf.train.AdamOptimizer(learning_rate=args.learning_rate).minimize(loss) # 使用Adam優化
init_all_v = tf.global_variables_initializer() # 張量初始化函式
sess = tf.InteractiveSession()
sess.run(init_all_v) # 實行張量初始化
saver = load_model(sess) # 匯入之前訓練過的引數,如果沒有則打印出來
mnist = read_data_sets(args.data_dir, one_hot=True) # 往指定目錄讀取Mnist資料集
print('start training')
plot_loss = [] # 損失值快取
for i in range(args.n_iterations):
batch_x, batch_y = mnist.train.next_batch(args.batch_size) # 讀取Batch_size
batch_x = batch_x.reshape([args.batch_size, 28, 28, 1]) # 維度匹配
y = np.zeros([args.batch_size, 10]) # 下面的操作是因為我讀到的標籤是6,8,9直接對應的圖片的數字,所以將這些數字向量化,以便訓練
for j in range(args.batch_size):
k = batch_y[j].astype(np.int)
y[j, k] = 1.
batch_y = y
d_loss, _ = sess.run([loss, optimizer], feed_dict={input_x:batch_x, label_y:batch_y}) # 執行
plot_loss.append(d_loss)
if i % args.print_loss == 0 and i > 0:
print('Iteration is : %d, Loss is: %f' % (i, d_loss)) # 列印損失
if i % args.plot_loss == 0 and i > 0: # 繪圖
plt.figure(figsize=(6*1.1618, 6))
plt.plot(range(len(plot_loss)), plot_loss)
plt.xlabel('iteration times')
plt.ylabel('lost')
plt.show()
if i % 500 == 0 and i > 0:
save_model(saver, sess, i)
- 模組的匯入與儲存
def save_model(saver, sess, step): # 儲存模組
saver.save(sess, os.path.join(args.model_dir, "classification"), global_step=step)
def load_model(sess): # 匯入模組
saver = tf.train.Saver()
checkpoint = tf.train.get_checkpoint_state(args.model_dir)
if checkpoint and checkpoint.model_checkpoint_path:
saver.restore(sess, checkpoint.model_checkpoint_path)
print("Successfully loaded:", checkpoint.model_checkpoint_path)
else:
print("Could not find any old weights!")
return saver
- 主函式
def main(_):
training()
if __name__ == "__main__":
tf.app.run()
從上往下黏貼就行,貼到IDE下就可以執行,還可以列印損失函式
鬼知道為什麼下降這麼快,,,