Tensorflow初探之MNIST資料集學習
阿新 • • 發佈:2019-01-09
MNIST資料集是手寫數字0~9的資料集,一般被用作機器學習領域的測試,相當於HelloWorld級別。
本程式先從網上匯入資料,再利用最小梯度法進行訓練使得樣本交叉熵最小,最後給出訓練之後程式的準確率。
交叉熵的定義:
y 是我們預測的概率分佈, y' 是實際的分佈。
該指標用來衡量學習結果與實際情況的差距。
import tensorflow.examples.tutorials.mnist.input_data as input_data import tensorflow as tf #initialize mnist = input_data.read_data_sets("MNIST_data/", one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784,10])) b = tf.Variable(tf.zeros([10])) y = tf.nn.softmax(tf.matmul(x,W) + b) y_ = tf.placeholder("float", [None,10]) cross_entropy = -tf.reduce_sum(y_*tf.log(y)) train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) #train init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) for i in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) #predict correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float")) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
先用placeholder和variable初始化。
placeholder(佔位符)一般用於匯入資料,而Variable一般是與學習過程相關的變數。
之後,啟動Session進行訓練,本程式當中隨機選取了100張圖片進行訓練。
最後執行的準確率為91.49%.
/home/jcole/PycharmProjects/MNIST-HelloWorld/venv/bin/python /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/MNIST.py WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:198: retry (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Use the retry module or similar alternatives. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/MNIST.py:4: read_data_sets (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:260: maybe_download (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Please write your own downloading logic. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/base.py:219: retry.<locals>.wrap.<locals>.wrapped_fn (from tensorflow.contrib.learn.python.learn.datasets.base) is deprecated and will be removed in a future version. Instructions for updating: Please use urllib or similar directly. Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes. Extracting MNIST_data/train-images-idx3-ubyte.gz WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:262: extract_images (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:267: extract_labels (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.data to implement this functionality. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:110: dense_to_one_hot (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Instructions for updating: Please use tf.one_hot on tensors. Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes. Extracting MNIST_data/train-labels-idx1-ubyte.gz Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes. Extracting MNIST_data/t10k-images-idx3-ubyte.gz Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/datasets/mnist.py:290: DataSet.__init__ (from tensorflow.contrib.learn.python.learn.datasets.mnist) is deprecated and will be removed in a future version. Extracting MNIST_data/t10k-labels-idx1-ubyte.gz Instructions for updating: Please use alternatives such as official/mnist/dataset.py from tensorflow/models. WARNING:tensorflow:From /home/jcole/PycharmProjects/MNIST-HelloWorld/venv/lib/python3.5/site-packages/tensorflow/python/util/tf_should_use.py:118: initialize_all_variables (from tensorflow.python.ops.variables) is deprecated and will be removed after 2017-03-02. Instructions for updating: Use `tf.global_variables_initializer` instead. 2018-04-23 16:27:25.917321: I tensorflow/core/platform/cpu_feature_guard.cc:140] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA 0.9149 Process finished with exit code 0