【深度學習實戰】Pytorch Geometric實踐——利用Pytorch搭建GNN
(8條訊息) 【深度學習實戰】Pytorch Geometric實踐——利用Pytorch搭建GNN_喵木木的部落格-CSDN部落格
1. 安裝
首先,我們先查一下我們的pytorch的版本。要求至少安裝 PyTorch 1.2.0 版本:
python -c "import torch; print(torch.__version__)"
- 1
接著,查詢對應pytorch安裝的CUDA的版本:
python -c "import torch; print(torch.version.cuda)"
- 1
然後,安裝Pytorch geometry的軟體包。需要注意的是,這裡的${CUDA}
(cpu, cu92, cu101, cu102)
,${TORCH}
是前面查到的pytorch的版本。(建議將pytorch升級到最新版本再進行安裝)
pip install torch-scatter==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html pip install torch-sparse==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html pip install torch-cluster==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html pip install torch-spline-conv==latest+${CUDA} -f https://pytorch-geometric.com/whl/torch-${TORCH}.html pip install torch-geometric
- 1
- 2
- 3
- 4
- 5
比如我這裡查到Pytorch的版本是1.5.1(按照官網的教程,pytorch版本為1.5.0或者1.5.1的按照1.5.0來安裝),CUDA的版本是10.2,那麼我的安裝語句如下:
pip install torch-scatter==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html pip install torch-sparse==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html pip install torch-cluster==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html pip install torch-spline-conv==latest+cu102 -f https://pytorch-geometric.com/whl/torch-1.5.0.html pip install torch-geometric
- 1
- 2
- 3
- 4
- 5
2. 基本概念介紹
2.1 Data Handling of Graphs 圖形資料處理
圖(Graph)是描述實體(節點)和關係(邊)的資料模型。在Pytorch Geometric中,圖被看作是torch_geometric.data.Data的例項,並擁有以下屬性:
屬性 | 描述 |
---|---|
data.x |
節點特徵,維度是[num_nodes, num_node_features] 。 |
data.edge_index |
維度是[2, num_edges] ,描述圖中節點的關聯關係,每一列對應的兩個元素,分別是邊的起點和重點。資料型別是torch.long 。需要注意的是,data.edge_index 是定義邊的節點的張量(tensor),而不是節點的列表(list)。 |
data.edge_attr |
邊的特徵矩陣,維度是[num_edges, num_edge_features] |
data.y |
訓練目標(維度可以是任意的)。對於節點相關的任務,維度為[num_nodes, *] ;對於圖相關的任務,維度為[1,*] 。 |
data.position |
節點位置矩陣(Node position matrix),維度為[num_nodes, num_dimensions] 。 |
下面是一個簡單的例子:
首先匯入需要的包:
import torch
from torch_geometric.data import Data
- 1
- 2
比如上圖所示的圖結構,我們首先定義節點特徵向量:
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
- 1
接著定義邊,下面兩種定義方式是等價的。第二種方式可能更符合我們的閱讀習慣,但是需要注意的是此時應當增加一個edge_index=edge_index.t().contiguous()
的操作。此外,由於是無向圖,雖然只有兩條邊,但是我們需要四組關係說明來描述邊的兩個方向。
## 法1
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
data = Data(x=x, edge_index=edge_index)
## 法2
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
data = Data(x=x, edge_index=edge_index.t().contiguous())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
可以得到:
同時,Data物件提供了一些很實用的函式:
print('data\'s keys: {}'.format(data.keys))
print('-'*5)
for key, item in data:
print("{} found in data".format(key))
print('-'*5)
print('Does data has attribute \'edge_attr\'? {}'.format('edge_attr' in data))
print('data has {} nodes'.format(data.num_nodes))
print('data has {} edges'.format(data.num_edges))
print('The nodes in data have {} feature(s)'.format(data.num_node_features))
print('Does data contains isolated nodes? {}'.format(data.contains_isolated_nodes()))
print('Does data contains self loops? {}'.format(data.contains_self_loops()))
print('is data directed? {}'.format(data.is_directed()))
print(data['x'])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
輸出:
data's keys: ['x', 'edge_index']
-----
edge_index found in data
x found in data
-----
Does data has attribute 'edge_attr'? False
data has 3 nodes
data has 4 edges
The nodes in data have 1 feature(s)
Does data contains isolated nodes? False
Does data contains self loops? False
is data directed? False
tensor([[-1.],
[ 0.],
[ 1.]])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
同樣可以在GPU上執行data:
device = torch.device('cuda')
data = data.to(device)
- 1
- 2
2.2 Common Benchmark Datasets 常見的基準資料集
PyTorch Geometric提供很多基準資料集,包括
- all Planetoid datasets (Cora, Citeseer, Pubmed)
- all graph classification datasets fromhttp://graphkernels.cs.tu-dortmund.deandtheir cleaned versions
- the QM7 and QM9 dataset
- a handful of 3D mesh/point cloud datasets like FAUST, ModelNet10/40 and ShapeNet
想要使用這些資料集,只要進行初始化,資料就會自動下載。比如我們要使用ENZYMES資料集(該資料集包括600張圖,有6個類別):
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='.\data\ENZYMES', name='ENZYMES')
- 1
- 2
程式就會自動執行下載:
Downloading http://ls11-www.cs.tu-dortmund.de/people/morris/graphkerneldatasets/ENZYMES.zip
Extracting data\ENZYMES\ENZYMES\ENZYMES.zip
Processing...
Done!
- 1
- 2
- 3
- 4
我們可以看一下這個資料集的一些屬性:
print(dataset)
print(len(dataset))
print(dataset.num_classes)
print(dataset.num_node_features)
- 1
- 2
- 3
- 4
輸出:
我們可以看下其中一張圖的結構:
data = dataset[14]
print(data)
print(data.is_undirected())
- 1
- 2
- 3
輸出:
- 1
- 2
我們可以看到資料集中的第一個圖包含36個節點,每個節點有3個特徵。圖中有128/2 = 64條無向邊,圖被分類為“1”類。在將資料集分為訓練集和測試集之前,可以呼叫dataset = dataset.shuffle()
將資料集進行隨機打亂。這個語句和下面這段程式是等價的:
perm = torch.randperm(len(dataset))
dataset = dataset[perm]
- 1
- 2
我們再來看硬外一個數據集Cora
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='.\data\Cora', name='Cora')
data = dataset[0]
print(data)
print(data.is_undirected())
print(data.train_mask.sum().item())
print(data.val_mask.sum().item())
print(data.test_mask.sum().item())
print(len(dataset))
print(dataset.num_classes)
print(dataset.num_node_features)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
輸出:
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
可以看到,前面的資料集針對的是“網路分類”的任務,而這個資料集針對的是“節點分類”的任務。每個節點又1433個特徵,被分為7類。這個圖是一個無向圖,共有10556/2=5278條邊,共有2708個節點。這裡有三個需要注意的引數:
train_mask
——指明訓練集中的節點(可以看到,在這個資料集中,訓練集裡有140個節點)val_mask
——指明驗證集中的節點(可以看到,在這個資料集中,驗證集裡有500個節點)test_mask
——指明測試集中的節點(可以看到,在這個資料集中,測試集裡有1000個節點)
2.3 Mini-batches
神經網路通常以批處理的方式進行訓練。在pytorch中,通常用資料載入器DataLoader
來進行批處理。
from torch_geometric.datasets import TUDataset
from torch_geometric.data import DataLoader
from torch_scatter import scatter_mean
dataset = TUDataset(root='.\data\ENZYMES', name='ENZYMES', use_node_attr=True)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
for data in loader:
print(data)
print(data.num_graphs)
x = scatter_mean(data.x, data.batch, dim=0)
print(x.size())
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
張圖。
這個scatter函式實質上是對節點的一個整合,節點根據batch的標籤(按照圖)來進行整合,下面這張官方文件中的圖可以很好地說明scatter函式的作用:
2.4 Data Transforms 資料轉換
torch_geometric.transforms.Compose提供了資料轉換的方法,可以方便使用者將資料轉換成既定的格式或者用於資料的預處理。在之前使用torchvision處理影象時,也會用到資料轉換的相關方法,將圖片轉換成畫素矩陣,這裡的資料轉換就類似torchvision在影象上的處理。
2.5 Learning Methods on Graphs——the first graph neural network 搭建我們的第一個圖神經網路
下面我們來嘗試著搭建我們的第一圖神經網路。關於圖神經網路,可以看一下這篇部落格——GRAPH CONVOLUTIONAL NETWORKS。
資料集準備
我們使用的是Cora資料集。
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='./data/Cora', name='Cora')
print(dataset)
- 1
- 2
- 3
輸出:
搭建網路模型
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = GCNConv(dataset.num_node_features, 16)
self.conv2 = GCNConv(16, dataset.num_classes)
def forward(self, data):
x, edge_index = data.x, data.edge_index
x = self.conv1(x, edge_index)
x = F.relu(x)
x = F.dropout(x, training=self.training)
x = self.conv2(x, edge_index)
return F.log_softmax(x, dim=1)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
模型的結構包含兩個GCNConv層,選擇ReLU作為非線性函式,最後通過softmax輸出分類結果。
模型訓練和驗證
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net().to(device)
data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)
model.train()
for epoch in range(200):
optimizer.zero_grad()
out = model(data)
loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
model.eval()
_, pred = model(data).max(dim=1)
correct = int(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item())
acc = correct / int(data.test_mask.sum())
print('Accuracy: {:.4f}'.format(acc))
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
輸出:
3. CREATING MESSAGE PASSING NETWORKS 建立訊息傳遞網路
將卷積神經網路中的“卷積運算元”應用到圖上面,核心在於neighborhood aggregation機制,或者說是message passing的機制。Aggregate Neighbours,核心思想在於基於區域性網路連線來生成Node embeddings(Generate node embeddings based on local network neighborhoods)。如下面這個圖:
例如圖中節點A的embedding決定於其鄰居節點
{ B , C , D } \{B,C,D\}
{B,C,D},而這些節點又受到它們各自的鄰居節點的影響。圖中的“黑箱”可以看成是整合其鄰居節點資訊的操作,它有一個很重要的屬性——其操作應該是順序(order invariant)無關的,如求和、求平均、求最大值這樣的操作,可以採用神經網路來獲取。這樣順序無關的聚合函式符合網路節點無序性的特徵,當我們對網路節點進行重新編號時,我們的模型照樣可以使用。
那麼,對於每個節點來說,它的計算圖就由其鄰居節點的數量來決定——
模型的深度可以自己定義(Model can be of arbitrary depth):
- Nodes have embeddings at each layer
- Layer-0節點
3.1 Message passing 基本類
PyTorch Geometric 提供了基本類——MessagePassing,可以實現上述的圖神經網路,來實現訊息傳遞或訊息聚集(which helps in creating such kinds of message passing graph neural networks by automatically taking care of message propagation. )
MessagePassing類有三個引數:
- aggr (string, optional)——指定採用的置換不變函式,預設是
add
,可以定義為add
、mean
、max
和None
。 - **flow (string, optional) **——指定資訊傳遞的反向,預設是
source_to_target
,還可以設定為target_to_source
。 - **node_dim (int, optional) **——The axis along which to propagate. 預設是-2。
同時,MessagePassing提供了一些比較實用的方法:
MessagePassing.propagate(edge_index, size=None, **kwargs)
- ii全部設定為1。在pytorch geometric裡面,是利用edge_index來實現。如果是有權圖,則新增的自迴圈邊以
fill_value
作為權。該方法最後返回兩個值——`edge_index, edge_weight``。
import torch
from torch_geometric.utils import add_self_loops, degree
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
print("original edge_index ")
print(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
print("new edge_index")
print(edge_index)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
最後輸出:
original edge_index
tensor([[0, 1, 1, 2],
[1, 0, 2, 1]])
new edge_index
tensor([[0, 1, 1, 2, 0, 1, 2],
[1, 0, 2, 1, 0, 1, 2]])
- 1
- 2
- 3
- 4
- 5
- 6
-
Linearly transform node feature matrix.第二步是對節點的特徵矩陣進行線性變換。主要通過一個線性層
torch.nn.Linear
實現。 -
Compute normalization coefficients.第三步是對變換後的節點特徵進行標準化。節點的度可以通過
torch_geometric.utils.degree
實現。
import torch
from torch_geometric.utils import add_self_loops, degree
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
print("original edge_index ")
print(edge_index)
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
print("new edge_index")
print(edge_index)
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
print(deg)
deg_inv_sqrt = deg.pow(-0.5)
print(deg_inv_sqrt)
print(deg_inv_sqrt[row])
print(deg_inv_sqrt[col])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
最後輸出:
original edge_index
tensor([[0, 1, 1, 2],
[1, 0, 2, 1]])
new edge_index
tensor([[0, 1, 1, 2, 0, 1, 2],
[1, 0, 2, 1, 0, 1, 2]])
tensor([2., 3., 2.])
tensor([0.7071, 0.5774, 0.7071])
tensor([0.7071, 0.5774, 0.5774, 0.7071, 0.7071, 0.5774, 0.7071])
tensor([0.5774, 0.7071, 0.7071, 0.5774, 0.7071, 0.5774, 0.7071])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
-
Sum up neighboring node features (“add” aggregation).
前面三步是message passing之前的預操作,第四、第五步可以採用MessagePassing類裡面的方法完成。
完整的程式碼如下:
import torch
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree
class GCNConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(GCNConv, self).__init__(aggr='add') # "Add" aggregation (Step 5).
self.lin = torch.nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
# Step 1: Add self-loops to the adjacency matrix.
edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))
# Step 2: Linearly transform node feature matrix.
x = self.lin(x)
# Step 3: Compute normalization.
row, col = edge_index
deg = degree(col, x.size(0), dtype=x.dtype)
deg_inv_sqrt = deg.pow(-0.5)
norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]
# Step 4-5: Start propagating messages.
return self.propagate(edge_index, x=x, norm=norm)
def message(self, x_j, norm):
# x_j has shape [E, out_channels]
# Step 4: Normalize node features.
return norm.view(-1, 1) * x_j
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
我們建立的這個神經網路模型GCNConv
繼承於基礎類MessagePassing
,並且採用求和函式作為
□ \square
□函式,通過super(GCNConv, self).__init__(aggr='add')
來初始化。在完成1-3步之後,呼叫MessagePassing
中的propagate()
方法來完成4-5步,進行資訊傳播。message
函式用於對節點的鄰居節點的資訊進行標準化。
我們可以通過一個案例來感受一下這個模型的輸入和輸出。
x = torch.tensor(torch.rand(3,2), dtype=torch.float)
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
conv = GCNConv(2, 4)
- 1
- 2
- 3
- 4
設有上圖所示的網路,網路中有三個節點,每個節點有2個特徵值。並構建神經網路conv = GCNConv(2, 4)
。下面是程式執行的每一步輸出的結果:
x is
tensor([[0.1819, 0.1848],
[0.8479, 0.1754],
[0.7511, 0.9781]])
----Step 1: Add self-loops to the adjacency matrix.----
tensor([[0, 1, 1, 2, 0, 1, 2],
[1, 0, 2, 1, 0, 1, 2]])
----Step 2: Linearly transform node feature matrix.----
linear weight is
Parameter containing:
tensor([[-0.6532, -0.3349],
[ 0.5238, -0.5996],
[-0.6279, -0.5872],
[-0.4064, 0.5893]], requires_grad=True)
linear bias is
Parameter containing:
tensor([ 0.5966, -0.4339, 0.0263, 0.1577], requires_grad=True)
transformed x is
tensor([[ 0.4160, -0.4494, -0.1964, 0.1927],
[-0.0159, -0.0949, -0.6090, -0.0835],
[-0.2215, -0.6270, -1.0196, 0.4289]], grad_fn=<AddmmBackward>)
----Step 3: Compute normalization.----
tensor([0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000])
----Step 4-5: Start propagating messages.----
tensor([[ 0.2015, -0.2635, -0.3468, 0.0623],
[ 0.0741, -0.4711, -0.6994, 0.2260],
[-0.1172, -0.3522, -0.7584, 0.1804]], grad_fn=<ScatterAddBackward>)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
3.3 Edge Convolution 邊卷積層的實現
邊卷積層的數學定義如下:
x i ( k ) = max j ∈ N ( i ) h Θ ( x i ( k − 1 ) , x j ( k − 1 ) − x i ( k − 1 ) ) x_i^{(k)}=\max_{j \in N(i)} h_{\Theta}(x_i^{(k-1)},x_j^{(k-1)}-x_i^{(k-1)})
x
Θ
為多層感知機,類似於GCN,邊卷積層同樣繼承于于基礎類MessagePassing
,不同在於採用max
函式作為
□ \square
□函式。
邊卷積層的主要理論來自於論文Dynamic Graph CNN for Learning on Point Clouds,這篇文章提出一種邊卷積(EdgeConv)操作,來完成點雲中點與點之間關係的建模,使得網路能夠更好地學習區域性和全域性特徵。具體可以看這兩篇部落格:【深度學習——點雲】DGCNN(EdgeConv)和論文筆記:DGCNN(EdgeConv)。
import torch
from torch.nn import Sequential as Seq, Linear, ReLU
from torch_geometric.nn import MessagePassing
class EdgeConv(MessagePassing):
def __init__(self, in_channels, out_channels):
super(EdgeConv, self).__init__(aggr='max') # "Max" aggregation.
self.mlp = Seq(Linear(2 * in_channels, out_channels),
ReLU(),
Linear(out_channels, out_channels))
def forward(self, x, edge_index):
# x has shape [N, in_channels]
# edge_index has shape [2, E]
return self.propagate(edge_index, x=x)
def message(self, x_i, x_j):
# x_i has shape [E, in_channels]
# x_j has shape [E, in_channels]
tmp = torch.cat([x_i, x_j - x_i], dim=1) # tmp has shape [E, 2 * in_channels]
return self.mlp(tmp)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
邊緣卷積實際上是一種動態卷積,它使用特徵空間中的最近鄰重新計算每一層的圖。PyTorch geometry附帶一個GPU加速的批處理k-NN圖形生成方法——torch_geometric.n .pool.knn_graph()
。
from torch_geometric.nn import knn_graph
class DynamicEdgeConv(EdgeConv):
def __init__(self, in_channels, out_channels, k=6):
super(DynamicEdgeConv, self).__init__(in_channels, out_channels)
self.k = k
def forward(self, x, batch=None):
edge_index = knn_graph(x, self.k, batch, loop=False, flow=self.flow)
return super(DynamicEdgeConv, self).forward(x, edge_index)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
4. 建立自己的資料集
PyTorch Geometric提供了兩個抽象類——torch_geometric.data.Dataset
和torch_geometric.data.InMemoryDataset
。前者適用於不能一次性放進記憶體中的大資料集,後者適用於可以全部放進記憶體中的小資料集。
4.1 “In Memory Datasets”的建立
torch_geometric.data.InMemoryDataset
有四個可選引數:
- root (string, optional)——儲存資料集的根目錄。每個資料集都傳遞一個根資料夾,該根資料夾指示資料集應該儲存在何處。將根資料夾分成兩個資料夾:未處理過的資料集被儲存在raw_dir目錄下;已處理的資料集被儲存在processed_dir目錄下。
- transform (callable, optional)
- pre_transform (callable, optional)
- pre_filter (callable, optional)
建立In Memory Datasets,需要用到四個基本的方法:
raw_file_names()
——返回一個包含所有未處理過的資料檔案的檔名的列表。processed_file_names()
——返回一個包含所有處理過的資料檔案的檔名的列表。download()
——下載資料到raw_dir目錄下。process()
——對資料的處理函式,是核心的函式之一。
下面是官方文件給出的一個示例:
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
4.2 建立更大的資料集
對於無法全部放進記憶體中的大資料集,可以使用torch_geometric.data.Dataset
。torch_geometric.data.Dataset
的引數和torch_geometric.data.InMemoryDataset
的一致。常用的方法如下:
len()
——獲取資料集中的資料量。get(idx)
——獲取索引為idx
的資料物件。
下面是官方文件給出的一個示例:
import os.path as osp
import torch
from torch_geometric.data import Dataset
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42