1. 程式人生 > 其它 >無監督-DEEP GRAPH INFOMAX

無監督-DEEP GRAPH INFOMAX

無監督-DEEP GRAPH INFOMAX

標籤:圖神經網路、無監督

動機

  • 在真實世界中,圖的標籤是較少的,而現在圖神經的高效能主要依賴於有標籤的真是資料集
  • 在無監督中,隨機遊走犧牲了圖結構資訊和強調的是鄰域資訊,並且效能高度依賴於超引數的選擇

貢獻

  • 在無監督學習上,首次結合互資訊提出了一個圖節點表示學習方法-DGI
  • 該方法不依賴隨機遊走目標,並且使用與直推式學習和歸納學習
  • DGI 依賴於最大限度地擴大圖增強表示和目前提取到的圖資訊之間的互資訊

思想

符號定義

節點特徵集合: \(X \in \mathbb{R}^{N \times F}\),鄰接矩陣: \(A \in \mathbb{R}^{N \times N}\)

, 編碼器: \(\varepsilon~~~~ \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times F'}\),鑑別器(discriminator): \(D~~ ~~ \mathbb{R}^{F} \times \mathbb{R}^{F} \rightarrow \mathbb{R}\),腐蝕函式(corruption function): \(C~~~\mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{M \times F} \times \mathbb{R}^{M \times M}\)
,節點的表示 (patch representations): \(\overrightarrow{h_i}\) , 圖表示: \(\overrightarrow{s}\) .

核心

本質上利用大化區域性互資訊訓練一個模型 (編碼器) \(\varepsilon~~~~ \mathbb{R}^{N \times F} \times \mathbb{R}^{N \times N} \rightarrow \mathbb{R}^{N \times F'}\) ,其損失函式 (1) 所示,負例的來源是是通過 corruption function 得到

框架

對於一個原圖 \(G(X, A)\)

,首先利用 corruption function 得到一個破壞後的圖 \(G'(\widetilde{X}, \widetilde{A})\),拿這兩個圖利用編碼器 \(\varepsilon\) 進行編碼, \(H = \varepsilon(X, A) = \{\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_N}\} ~~~~~ H = \varepsilon(\widetilde{X}, \widetilde{A}) = \{\overrightarrow{\widetilde{h_1}},\overrightarrow{\widetilde{h_2}},...,\overrightarrow{\widetilde{h_N}}\}\) , 對於原圖得到每個節點的表示利用一個讀出函式 (readout function) 得到整個圖的表示 \(\overrightarrow{s} = R(H)\) ,最後利用目標函式更新引數

步驟

  • 用 corruption function 進行取樣負樣例得到 \((\widetilde{X}, \widetilde{A}) \approx C(X,A)\)
  • 將原圖(正例)餵給編碼器獲得節點的表示 patch representations \(\overrightarrow{h_i}\), \(H = \varepsilon(X, A) = \{\overrightarrow{h_1},\overrightarrow{h_2},...,\overrightarrow{h_N}\}\)
  • 將破壞後的圖(負例)餵給編碼器獲得節點的表示 patch representations \(\overrightarrow{\widetilde{h_i}}\), \(H = \varepsilon(\widetilde{X}, \widetilde{A}) = \{\overrightarrow{\widetilde{h_1}},\overrightarrow{\widetilde{h_2}},...,\overrightarrow{\widetilde{h_N}}\}\)
  • 通過讀出函式 (readout function) 傳遞輸入圖的patch representations 來得到總的圖的表示 \(\overrightarrow{s} = R(H)\)
  • 通過應用梯度下降最大化 (1) 來更新 \(\varepsilon、R、D\) 的引數

損失函式

\[L = \frac{1}{N + M}(\sum_{i = 1} ^ {N} \mathbb{E}_{(X, A)}[\log D(\overrightarrow{h_i},\overrightarrow{s})] + \sum_{j = 1} ^ {M} [log(1 - D(\overrightarrow{\widetilde{h_j} },\overrightarrow{s}))]) ~~~~~~~~~(1) \]

實驗

直推式學習 (Inductive Learn)

GCN 傳播規則: \(\varepsilon(X, A) = \sigma(\hat{D}^{-\frac{1}{2}}\hat{A} \hat{D}^{-\frac{1}{2}}X\Theta)\)

其中, \(\hat{A} = A + I_N\) 代表加上自環的鄰接矩陣, \(\hat{D}\) 代表相應的度矩陣,滿足 \(\hat{D_{ii} = \sum_{j}\hat{A_{ij}}}\) 對於非線性啟用函式 \(\sigma\) ,選擇 PReLU(parametric ReLU)。\(\Theta \in R^{F \times F'}\) 是應用於每個節點的可學習線性變換。

對於 corruption function C ,直接採用 \(\widetilde{A} = A\),但是 \(\widetilde{X}\) 是由原本的特徵矩陣 \(X\) 經過隨機變換得到的。也就是說,損壞的圖(corrupted graph)由與原始圖完全相同的節點組成,但它們位於圖中的不同位置,因此將得到不同的鄰近表示。

歸納式學習 (Transductive Learn)

對於歸納學習,不再在編碼器中使用 GCN 更新規則(因為學習的濾波器依賴於固定的和已知的鄰接矩陣);相反,我們應用平均池( mean-pooling)傳播規則,GraphSAGE-GCN:\(MP(X,A) = \hat{D}^{-1}\hat{A}X\Theta\)