1. 程式人生 > 實用技巧 >【深度學習實戰】Pytorch Geometric實踐——利用Pytorch搭建GNN

【深度學習實戰】Pytorch Geometric實踐——利用Pytorch搭建GNN

(8條訊息) 【深度學習實戰】Pytorch Geometric實踐——利用Pytorch搭建GNN_喵木木的部落格-CSDN部落格

1. 安裝

首先,我們先查一下我們的pytorch的版本。要求至少安裝 PyTorch 1.2.0 版本:

python -c "import torch; print(torch.__version__)"
  • 1

接著,查詢對應pytorch安裝的CUDA的版本:

python -c "import torch; print(torch.version.cuda)"
  • 1

然後,安裝Pytorch geometry的軟體包。需要注意的是,這裡的${CUDA}

是前面查詢到的CUDA的版本(cpu, cu92, cu101, cu102)${TORCH}是前面查到的pytorch的版本。(建議將pytorch升級到最新版本再進行安裝)

pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html
pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html
pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html
pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html
pip install torch-geometric
  • 1
  • 2
  • 3
  • 4
  • 5

比如我這裡查到Pytorch的版本是1.5.1(按照官網的教程,pytorch版本為1.5.0或者1.5.1的按照1.5.0來安裝),CUDA的版本是10.2,那麼我的安裝語句如下:

pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html
pip install torch-geometric
  • 1
  • 2
  • 3
  • 4
  • 5

2. 基本概念介紹

2.1 Data Handling of Graphs 圖形資料處理

圖(Graph)是描述實體(節點)和關係(邊)的資料模型。在Pytorch Geometric中,圖被看作是torch_geometric.data.Data的例項,並擁有以下屬性:

屬性描述
data.x 節點特徵,維度是[num_nodes, num_node_features]
data.edge_index 維度是[2, num_edges],描述圖中節點的關聯關係,每一列對應的兩個元素,分別是邊的起點和重點。資料型別是torch.long。需要注意的是,data.edge_index是定義邊的節點的張量(tensor),而不是節點的列表(list)。
data.edge_attr 邊的特徵矩陣,維度是[num_edges, num_edge_features]
data.y 訓練目標(維度可以是任意的)。對於節點相關的任務,維度為[num_nodes, *];對於圖相關的任務,維度為[1,*]
data.position 節點位置矩陣(Node position matrix),維度為[num_nodes, num_dimensions]

下面是一個簡單的例子:

首先匯入需要的包:

import torch
from torch_geometric.data import Data
  • 1
  • 2


比如上圖所示的圖結構,我們首先定義節點特徵向量:

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
  • 1

接著定義邊,下面兩種定義方式是等價的。第二種方式可能更符合我們的閱讀習慣,但是需要注意的是此時應當增加一個edge_index=edge_index.t().contiguous()的操作。此外,由於是無向圖,雖然只有兩條邊,但是我們需要四組關係說明來描述邊的兩個方向。

## 法1
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)

## 法2
edge_index = torch.tensor([[0, 1],
                           [1, 0],
                           [1, 2],
                           [2, 1]], dtype=torch.long)
data = Data(x=x, edge_index=edge_index.t().contiguous())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11

可以得到:

同時,Data物件提供了一些很實用的函式:

print('data\'s keys: {}'.format(data.keys))
print('-'*5)
for key, item in data:
    print("{} found in data".format(key))
print('-'*5)  
print('Does data has attribute \'edge_attr\'? {}'.format('edge_attr' in data))
print('data has {} nodes'.format(data.num_nodes))
print('data has {} edges'.format(data.num_edges))
print('The nodes in data have {} feature(s)'.format(data.num_node_features))
print('Does data contains isolated nodes? {}'.format(data.contains_isolated_nodes()))
print('Does data contains self loops? {}'.format(data.contains_self_loops()))
print('is data directed? {}'.format(data.is_directed()))
print(data['x'])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

輸出:

data's keys: ['x', 'edge_index']
-----
edge_index found in data
x found in data
-----
Does data has attribute 'edge_attr'? False
data has 3 nodes
data has 4 edges
The nodes in data have 1 feature(s)
Does data contains isolated nodes? False
Does data contains self loops? False
is data directed? False
tensor([[-1.],
        [ 0.],
        [ 1.]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15

同樣可以在GPU上執行data:

device = torch.device('cuda')
data = data.to(device)
  • 1
  • 2

2.2 Common Benchmark Datasets 常見的基準資料集

PyTorch Geometric提供很多基準資料集,包括

想要使用這些資料集,只要進行初始化,資料就會自動下載。比如我們要使用ENZYMES資料集(該資料集包括600張圖,有6個類別):

from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='.\data\ENZYMES', name='ENZYMES')
  • 1
  • 2

程式就會自動執行下載:

Downloading http://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/ENZYMES.zip
Extracting data\ENZYMES\ENZYMES\ENZYMES.zip
Processing...
Done!
  • 1
  • 2
  • 3
  • 4

我們可以看一下這個資料集的一些屬性:

print(dataset)
print(len(dataset))
print(dataset.num_classes)
print(dataset.num_node_features)
  • 1
  • 2
  • 3
  • 4

輸出:

我們可以看下其中一張圖的結構:

data = dataset[14]
print(data)
print(data.is_undirected())
  • 1
  • 2
  • 3

輸出:

  • 1
  • 2

我們可以看到資料集中的第一個圖包含36個節點,每個節點有3個特徵。圖中有128/2 = 64條無向邊,圖被分類為“1”類。在將資料集分為訓練集和測試集之前,可以呼叫dataset = dataset.shuffle()將資料集進行隨機打亂。這個語句和下面這段程式是等價的:

perm = torch.randperm(len(dataset))
dataset = dataset[perm]
  • 1
  • 2

我們再來看硬外一個數據集Cora

from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='.\data\Cora', name='Cora')
data = dataset[0]

print(data)
print(data.is_undirected())
print(data.train_mask.sum().item())
print(data.val_mask.sum().item())
print(data.test_mask.sum().item())

print(len(dataset))
print(dataset.num_classes)
print(dataset.num_node_features)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14

輸出:

  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8

可以看到,前面的資料集針對的是“網路分類”的任務,而這個資料集針對的是“節點分類”的任務。每個節點又1433個特徵,被分為7類。這個圖是一個無向圖,共有10556/2=5278條邊,共有2708個節點。這裡有三個需要注意的引數:

  • train_mask——指明訓練集中的節點(可以看到,在這個資料集中,訓練集裡有140個節點)
  • val_mask——指明驗證集中的節點(可以看到,在這個資料集中,驗證集裡有500個節點)
  • test_mask——指明測試集中的節點(可以看到,在這個資料集中,測試集裡有1000個節點)

2.3 Mini-batches

神經網路通常以批處理的方式進行訓練。在pytorch中,通常用資料載入器DataLoader來進行批處理。

from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_scatter import scatter_mean

dataset = TUDataset(root='.\data\ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)

for data in loader:
    print(data)
    print(data.num_graphs)
    x = scatter_mean(data.x, data.batch, dim=0)
    print(x.size())
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12

張圖。

這個scatter函式實質上是對節點的一個整合,節點根據batch的標籤(按照圖)來進行整合,下面這張官方文件中的圖可以很好地說明scatter函式的作用:

2.4 Data Transforms 資料轉換

torch_geometric.transforms.Compose提供了資料轉換的方法,可以方便使用者將資料轉換成既定的格式或者用於資料的預處理。在之前使用torchvision處理影象時,也會用到資料轉換的相關方法,將圖片轉換成畫素矩陣,這裡的資料轉換就類似torchvision在影象上的處理。

2.5 Learning Methods on Graphs——the first graph neural network 搭建我們的第一個圖神經網路

下面我們來嘗試著搭建我們的第一圖神經網路。關於圖神經網路,可以看一下這篇部落格——GRAPH CONVOLUTIONAL NETWORKS

資料集準備

我們使用的是Cora資料集。

from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset)
  • 1
  • 2
  • 3

輸出:

搭建網路模型

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19

模型的結構包含兩個GCNConv層,選擇ReLU作為非線性函式,最後通過softmax輸出分類結果。

模型訓練和驗證

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()

model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('Accuracy: {:.4f}'.format(acc))
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

輸出:

3. CREATING MESSAGE PASSING NETWORKS 建立訊息傳遞網路

將卷積神經網路中的“卷積運算元”應用到圖上面,核心在於neighborhood aggregation機制,或者說是message passing的機制。Aggregate Neighbours,核心思想在於基於區域性網路連線來生成Node embeddings(Generate node embeddings based on local network neighborhoods)。如下面這個圖:

例如圖中節點A的embedding決定於其鄰居節點

{ B , C , D } \{B,C,D\}

{B,C,D},而這些節點又受到它們各自的鄰居節點的影響。圖中的“黑箱”可以看成是整合其鄰居節點資訊的操作,它有一個很重要的屬性——其操作應該是順序(order invariant)無關的,如求和、求平均、求最大值這樣的操作,可以採用神經網路來獲取。這樣順序無關的聚合函式符合網路節點無序性的特徵,當我們對網路節點進行重新編號時,我們的模型照樣可以使用。

那麼,對於每個節點來說,它的計算圖就由其鄰居節點的數量來決定——


模型的深度可以自己定義(Model can be of arbitrary depth):

  • Nodes have embeddings at each layer
  • Layer-0節點

3.1 Message passing 基本類

PyTorch Geometric 提供了基本類——MessagePassing,可以實現上述的圖神經網路,來實現訊息傳遞或訊息聚集(which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. )

MessagePassing類有三個引數:

  • aggr (string, optional)——指定採用的置換不變函式,預設是add,可以定義為addmeanmaxNone
  • **flow (string, optional) **——指定資訊傳遞的反向,預設是source_to_target,還可以設定為target_to_source
  • **node_dim (int, optional) **——The axis along which to propagate. 預設是-2。

同時,MessagePassing提供了一些比較實用的方法:

  • MessagePassing.propagate(edge_index, size=None, **kwargs)
  1. ii​全部設定為1。在pytorch geometric裡面,是利用edge_index來實現。如果是有權圖,則新增的自迴圈邊以fill_value作為權。該方法最後返回兩個值——`edge_index, edge_weight``。
import torch
from torch_geometric.utils import add_self_loops, degree

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

print("original edge_index ")
print(edge_index)

edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
print("new edge_index")
print(edge_index)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13

最後輸出:

original edge_index 
tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])
new edge_index
tensor([[0, 1, 1, 2, 0, 1, 2],
        [1, 0, 2, 1, 0, 1, 2]])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  1. Linearly transform node feature matrix.第二步是對節點的特徵矩陣進行線性變換。主要通過一個線性層torch.nn.Linear實現。

  2. Compute normalization coefficients.第三步是對變換後的節點特徵進行標準化。節點的度可以通過torch_geometric.utils.degree實現。

import torch
from torch_geometric.utils import add_self_loops, degree

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)

print("original edge_index ")
print(edge_index)

edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
print("new edge_index")
print(edge_index)

row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
print(deg)
deg_inv_sqrt = deg.pow(-0.5)
print(deg_inv_sqrt)
print(deg_inv_sqrt[row])
print(deg_inv_sqrt[col])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21

最後輸出:

original edge_index 
tensor([[0, 1, 1, 2],
        [1, 0, 2, 1]])
new edge_index
tensor([[0, 1, 1, 2, 0, 1, 2],
        [1, 0, 2, 1, 0, 1, 2]])
tensor([2., 3., 2.])
tensor([0.7071, 0.5774, 0.7071])
tensor([0.7071, 0.5774, 0.5774, 0.7071, 0.7071, 0.5774, 0.7071])
tensor([0.5774, 0.7071, 0.7071, 0.5774, 0.7071, 0.5774, 0.7071])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  1. Sum up neighboring node features (“add” aggregation).

前面三步是message passing之前的預操作,第四、第五步可以採用MessagePassing類裡面的方法完成。

完整的程式碼如下:

import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree

class GCNConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(GCNConv, self).__init__(aggr='add')  # "Add" aggregation (Step 5).
        self.lin = torch.nn.Linear(in_channels, out_channels)

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 1: Add self-loops to the adjacency matrix.
        edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

        # Step 2: Linearly transform node feature matrix.
        x = self.lin(x)

        # Step 3: Compute normalization.
        row, col = edge_index
        deg = degree(col, x.size(0), dtype=x.dtype)
        deg_inv_sqrt = deg.pow(-0.5)
        norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

        # Step 4-5: Start propagating messages.
        return self.propagate(edge_index, x=x, norm=norm)

    def message(self, x_j, norm):
        # x_j has shape [E, out_channels]

        # Step 4: Normalize node features.
        return norm.view(-1, 1) * x_j
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33

我們建立的這個神經網路模型GCNConv繼承於基礎類MessagePassing,並且採用求和函式作為

□ \square

□函式,通過super(GCNConv, self).__init__(aggr='add')來初始化。在完成1-3步之後,呼叫MessagePassing中的propagate()方法來完成4-5步,進行資訊傳播。message函式用於對節點的鄰居節點的資訊進行標準化。

我們可以通過一個案例來感受一下這個模型的輸入和輸出。

x = torch.tensor(torch.rand(3,2), dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
conv = GCNConv(2, 4)                         
  • 1
  • 2
  • 3
  • 4


設有上圖所示的網路,網路中有三個節點,每個節點有2個特徵值。並構建神經網路conv = GCNConv(2, 4)。下面是程式執行的每一步輸出的結果:

x is
tensor([[0.1819, 0.1848],
        [0.8479, 0.1754],
        [0.7511, 0.9781]])

----Step 1: Add self-loops to the adjacency matrix.----
tensor([[0, 1, 1, 2, 0, 1, 2],
        [1, 0, 2, 1, 0, 1, 2]])

----Step 2: Linearly transform node feature matrix.----
linear weight is 
Parameter containing:
tensor([[-0.6532, -0.3349],
        [ 0.5238, -0.5996],
        [-0.6279, -0.5872],
        [-0.4064,  0.5893]], requires_grad=True)
linear bias is
Parameter containing:
tensor([ 0.5966, -0.4339,  0.0263,  0.1577], requires_grad=True)
transformed x is
tensor([[ 0.4160, -0.4494, -0.1964,  0.1927],
        [-0.0159, -0.0949, -0.6090, -0.0835],
        [-0.2215, -0.6270, -1.0196,  0.4289]], grad_fn=<AddmmBackward>)

----Step 3: Compute normalization.----
tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])

----Step 4-5: Start propagating messages.----
tensor([[ 0.2015, -0.2635, -0.3468,  0.0623],
        [ 0.0741, -0.4711, -0.6994,  0.2260],
        [-0.1172, -0.3522, -0.7584,  0.1804]], grad_fn=<ScatterAddBackward>)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31

3.3 Edge Convolution 邊卷積層的實現

邊卷積層的數學定義如下:

x i ( k ) = max ⁡ j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) x_i^{(k)}=\max_{j \in N(i)} h_{\Theta}(x_i^{(k-1)},x_j^{(k-1)}-x_i^{(k-1)})

x

Θ

為多層感知機,類似於GCN,邊卷積層同樣繼承于于基礎類MessagePassing,不同在於採用max函式作為

□ \square

□函式。

邊卷積層的主要理論來自於論文Dynamic Graph CNN for Learning on Point Clouds,這篇文章提出一種邊卷積(EdgeConv)操作,來完成點雲中點與點之間關係的建模,使得網路能夠更好地學習區域性和全域性特徵。具體可以看這兩篇部落格:【深度學習——點雲】DGCNN(EdgeConv)論文筆記:DGCNN(EdgeConv)

import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing

class EdgeConv(MessagePassing):
    def __init__(self, in_channels, out_channels):
        super(EdgeConv, self).__init__(aggr='max') #  "Max" aggregation.
        self.mlp = Seq(Linear(2 * in_channels, out_channels),
                       ReLU(),
                       Linear(out_channels, out_channels))

    def forward(self, x, edge_index):
        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        return self.propagate(edge_index, x=x)

    def message(self, x_i, x_j):
        # x_i has shape [E, in_channels]
        # x_j has shape [E, in_channels]

        tmp = torch.cat([x_i, x_j - x_i], dim=1)  # tmp has shape [E, 2 * in_channels]
        return self.mlp(tmp)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

邊緣卷積實際上是一種動態卷積,它使用特徵空間中的最近鄰重新計算每一層的圖。PyTorch geometry附帶一個GPU加速的批處理k-NN圖形生成方法——torch_geometric.n .pool.knn_graph()

from torch_geometric.nn import knn_graph

class DynamicEdgeConv(EdgeConv):
    def __init__(self, in_channels, out_channels, k=6):
        super(DynamicEdgeConv, self).__init__(in_channels, out_channels)
        self.k = k

    def forward(self, x, batch=None):
        edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
        return super(DynamicEdgeConv, self).forward(x, edge_index)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10

4. 建立自己的資料集

PyTorch Geometric提供了兩個抽象類——torch_geometric.data.Datasettorch_geometric.data.InMemoryDataset。前者適用於不能一次性放進記憶體中的大資料集,後者適用於可以全部放進記憶體中的小資料集。

4.1 “In Memory Datasets”的建立

torch_geometric.data.InMemoryDataset有四個可選引數:

  • root (string, optional)——儲存資料集的根目錄。每個資料集都傳遞一個根資料夾,該根資料夾指示資料集應該儲存在何處。將根資料夾分成兩個資料夾:未處理過的資料集被儲存在raw_dir目錄下;已處理的資料集被儲存在processed_dir目錄下。
  • transform (callable, optional)
  • pre_transform (callable, optional)
  • pre_filter (callable, optional)

建立In Memory Datasets,需要用到四個基本的方法:

  1. raw_file_names()——返回一個包含所有未處理過的資料檔案的檔名的列表。
  2. processed_file_names()——返回一個包含所有處理過的資料檔案的檔名的列表。
  3. download()——下載資料到raw_dir目錄下。
  4. process()——對資料的處理函式,是核心的函式之一。

下面是官方文件給出的一個示例:

import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32

4.2 建立更大的資料集

對於無法全部放進記憶體中的大資料集,可以使用torch_geometric.data.Datasettorch_geometric.data.Dataset的引數和torch_geometric.data.InMemoryDataset的一致。常用的方法如下:

  1. len()——獲取資料集中的資料量。
  2. get(idx)——獲取索引為idx的資料物件。

下面是官方文件給出的一個示例:

import os.path as osp

import torch
from torch_geometric.data import Dataset


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42