1. 程式人生 > >Tensorflow之MNIST手寫數字識別:分類問題(1)

Tensorflow之MNIST手寫數字識別:分類問題(1)

一、MNIST資料集讀取

one hot 獨熱編碼
獨熱編碼是一種稀疏向量,其中:一個向量設為1,其他元素均設為0.獨熱編碼常用於表示擁有有限個可能值的字串或識別符號
優點:   1、將離散特徵的取值擴充套件到了歐式空間,離散特徵的某個取值就對應歐式空間的某個點
    2、機器學習演算法中,特徵之間距離的計算或相似度的常用計算方法都是基於歐式空間的
    3、將離散型特徵使用one_hot編碼,會讓特徵之間的距離計算更加合理

import tensorflow as tf
 #MNIST資料集讀取
import tensorflow.examples.tutorials.mnist.input_data as input_data
mnist 
= input_data.read_data_sets("MNIST_data/",one_hot=True)

###輸出結果###
#若不成功可手動到相關網站下載之後新增到資料夾中
#Extracting MNIST_data/train-images-idx3-ubyte.gz
#Extracting MNIST_data/train-labels-idx1-ubyte.gz
#Extracting MNIST_data/t10k-images-idx3-ubyte.gz
#Extracting MNIST_data/t10k-labels-idx1-ubyte.gz

二、瞭解MNIST手寫數字識別資料集

#瞭解MNIST手寫數字識別資料集
print('訓練集 train 數量:',mnist.train.num_examples,
      ',驗證集 validation 數量:',mnist.validation.num_examples,
      ',測試集 test 數量:',mnist.test.num_examples)

###輸出結果###
#訓練集 train 數量: 55000 ,驗證集 validation 數量: 5000 ,測試集 test 數量: 10000
print(' train images shape:
',mnist.train.images.shape, 'labels shape:',mnist.train.labels.shape) ###輸出### #train images shape: (55000, 784) labels shape: (55000, 10) #28*28=784,10分類One Hot編碼

 三、視覺化image

#視覺化image
import matplotlib.pyplot as plt

def plot_image(image):
    plt.imshow(image.reshape(28,28),cmap='binary')
    plt.show()
plot_image(mnist.train.images[1])
輸出結果:

 

 
 
 
 
#進一步瞭解reshape()
import numpy as np
int_array = np.array([i for i in range(64)])
print(int_array)
輸出結果:
[ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24
 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49
 50 51 52 53 54 55 56 57 58 59 60 61 62 63]
int_array.reshape(8,8)
輸出結果:
array([[ 0,  1,  2,  3,  4,  5,  6,  7],
       [ 8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23],
       [24, 25, 26, 27, 28, 29, 30, 31],
       [32, 33, 34, 35, 36, 37, 38, 39],
       [40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55],
       [56, 57, 58, 59, 60, 61, 62, 63]])
#行優先,逐列排列
int_array.reshape(4,16)
輸出結果:
array([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15],
       [16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31],
       [32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47],
       [48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63]])

 

plt.imshow(mnist.train.images[20000].reshape(14,56),cmap='binary')
plt.show()
輸出結果:

 

 

 

四、資料讀取

1.採用獨熱編碼,標籤資料內容並不是直接輸出值,而是輸出編碼

#標籤資料與獨熱編碼,
#內容並不是直接輸出值,而是輸出編碼
mnist.train.labels[1]
輸出結果:
array([ 0.,  0.,  0.,  1.,  0.,  0.,  0.,  0.,  0.,  0.])
#非one_hot編碼的標籤值
mnist_no_one_hot = input_data.read_data_sets("MNIST_data/",one_hot=False)
print(mnist_no_one_hot.train.labels[0:10])      #onr_hot = False,直接返回值
輸出結果:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
[7 3 4 6 1 8 1 0 9 8]

2.讀取驗證集資料

#讀取驗證集資料
print('validation images:',mnist.validation.images.shape,'labels:',mnist.validation.labels.shape)    
輸出:
validation images: (5000, 784) labels: (5000, 10)

3.讀取測試機資料

#讀取測試機資料
print('tast images:',mnist.test.images.shape,'labels:',mnist.test.labels.shape)
輸出結果:
tast images: (10000, 784) labels: (10000, 10)

 

4.一次批量讀取多條資料

#一次批量讀取多條資料
batch_image_xs,batch_labels_ys = mnist.train.next_batch(batch_size=10)        #next_batch()實現內部會對資料集先做shuffle
print(mnist.train.labels[0:10])
print("\n")
print(batch_labels_ys)
輸出結果:
[[ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  1.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]]


[[ 0.  0.  0.  0.  0.  0.  1.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  1.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  0.  1.]
 [ 0.  0.  0.  1.  0.  0.  0.  0.  0.  0.]
 [ 0.  1.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  0.  1.  0.]
 [ 1.  0.  0.  0.  0.  0.  0.  0.  0.  0.]
 [ 0.  0.  0.  0.  0.  0.  0.  1.  0.  0.]]

5.argmax()用法

       argmax返回的是最大數的索引

import numpy as np
np.array(mnist.train.labels[1])
np.argmax(mnist.train.labels[1])     #argmax返回的是最大數的索引
#argmax詳解
arr1 = np.array([1,3,2,5,7,0])
arr2 = np.array([[1,2,3],[3,2,1],[4,7,2],[8,3,2]])
print("arr1=",arr1)
print("arr2=",arr2)

argmax_1 = tf.argmax(arr1)
argmax_20 = tf.argmax(arr2,0)      #指定第二個引數為0,按第一維(行)的元素取值,即同列的每一行取值   以行為基準,每列取最大值的下標
argmax_21 = tf.argmax(arr2,1)       #指定第二個引數為1,則第二維(列)的元素取值,即同行的每一列取值   以列為基準,每行取最大值的下標
argmax_22 = tf.argmax(arr2,-1)     #指定第二個引數為-1,則第最後維的元素取值

with tf.Session() as sess:
    print(argmax_1.eval())
    print(argmax_20.eval())
    print(argmax_21.eval())
    print(argmax_22.eval())
輸出結果:
arr1= [1 3 2 5 7 0]
arr2= [[1 2 3]
 [3 2 1]
 [4 7 2]
 [8 3 2]]
4
[3 2 0]
[2 0 1 0]
[2 0 1 0]

五、視覺化

#定義視覺化函式
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,prediction,index,num=10):  #引數: 圖形列表,標籤列表,預測值列表,從第index個開始顯示,預設一次顯示10幅
    fig = plt.gcf()             #獲取當前圖表,Get Current Figure
    fig.set_size_inches(10,12)    #1英寸等於2.45cm
    if num > 25 :      #最多顯示25個子圖
        num = 25
    for i in range(0,num):
        ax = plt.subplot(5,5,i+1)   #獲取當前要處理的子圖
        ax.imshow(np.reshape(images[index],(28,28)), cmap = 'binary')              #顯示第index個影象
        title = "labels="+str(np.argmax(labels[index]))              #構建該圖上要顯示的title資訊
        if len(prediction)>0:
            title += ",predict="+str(prediction[index])
            
        ax.set_title(title,fontsize=10)    #顯示圖上的title資訊
        ax.set_xticks([])           #不顯示座標軸
        ax.set_yticks([])
        index += 1
    plt.show()
#視覺化預測結果
# plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,10)

plot_images_labels_prediction(mnist.test.images,mnist.test.labels,prediction_result,10,25)

六、評估與應用

#評估模型
#完成訓練後,在測試集上評估模型的準確率
accu_test = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
print("Test Accuracy:",accu_test)
#完成訓練後,在驗證集上評估模型的準確率
accu_validation = sess.run(accuracy,feed_dict={x:mnist.validation.images,y:mnist.validation.labels})
print("Test Accuracy:",accu_validation)
#完成訓練後,在訓練集上評估模型的準確率
accu_train = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels})
print("Test Accuracy:",accu_train)
#應用模型
#在建立模型並進行訓練後,若認為準確率可以接受,則可以使用此模型進行預測
#由於pred預測結果是one_hot編碼格式,所以需要轉換成0~9數字
prediction_result = sess.run(tf.argmax(pred,1),feed_dict={x:mnist.test.images})

#檢視預測結果中的前10項
prediction_result[0:10]

七、tf.random_normal()介紹

#tf.random_normal()介紹
norm = tf.random_normal([100])    #生成100個隨機數
with tf.Session() as sess:
    norm_data = norm.eval()
print(norm_data[:10])

import matplotlib.pyplot as plt
plt.hist(norm_data)
plt.show()
輸出結果:
[-1.20503342 -0.40912333  1.02314627  0.91239542 -0.44498116  1.46095467
  1.71958613 -0.02297023 -0.04446657 -1.58943892]