1. 程式人生 > 實用技巧 >DGL學習(四): 圖分類教程

DGL學習(四): 圖分類教程

本節中我們將使用DGL批處理多個大小和形狀可變的圖形。

使用包含如下8種類型圖的資料集。

from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
dataset = MiniGCDataset(80, 10, 20) ## 產生80個樣本, 每個樣本的節點數位於 [10,20]之間
graph, label = dataset[10]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title(
'Class: {:d}'.format(label)) plt.show()

影象由於張量大小一致,很容易就可以進行批量學習。圖如何進行批量學習?

圖批量學習主要有以下兩個挑戰。

1. 圖是稀疏的。 2. 不同圖中的節點數和邊數是不同的。

為了解決這個問題,DGL提供了dgl.batch() 進行批處理。 他的想法是將一批圖視為一張大圖,大圖裡面有多個不相連的連通分量嗎,如下所示。

定義collate函式,從給定的Graph和label對列表中形成一個mini-batch。返回值依然是一個DGLGraph 和 label組成的tensor, 這樣做DGL能夠並行處理邊和節點,大大提高了效率。

import dgl
import torch

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

整個演算法的流程框架如下:

在一個batch的graph中,執行訊息傳遞和GraphConv,使得節點與其他節點進行通訊。 訊息傳遞後,根據節點(邊)的屬性計算一個張量作為graph representation。 此步驟被稱為readout或aggregation。 最後,將輸入graphrepresentation到分類器g中進行預測。

模型結構: 輸入特徵是節點的入度,通過兩層圖卷積之後,將圖中所有節點的輸出拼接起來,作為圖的表示向量,再通過一個全連線神經網路進行分類。

from dgl.data import MiniGCDataset
import dgl
import torch
from torch.utils.data import DataLoader
from dgl.nn.pytorch import GraphConv
import torch.nn as nn
import torch.nn.functional as F


class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        # 使用節點的入度作為初始特徵
        h = g.in_degrees().view(-1,1).float()
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        g.ndata['h'] = h ## 節點特徵經過兩層卷積的輸出
        hg = dgl.mean_nodes(g, 'h') # 圖的特徵是所有節點特徵的均值
        y = self.classify(hg)
        return y

訓練模型:

## 訓練模型
trainset = MiniGCDataset(320, 10, 20) ## 產生80個樣本, 每個樣本的節點數位於 [10,20]之間
testset = MiniGCDataset(80, 10, 20)

def collate(samples):
    # The input `samples` is a list of pairs
    #  (graph, label).
    graphs, labels = map(list, zip(*samples))
    batched_graph = dgl.batch(graphs)
    return batched_graph, torch.tensor(labels)

data_loader = DataLoader(trainset, batch_size=32, shuffle=True,collate_fn=collate)

model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr = 0.001)

model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss+=loss.detach().item()
    epoch_loss /= (iter+1)

    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()

測試模型:

model.eval()
# Convert a list of tuples to two lists
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))