1. 程式人生 > >TensorFlow學習筆記之五——原始碼分析之最近演算法

TensorFlow學習筆記之五——原始碼分析之最近演算法

import numpy as np
import tensorflow as tf

# Import MINST data
import input_data
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
#這裡主要是匯入資料,資料通過input_data.py已經下載到/tmp/data/目錄之下了,這裡下載資料的時候,需要提前用瀏覽器嘗試是否可以開啟
#http://yann.lecun.com/exdb/mnist/,如果打不開,下載資料階段會報錯。而且一旦資料下載中斷,需要將之前下載的未完成的資料清空,重新
#進行下載,否則會出現CRC Check錯誤。read_data_sets是input_data.py裡面的一個函式,主要是將資料解壓之後,放到對應的位置。
# In this example, we limit mnist data
Xtr, Ytr = mnist.train.next_batch(5000) #5000 for training (nn candidates)
Xte, Yte = mnist.test.next_batch(200) #200 for testing
#mnist.train.next_batch,其中train和next_batch都是在input_data.py裡定義好的資料項和函式。此處主要是取得一定數量的資料。

# Reshape images to 1D
Xtr = np.reshape(Xtr, newshape=(-1, 28*28))
Xte = np.reshape(Xte, newshape=(-1, 28*28))
#將二維的影象資料一維化,利於後面的相加操作。
# tf Graph Input
xtr = tf.placeholder("float", [None, 784])
xte = tf.placeholder("float", [784])
#設立兩個空的型別,並沒有給具體的資料。這也是為了基於這兩個型別,去實現部分的graph。

# Nearest Neighbor calculation using L1 Distance
# Calculate L1 Distance
distance = tf.reduce_sum(tf.abs(tf.add(xtr, tf.neg(xte))), reduction_indices=1)
# Predict: Get min distance index (Nearest neighbor)
pred = tf.arg_min(distance, 0)
#最近鄰居演算法,算最近的距離的鄰居,並且獲取該鄰居的下標,這裡只是基於空的型別,實現的graph,並未進行真實的計算。
accuracy = 0.
# Initializing the variables
init = tf.initialize_all_variables()
#初始化所有的變數和未分配數值的佔位符,這個過程是所有程式中必須做的,否則可能會讀出隨機數值。
# Launch the graph
with tf.Session() as sess:
    sess.run(init)

    # loop over test data
    for i in range(len(Xte)):
        # Get nearest neighbor
        nn_index = sess.run(pred, feed_dict={xtr: Xtr, xte: Xte[i,:]})
        # Get nearest neighbor class label and compare it to its true label
        print "Test", i, "Prediction:", np.argmax(Ytr[nn_index]), "True Class:", np.argmax(Yte[i])
        # Calculate accuracy
        if np.argmax(Ytr[nn_index]) == np.argmax(Yte[i]):
            accuracy += 1./len(Xte)
    print "Done!"
    print "Accuracy:", accuracy
#for迴圈迭代計算每一個測試資料的預測值,並且和真正的值進行對比,並計算精確度。該演算法比較經典的是不需要提前訓練,直接在測試階段進行識別。



相關API:

tf.reduce_sum(input_tensor, reduction_indices=None, keep_dims=False, name=None)

Computes the sum of elements across dimensions of a tensor.

Reduces input_tensor along the dimensions given in reduction_indices. Unless keep_dims is true, the rank of the tensor is reduced by 1 for each entry in reduction_indices

. If keep_dims is true, the reduced dimensions are retained with length 1.

If reduction_indices has no entries, all dimensions are reduced, and a tensor with a single element is returned.

For example:

# 'x' is [[1, 1, 1]
#         [1, 1, 1]]
tf.reduce_sum(x) ==> 6
tf.reduce_sum(x, 0) ==> [2, 2, 2]
tf.reduce_sum(x, 1
) ==> [3, 3] tf.reduce_sum(x, 1, keep_dims=True) ==> [[3], [3]] tf.reduce_sum(x, [0, 1]) ==> 6
Args:
  • input_tensor: The tensor to reduce. Should have numeric type.
  • reduction_indices: The dimensions to reduce. If None (the default), reduces all dimensions.
  • keep_dims: If true, retains reduced dimensions with length 1.
  • name: A name for the operation (optional).
Returns:

The reduced tensor.

點評:這個API主要是降維使用,在這個例子中,將測試圖片和所有圖片相加後的二維矩陣,降為每個圖片只有一個最終結果的一維矩陣。

相關推薦

TensorFlow學習筆記——原始碼分析最近演算法

import numpy as np import tensorflow as tf # Import MINST data import input_data mnist = input_data.read_data_sets("/tmp/data/", one_hot=

Spring 學習筆記)IOC零註解配置(用註解代替applicationContext.xml配置檔案)

有了這個東西開發方便很多,不用寫xml那些配置嘍。 package org.spring.exampleAOP; import org.springframework.context.annotation.ComponentScan; import org.springframework.co

python3.5《機器學習實戰》學習筆記):決策樹演算法實戰預測隱形眼鏡型別

一、使用決策樹預測隱形眼鏡型別 在上一篇文章中,我們學習了決策樹演算法,接下來,讓我們通過一個例子講解決策樹如何預測患者需要佩戴的隱形眼鏡型別。 隱形眼鏡資料集是非常著名的資料集,它包含了很多患者眼部狀況的觀察條件以及醫生推薦的隱形眼鏡型別。隱形眼鏡

Java NIO學習筆記:結合原始碼分析+Reactor模式

Java NIO和IO的主要區別 下表總結了Java NIO和IO之間的主要差別,我會更詳細地描述表中每部分的差異。 IO                           NIO 面向流                     面向緩衝 阻塞IO

google機器學習框架tensorflow學習筆記

Pandas簡介 pandas  是一種列存資料分析 API。它是用於處理和分析輸入資料的強大工具,很多機器學習框架都支援將  pandas  資料結構作為輸入。 雖然全方位介紹  pandas  API 會佔據很

Tensorflow學習筆記)——結構化模型及Skip-gram模型的實現

一、結構化模型 結構化我們的模型,可以方便我們Debug和良好的視覺化。一般我們的模型都是由以下兩步構成,第一步是構建計算圖,第二步是執行計算圖。 Assemble Graph Define placeholders for Inp

tensorflow學習筆記):TensorFlow變數共享和資料讀取

  這一節我們提及了三個內容:變數共享、執行緒和佇列和資料讀取,這些都是TensorFlow官方指導中的內容。會在程式中經常遇到所以放在一起進行敘述。前面都是再利用已有的資料進行tensorflow的學習,這一節我們要學習怎麼從檔案中讀取我們需要的各類資料。

jav學習筆記-String原始碼分析

java中用String類表示字串,是lang包裡面使用頻率很高的一個類,今天我們就來深入原始碼解析。事例和特性均基於java8版本。 基礎知識 String內部使用char[]陣列實現,是不可變類。 public final class Stri

TensorFlow學習筆記)—— MNIST —— 資料下載,讀取

MNIST資料下載 本教程的目標是展示如何下載用於手寫數字分類問題所要用到的(經典)MNIST資料集。 教程 檔案 本教程需要使用以下檔案: 檔案 目的 下載用於訓練和測試的MNIST資料集的原始碼 備註: input_data.py

tensorflow學習筆記(十): variable scope

variable scope tensorflow 為了更好的管理變數,提供了variable scope機制 官方解釋: Variable scope object to carry defaults to provide to get_variable

TensorFlow學習筆記原始碼分析(3)---- retrain.py

"""簡單呼叫Inception V3架構模型的學習在tensorboard顯示了摘要。 這個例子展示瞭如何採取一個Inception V3架構模型訓練ImageNet影象和訓練新的頂層,可以識別其他類的影象。 每個影象裡,頂層接收作為輸入的一個2048維向量。這

TensorFlow學習筆記原始碼分析(1)----最近演算法nearest_neighbor

import numpy as np import tensorflow as tf # Import MINST data import input_data mnist = input_data.read_data_sets("/tmp/data/", one_hot

Unity3DMecanim動畫系統學習筆記):Animator Controller

浮點 key 發現 菜單 融合 stat mon 好的 project 簡介 Animator Controller在Unity中是作為一種單獨的配置文件存在的文件類型,其後綴為controller,Animator Controller包含了以下幾種功能: 可以對

Tensorflow學習筆記池化

Tensorflow學習筆記之池化 在深度學習網路中,經常會遇到池化操作,並且往往是在卷積之後,池化操作的意義是降低卷積層輸出特徵向量的維度,並且通過不同的池化方法使不同維度的卷積層輸出結果得到相同維度的特徵向量結果。 1、一般池化 池化過程作用於不重疊區域 我們定義池化視窗的大小為s

Tensorflow學習筆記tf.nn.relu

Tensorflow學習筆記之tf.nn.relu 關於Tensorflow的學習筆記大部分為其他部落格或者書籍轉載,只為督促自己學習。 線性整流函式(Rectified Linear Unit,ReLU),又稱修正線性單元。其定義如下圖,在橫座標的右側,ReLU函式為線性函式。在橫座標

Tensorflow學習筆記tf.layers.conv2d

Tensorflow學習筆記 關於Tensorflow的學習筆記大部分為其他部落格或者書籍轉載,只為督促自己學習。 conv2d(inputs, filters, kernel_size, strides=(1, 1), padding='valid', d

嵌入式核心及驅動開發學習筆記) 編寫字元驅動步驟總結

1,實現模組載入和解除安裝入口函式         module_init(chr_dev_init);         module_exit(chr_dev_exit);

TensorFlow學習筆記--[tf.clip_by_global_norm,tf.clip_by_value,tf.clip_by_norm等的區別]

以下這些函式可以用於解決梯度消失或梯度爆炸問題上。 1. tf.clip_by_value tf.clip_by_value( t, clip_value_min, clip_value_max, name=None ) 輸入一個張量t,把t中的每一個元素的值都

springMVC原始碼學習addFlashAttribute原始碼分析

本文主要從falshMap初始化,存,取,消毀來進行原始碼分析,springmvc版本4.3.18。關於使用及驗證請參考另一篇https://www.cnblogs.com/pu20065226/p/10032048.html 1.初始化和呼叫,首先是入springMVC 入口webmvc包中org.spr

go 原始碼學習---Tail 原始碼分析

已經有兩個月沒有寫部落格了,也有好幾個月沒有看go相關的內容了,由於工作原因最近在做java以及大資料相關的內容,導致最近工作較忙,部落格停止了更新,正好想撿起之前go的東西,所以找了一個原始碼學習 這個也是之前用go寫日誌收集的時候用到的一個包 :github.com/hpcloud/tail, 這次就學