1. 程式人生 > >第一個分類學習問題例項(Classification)

第一個分類學習問題例項(Classification)

我們用一個簡單的例子來說明分類學習問題。

首先我們要下載一個數據庫。

我們使用MNIST庫,這是一個手寫體數字庫,總共有 60000 張圖片,其中 50000 張訓練圖片,10000 張測試圖片。

輸出層的激勵函式我們使用softmax激勵函式,這是一個常用在分類問題上的激勵函式。

loss函式我們使用交叉熵函式(cross_entropy)。

訓練函式仍使用梯度下降法,交叉熵函式作為訓練指引方向。

由於訓練集中的圖片很多,我們只取100張圖片來訓練。每訓練20次,我們列印一次準確度。準確度即我們搭建的模型的預測值和真實值是否相同的概率,用百分比表示。

程式碼如下:

import tensorflow as tf
import numpy as np
import os
from tensorflow.examples.tutorials.mnist import input_data

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
# 準備資料庫(MNIST庫,這是一個手寫體資料庫)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)


# 這一步是如果電腦上沒有這個資料庫則會下載這個資料庫到.py檔案所在目錄並建立一個MNIST_data資料夾
# 注意執行時這個資料庫可能要翻牆才能下載下來

def add_layer(inputs, in_size, out_size, activation_function=None):
	Weights = tf.Variable(tf.random_normal([in_size, out_size]))
	biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
	Wx_plus_b = tf.matmul(inputs, Weights) + biases
	if activation_function is None:
		outputs = Wx_plus_b
	# activation_function is None時沒有激勵函式,是線性關係
	else:
		outputs = activation_function(Wx_plus_b)
	# activation_function不為None時,得到的Wx_plus_b再傳入activation_function再處理一下
	return outputs


def compute_accuracy(v_xs, v_ys):
	global prediction
	# 使用global則對全域性變數prediction進行操作
	y_pre = sess.run(prediction, feed_dict={xs: v_xs})
	# 使用xs輸入資料生成預測值prediction
	correct_prediction = tf.equal(tf.argmax(y_pre, 1), tf.argmax(v_ys, 1))
	# 對於預測值和真實值的差別
	accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
	# 計算預測的準確率
	result = sess.run(accuracy, feed_dict={xs: v_xs, ys: v_ys})
	# run得到結果,結果是個百分比
	return result


xs = tf.placeholder(tf.float32, [None, 784])
# None表示不規定樣本的數量,784表示每個樣本的大小為28X28=784個畫素點
ys = tf.placeholder(tf.float32, [None, 10])
# 每張圖片表示一個數字,我們的輸出是數字0到9,所以是10個輸出

prediction = add_layer(xs, 784, 10, activation_function=tf.nn.softmax)
# 呼叫add_layer定義輸出層,輸入資料是784個特徵,輸出資料是10個特徵,激勵採用softmax函式
# softmax激勵函式一般用於classification

# 搭建分類模型時,loss函式(即最優化目標函式)選用交叉熵函式(cross_entropy)
# 交叉熵用來衡量預測值和真實值的相似程度,如果完全相同,它們的交叉熵等於零。
cross_entropy = tf.reduce_mean(-tf.reduce_sum(ys * tf.log(prediction), reduction_indices=[1]))
# 定義訓練函式,使用梯度下降法訓練,0.5是學習效率,通常小於1,minimize(cross_entropy)指要將cross_entropy減小
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
# 建立會話,並開始將網路初始化
sess = tf.Session()
sess.run(tf.global_variables_initializer())

for i in range(1000):
	batch_xs, batch_ys = mnist.train.next_batch(100)
	# 訓練時只從資料集中取100張圖片來訓練
	sess.run(train_step, feed_dict={xs: batch_xs, ys: batch_ys})
	if i % 50 == 0:
		# 每訓練50次列印準確度
		print(compute_accuracy(mnist.test.images, mnist.test.labels))
# 對比mnist中的training data和testing data的準確度

執行結果如下:

WARNING:tensorflow:From C:/Users/1234/Desktop/test/src/test.py:9: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version.
Instructions for updating:
Please write your own downloading logic.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data\train-images-idx3-ubyte.gz
Extracting MNIST_data\train-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.data to implement this functionality.
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data\t10k-images-idx3-ubyte.gz
Extracting MNIST_data\t10k-labels-idx1-ubyte.gz
WARNING:tensorflow:From C:\ProgramData\Anaconda3\envs\tensorflow\lib\site-packages\tensorflow\contrib\learn\python\learn\datasets\mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version.
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
0.1657
0.4756
0.6107
0.6827
0.7226
0.7455
0.7492
0.7783
0.7833
0.8
0.8049
0.8107
0.8198
0.8172
0.8242
0.8309
0.8324
0.8384
0.8416
0.8422
0.8428
0.846
0.8481
0.8528
0.8513
0.8523
0.8539
0.8557
0.8597
0.8627
0.8585
0.861
0.8628
0.8667
0.8654
0.8664
0.8687
0.8659
0.8686
0.8686
0.8714
0.8701
0.8727
0.8714
0.8753
0.8742
0.8713
0.8777
0.874
0.8754

Process finished with exit code 0

前面的warning提示一般不影響結果,可忽略。不過有時提示會告訴你有個別包無法下載,這時請更換新的梯子再重新執行。

可以看到我們訓練的準確度是在不斷提高的。