1. 程式人生 > 其它 >PyG-使用networkx對Graph進行視覺化

PyG-使用networkx對Graph進行視覺化

程式碼

方法一

根據networkx的文件: https://networkx.github.io/documentation/networkx-1.10/reference/generated/networkx.drawing.nx_pylab.draw_networkx.html

我們可以寫出來一個非常簡單的例子,如下(程式碼可以左右滑動):

import networkx as nx
import matplotlib.pyplot as plt
G = nx.Graph()
edge_index = [(1, 2), (1, 3), (2, 3), (3, 4)]
G.add_edges_from(edge_index)
nx.draw(G)
plt.show()

執行程式之後,可以得到下面的圖,(偷了一個懶,沒有加label之類的資訊)

圖片

這個例子給我們的啟發:我們可以將PyG得到的edge_index轉成numpy的格式,然後傳給nx,下面是根據這個寫的一個函式:
在PyG中,邊的表示放在了edge_index中,由一個二維的矩陣構成,edge_index[0]表示節點edge_index[1]表示另一個節點

def draw(edge_index, name=None):
    G = nx.Graph(node_size=15, font_size=8)
    src = edge_index[0].cpu().numpy()
    dst = edge_index[1].cpu().numpy()
    edgelist = zip(src, dst)
    for i, j in edgelist:
        G.add_edge(i, j)
    plt.figure(figsize=(20, 14)) # 設定畫布的大小
    nx.draw_networkx(G)
    plt.savefig('{}.png'.format(name if name else'path'))

注:該方法可以用於模型中的forward函式,用於分析cov,pool等操作

再寫一個與上面思想一致,可以直接執行的一個例子

from torch_geometric.datasets import KarateClub
import networkx as nx
import matplotlib.pyplot as plt
dataset = KarateClub()
edge, x, y = dataset[0]
# edge, x, y 每個維度都為2,其中第一維度是name,第二個維度是data
# x表示的是結點,y表示的標籤,edge表示的連邊, 由兩個維度的tensor構成
x_np = x[1].numpy()
y_np = y[1].numpy()
g = nx.Graph()
name, edgeinfo = edge
src = edgeinfo[0].numpy()
dst = edgeinfo[1].numpy()
edgelist = zip(src, dst)
for i, j in edgelist:
    g.add_edge(i, j)
nx.draw(g)
plt.savefig('test.png')
plt.show()

方法二

其實,torch_geometric.utils中已經帶有to_networkx的函式可以直接將格式為torch_geometric.data.Data 的資料轉換為networkx.DiGraph的格式,該格式可以直接networkx處理,但是我們提前要得到torch_geometric.data.Data的資料格式

import networkx as nx
from torch_geometric.utils.convert import to_networkx
def draw(Data):
    G = to_networkx(Data)
    nx.draw(G)
    plt.savefig("path.png")
    plt.show()

注:上面這個一般可以用於在model訓練載入資料之前資料的分析,比如下面的例子

for i, data in enumerate(train_loader):
        draw(data)
        data = data.to(args.device)
        out = model(data)
        loss = F.nll_loss(out, data.y)
        print("Training loss:{}".format(loss.item()))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

上面的函式是在graph classification進行分析的一段程式碼,可以把batch size的設定為1,那麼for迴圈中得到就是一個graph的資料,在把資料feed給模型之前,我們可以通過該方法分析一下原始的資料是什麼樣子的。

參考

pyg手冊

nx畫圖手冊