PyG-使用networkx對Graph進行視覺化
阿新 • • 發佈:2021-02-11
程式碼
方法一
根據networkx的文件: https://networkx.github.io/documentation/networkx-1.10/reference/generated/networkx.drawing.nx_pylab.draw_networkx.html
我們可以寫出來一個非常簡單的例子,如下(程式碼可以左右滑動):
import networkx as nx import matplotlib.pyplot as plt G = nx.Graph() edge_index = [(1, 2), (1, 3), (2, 3), (3, 4)] G.add_edges_from(edge_index) nx.draw(G) plt.show()
執行程式之後,可以得到下面的圖,(偷了一個懶,沒有加label之類的資訊)
這個例子給我們的啟發:我們可以將PyG得到的edge_index轉成numpy的格式,然後傳給nx,下面是根據這個寫的一個函式:
在PyG中,邊的表示放在了edge_index中,由一個二維的矩陣構成,edge_index[0]表示節點edge_index[1]表示另一個節點。
def draw(edge_index, name=None): G = nx.Graph(node_size=15, font_size=8) src = edge_index[0].cpu().numpy() dst = edge_index[1].cpu().numpy() edgelist = zip(src, dst) for i, j in edgelist: G.add_edge(i, j) plt.figure(figsize=(20, 14)) # 設定畫布的大小 nx.draw_networkx(G) plt.savefig('{}.png'.format(name if name else'path'))
注:該方法可以用於模型中的forward函式,用於分析cov,pool等操作
再寫一個與上面思想一致,可以直接執行的一個例子
from torch_geometric.datasets import KarateClub import networkx as nx import matplotlib.pyplot as plt dataset = KarateClub() edge, x, y = dataset[0] # edge, x, y 每個維度都為2,其中第一維度是name,第二個維度是data # x表示的是結點,y表示的標籤,edge表示的連邊, 由兩個維度的tensor構成 x_np = x[1].numpy() y_np = y[1].numpy() g = nx.Graph() name, edgeinfo = edge src = edgeinfo[0].numpy() dst = edgeinfo[1].numpy() edgelist = zip(src, dst) for i, j in edgelist: g.add_edge(i, j) nx.draw(g) plt.savefig('test.png') plt.show()
方法二
其實,torch_geometric.utils中已經帶有to_networkx的函式可以直接將格式為torch_geometric.data.Data 的資料轉換為networkx.DiGraph的格式,該格式可以直接networkx處理,但是我們提前要得到torch_geometric.data.Data的資料格式
import networkx as nx
from torch_geometric.utils.convert import to_networkx
def draw(Data):
G = to_networkx(Data)
nx.draw(G)
plt.savefig("path.png")
plt.show()
注:上面這個一般可以用於在model訓練載入資料之前資料的分析,比如下面的例子
for i, data in enumerate(train_loader):
draw(data)
data = data.to(args.device)
out = model(data)
loss = F.nll_loss(out, data.y)
print("Training loss:{}".format(loss.item()))
loss.backward()
optimizer.step()
optimizer.zero_grad()
上面的函式是在graph classification進行分析的一段程式碼,可以把batch size的設定為1,那麼for迴圈中得到就是一個graph的資料,在把資料feed給模型之前,我們可以通過該方法分析一下原始的資料是什麼樣子的。