1. 程式人生 > >【深度學習】聚焦機制DRAM(Deep Recurrent Attention Model)演算法詳解

【深度學習】聚焦機制DRAM(Deep Recurrent Attention Model)演算法詳解

Ba, Jimmy, Volodymyr Mnih, and Koray Kavukcuoglu. “Multiple object recognition with visual attention.” arXiv preprint arXiv:1412.7755 (2014).

思想

三位作者均來自於風頭正勁的Google DeepMind,三作Koray Kavukcuoglu在AlphaGo的Nature論文中榜上有名。

本文執行的任務相對簡單:從圖片中識別長度、位置未知的手寫數字串。但包含了當今神經網路的諸多熱點方向,包括:

  • 聚焦機制(Attention):每次只看輸入的一小部分,諸次移動觀察範圍。
  • 迴圈神經網路(Recurrent NN):在每一次移動和輸出之間建立記憶
  • 增強學習(Reinforcement learning):在訓練過程中,根據不可導的反饋,從當前位置產生探索性的取樣。

本文和前一篇文章中介紹的RAM(Recurrent Visual Attention Model)演算法極為相似,但是更側重數學推導。建議先閱讀這篇部落格中的解讀。
對於增強學習沒概念的同學,也可以參考這篇部落格:Torch中的增強學習層

模型

核心資料

X: 輸入影象
n: 步驟序號,共有N個步驟,每次檢視影象一小部分。
ln: 第n步檢視的影象位置。整數型別xy座標,影象中心為(0,0),影象邊緣對應的座標為系統超引數,決定搜尋粒度。
x

n: 第n步觀察到的影象內容,稱為glimpse。是以ln為中心,尺寸相同,縮放和範圍等差的影象金字塔。
這裡寫圖片描述
特別要注意的是:xn沒法對ln求導。

子網路

整個系統由若干部分組成,執行不同功能。系統的組成部件都稱為網路。

系統中變數繁多,不必急於看全圖,順序推導即可。

Glimpse網路

輸入:當前位置ln,當前影象塊xn
輸出:當前觀察的資訊gn

形式

gn=Gimage(xn|Wimage)Gloc(ln|Wloc)

GimageGloc是兩個網路,其引數為WimageWloc。分別把影象(what)和位置(where)編碼成統一維度的資訊,進行點乘。

作用:通過小範圍觀測,提取紋理和位置資訊。

條件號後面的W表示某網路引數,此後不再贅述。

Recurrent網路

輸入:當前觀察資訊gn,上一步狀態r1n1,r2n1
輸出:當前的兩個迴圈狀態r1n,r2n

形式

r1n=Rrecur(gn,r1n1|Wr1) r2n=Rrecur(r1n,r2n1|Wr2)

兩個狀態使用相同的網路Rrecur進行估計,只是輸入不同。由於存在兩層迴圈狀態,所以本文演算法稱為Deep RAM。

作用:通過小範圍觀測,更新網路迴圈狀態

Emission網路

輸入:當前第二級迴圈狀態r2n
輸出:下一步建議的觀察位置ln+1

形式

l^n+1=E(r2n|We)

注意,這個給出的l^是一個“建議”,下一步的真正位置可能圍繞這個建議有所偏差。

作用:利用系統迴圈狀態,決定觀測位置。

在RAM演算法中,這部分稱為locator。

Classification網路

輸入:當前第一級迴圈狀態r1n
輸出:類標y

形式

P(y|I)=O(r1n|Wo)

出現概率P的原因是:網路輸出是一個softmax層。
不一定每一步都有輸出,可以設定每K步輸出一個類標,即K次觀察能夠決定一個字母。

作用:從系統迴圈狀態估計分類結果。

在RAM演算法中,這部分稱為Agent。

Context網路

輸入:縮小後的原始影象Icoarse
輸出:第二級迴圈初始狀態

相關推薦

深度學習聚焦機制DRAM(Deep Recurrent Attention Model)演算法

Ba, Jimmy, Volodymyr Mnih, and Koray Kavukcuoglu. “Multiple object recognition with visual attention.” arXiv preprint arXiv:1412

深度學習Deep Learning必備之必背知識點

這篇文章有哪些需要背誦的內容: 1、張量、計算圖、會話     神經網路:用張量表示資料,用計算圖搭建神經網路,用會話執行計算圖,優化線上的權重(引數),得到模型。      張量:標量(單個)、向量(1維)、矩陣(2維)、張量(n維) 2、前向傳播     網路的

深度學習詞的向量化表示

model ref res font 技術 訓練 lin 挖掘 body 如果要一句話概括詞向量的用處,就是提供了一種數學化的方法,把自然語言這種符號信息轉化為向量形式的數字信息。這樣就把自然語言理解的問題要轉化為機器學習的問題。 其中最常用的詞向量模型無非是 one-h

深度學習批歸一化(Batch Normalization)

學習 src 試用 其中 put min 平移 深度 優化方法 BN是由Google於2015年提出,這是一個深度神經網絡訓練的技巧,它不僅可以加快了模型的收斂速度,而且更重要的是在一定程度緩解了深層網絡中“梯度彌散”的問題,從而使得訓練深層網絡模型更加容易和穩定。所以目前

深度學習常用的模型評估指標

是我 初學者 cnblogs 沒有 線下 均衡 顯示 總數 效果 “沒有測量,就沒有科學。”這是科學家門捷列夫的名言。在計算機科學中,特別是在機器學習的領域,對模型的測量和評估同樣至關重要。只有選擇與問題相匹配的評估方法,我們才能夠快速的發現在模型選擇和訓練過程中可能出現的

深度學習吳恩達網易公開課練習(class2 week1 task2 task3)

公開課 網易公開課 blog 校驗 過擬合 limit 函數 its cos 正則化 定義:正則化就是在計算損失函數時,在損失函數後添加權重相關的正則項。 作用:減少過擬合現象 正則化有多種,有L1範式,L2範式等。一種常用的正則化公式 \[J_{regularized}

深度學習深入理解ReLU(Rectifie Linear Units)激活函數

appdata 稀疏編碼 去掉 ren lock per 作用 開始 href 論文參考:Deep Sparse Rectifier Neural Networks (很有趣的一篇paper) Part 0:傳統激活函數、腦神經元激活頻率研究、稀疏激活性

深度學習一文讀懂機器學習常用損失函數(Loss Function)

back and 們的 wiki 導出 歐氏距離 classes 自變量 關於 最近太忙已經好久沒有寫博客了,今天整理分享一篇關於損失函數的文章吧,以前對損失函數的理解不夠深入,沒有真正理解每個損失函數的特點以及應用範圍,如果文中有任何錯誤,請各位朋友指教,謝謝~

深度學習ubuntu16.04下安裝opencv3.4.0

form 線程 ubunt con sudo ive tbb 依賴包 復制代碼 1、首先安裝一些編譯工具 # 安裝編譯工具 sudo apt-get install build-essential # 安裝依賴包 sudo apt-get install cmake

深度學習Pytorch 學習筆記

chang www. ans 如何 ret == 筆記 etc finished 目錄 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 邏輯回歸 - 二分類 Lect

深度學習Semantic Segmentation 語義分割

翻譯自 A 2017 Guide to Semantic Segmentation with Deep Learning What exactly is semantic segmentation? 對圖片的每個畫素都做分類。 較為重要的語義分割資料集有:VOC2

深度學習Drop out

來源:Dropout: A Simple Way to Prevent Neural Networks from Overfitting 1. 原理 在每個訓練批次的前向傳播中,以概率p保留部分神經元。目的是:簡化神經網路的複雜度,降低過擬合風險。 根據保留概率p計算一個概率向量r

深度學習Tensorboard 視覺化好幫手2

轉自https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/4-2-tensorboard2/ 目錄 要點  製作輸入源  在 layer 中為 Weights, biases 設定變化

深度學習Tensorboard 視覺化好幫手1

轉自https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/4-1-tensorboard1/ 注意: 本節內容會用到瀏覽器, 而且與 tensorboard 相容的瀏覽器是 “Google Chrome”.

深度學習Tensorflow函式

  目錄 tf.truncated_normal tf.random_normal tf.nn.conv2d tf.nn.max_pool tf.reshape tf.nn.softmax tf.reduce_sum tf.reduce_max,tf.r

深度學習Tensorflow——CNN 卷積神經網路 2

轉自https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-05-CNN3/ 目錄 圖片處理  建立卷積層  建立全連線層  選優化方法  完整程式碼

深度學習Tensorflow——CNN 卷積神經網路 1

轉自https://morvanzhou.github.io/tutorials/machine-learning/tensorflow/5-04-CNN2/ 這一次我們會說道 CNN 程式碼中怎麼定義 Convolutional 的層和怎樣進行 pooling. 基於上一次卷積神經網路的介

深度學習三維點雲資料集總結

點雲資料集總結 三維點雲資料,三維深度學習 1.ShapeNet ShapeNet是一個豐富標註的大規模點雲資料集,其中包含了55中常見的物品類別和513000個三維模型。 2.ShapeNetSem 這是一個小的資料庫,包含了270類的12000個物

深度學習ResNet解讀及程式碼實現

簡介 ResNet是何凱明大神在2015年提出的一種網路結構,獲得了ILSVRC-2015分類任務的第一名,同時在ImageNet detection,ImageNet localization,COCO detection和COCO segmentation等任務中均獲得了第一名,在當

深度學習GoogLeNet系列解讀 —— Inception v4

目錄 GoogLeNet系列解讀 Inception v1 Inception v2 Inception v3 Inception v4 簡介 在介紹Inception v4之前,首先說明一下Inception v4沒有使用殘差學習的思想。大部分小夥伴對Inc