DGL學習(三): 訊息傳遞教程
阿新 • • 發佈:2020-07-22
在本節中,我們將不同級別的訊息傳遞API與PageRank一起使用。 在DGL中,訊息傳遞和功能轉換是使用者定義的函式(UDF)。
PageRank 演算法:
在PageRank的每次迭代中,每個節點(網頁)首先將其PageRank值均勻地分散到其下游節點。 每個節點的新PageRank值是通過彙總從其鄰居收到的PageRank值來計算的,然後通過阻尼因子(damping factor)進行調整:
生成一個隨機圖, 兩點之間有邊的概率為 P:
import networkx as nx import matplotlib.pyplot as plt import torch importdgl N = 100 P = 0.1
DAMP = 0.8
g = nx.erdos_renyi_graph(N, P) g = dgl.DGLGraph(g)
src = list(range(1,51));dst = [0]*50 # 使用list批量新增
g.add_edges(src, dst)
print(g.number_of_edges()) print(g.number_of_nodes()) nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show()
在pagerank 中, 初始化每個節點初始值為 1/N, 將節點的出度作為節點的特徵。
## pv 演算法初始值 g.ndata['pv'] = torch.ones(N) / N g.ndata['deg'] = g.out_degrees(g.nodes()).float()
定義訊息函式,該函式將每個節點的PageRank值除以其出度,然後將結果作為訊息傳遞給其鄰居。
在DGL中,訊息函式是針對邊的,表示為Edge UDF。 Edge UDF接受單個引數edges。 它具有三個成員src,dst和data,用於訪問源節點特徵,目標節點特徵和邊特徵。實現pv演算法僅需從src中取特徵。
def pagerank_message_func(edges): return{'pv': edges.src['pv'] / edges.src['deg']}
定義reduce函式,該函式從其mailbox中聚合訊息和刪除訊息,並計算其新的PageRank值。
reduce函式是針對節點的,表示為 Node UDF。 Node UDF接受單個引數nodes,nodes具有兩個成員mailbox和data。 data包含節點特徵,mailbox包含所有傳入訊息特徵,這些功能沿第二維堆疊(dim = 1引數)。
def pagerank_reduce_func(nodes): msgs = torch.sum(nodes.mailbox['pv'], dim=1) pv = (1 - DAMP) / N + DAMP * msgs return {'pv' : pv}
註冊訊息函式和規約函式, 之後DGL呼叫它。 pagerank_naive是page_rank的簡單實現。
# 註冊訊息函式和歸約函式,稍後DGL將呼叫它。 g.register_message_func(pagerank_message_func) g.register_reduce_func(pagerank_reduce_func) def pagerank_naive(g): # Phase #1: send out messages along all edges. for u, v in zip(*g.edges()): g.send((u, v)) # Phase #2: receive messages to compute new PageRank values. for v in g.nodes(): g.recv(v) # 迭代10輪 for k in range(10): pagerank_naive(g) print(g.ndata['pv'])
tensor([0.0446, 0.0107, 0.0087, 0.0102, 0.0085, 0.0130, 0.0091, 0.0059, 0.0079, 0.0088, 0.0082, 0.0087, 0.0098, 0.0087, 0.0100, 0.0092, 0.0065, 0.0168, 0.0064, 0.0106, 0.0098, 0.0117, 0.0077, 0.0113, 0.0111, 0.0100, 0.0077, 0.0051, 0.0084, 0.0070, 0.0048, 0.0163, 0.0102, 0.0084, 0.0098, 0.0127, 0.0101, 0.0091, 0.0091, 0.0083, 0.0088, 0.0095, 0.0132, 0.0106, 0.0057, 0.0099, 0.0068, 0.0106, 0.0098, 0.0068, 0.0140, 0.0087, 0.0083, 0.0120, 0.0107, 0.0109, 0.0072, 0.0090, 0.0069, 0.0124, 0.0094, 0.0106, 0.0071, 0.0093, 0.0070, 0.0059, 0.0068, 0.0162, 0.0082, 0.0129, 0.0063, 0.0134, 0.0116, 0.0095, 0.0107, 0.0147, 0.0085, 0.0099, 0.0084, 0.0069, 0.0112, 0.0120, 0.0076, 0.0105, 0.0125, 0.0091, 0.0063, 0.0085, 0.0051, 0.0102, 0.0116, 0.0070, 0.0120, 0.0094, 0.0156, 0.0159, 0.0096, 0.0125, 0.0065, 0.0107])View Code