1. 程式人生 > 實用技巧 >DGL學習(三): 訊息傳遞教程

DGL學習(三): 訊息傳遞教程

在本節中,我們將不同級別的訊息傳遞API與PageRank一起使用。 在DGL中,訊息傳遞和功能轉換是使用者定義的函式(UDF)。

PageRank 演算法:

在PageRank的每次迭代中,每個節點(網頁)首先將其PageRank值均勻地分散到其下游節點。 每個節點的新PageRank值是通過彙總從其鄰居收到的PageRank值來計算的,然後通過阻尼因子(damping factor)進行調整:

生成一個隨機圖, 兩點之間有邊的概率為 P:

import networkx as nx
import matplotlib.pyplot as plt
import torch
import
dgl N = 100 P = 0.1
DAMP = 0.8
g = nx.erdos_renyi_graph(N, P) g = dgl.DGLGraph(g)
src = list(range(1,51));dst = [0]*50 # 使用list批量新增
g.add_edges(src, dst)
print(g.number_of_edges()) print(g.number_of_nodes()) nx.draw(g.to_networkx(), node_size=50, node_color=[[.5, .5, .5,]])
plt.show()

在pagerank 中, 初始化每個節點初始值為 1/N, 將節點的出度作為節點的特徵。

## pv 演算法初始值
g.ndata['pv'] = torch.ones(N) / N
g.ndata['deg'] = g.out_degrees(g.nodes()).float()

定義訊息函式,該函式將每個節點的PageRank值除以其出度,然後將結果作為訊息傳遞給其鄰居。

在DGL中,訊息函式是針對邊的,表示為Edge UDF。 Edge UDF接受單個引數edges。 它具有三個成員src,dst和data,用於訪問源節點特徵,目標節點特徵和邊特徵。實現pv演算法僅需從src中取特徵。

def pagerank_message_func(edges):
    return
{'pv': edges.src['pv'] / edges.src['deg']}

定義reduce函式,該函式從其mailbox中聚合訊息和刪除訊息,並計算其新的PageRank值。

reduce函式是針對節點的,表示為 Node UDF。 Node UDF接受單個引數nodes,nodes具有兩個成員mailbox和data。 data包含節點特徵,mailbox包含所有傳入訊息特徵,這些功能沿第二維堆疊(dim = 1引數)。

def pagerank_reduce_func(nodes):
    msgs = torch.sum(nodes.mailbox['pv'], dim=1)
    pv = (1 - DAMP) / N + DAMP * msgs
    return {'pv' : pv}

註冊訊息函式和規約函式, 之後DGL呼叫它。 pagerank_naive是page_rank的簡單實現。

# 註冊訊息函式和歸約函式,稍後DGL將呼叫它。
g.register_message_func(pagerank_message_func)
g.register_reduce_func(pagerank_reduce_func)

def pagerank_naive(g):
    # Phase #1: send out messages along all edges.
    for u, v in zip(*g.edges()):
        g.send((u, v))
    # Phase #2: receive messages to compute new PageRank values.
    for v in g.nodes():
        g.recv(v)

# 迭代10輪
for k in range(10):
    pagerank_naive(g)

print(g.ndata['pv'])
tensor([0.0446, 0.0107, 0.0087, 0.0102, 0.0085, 0.0130, 0.0091, 0.0059, 0.0079,
        0.0088, 0.0082, 0.0087, 0.0098, 0.0087, 0.0100, 0.0092, 0.0065, 0.0168,
        0.0064, 0.0106, 0.0098, 0.0117, 0.0077, 0.0113, 0.0111, 0.0100, 0.0077,
        0.0051, 0.0084, 0.0070, 0.0048, 0.0163, 0.0102, 0.0084, 0.0098, 0.0127,
        0.0101, 0.0091, 0.0091, 0.0083, 0.0088, 0.0095, 0.0132, 0.0106, 0.0057,
        0.0099, 0.0068, 0.0106, 0.0098, 0.0068, 0.0140, 0.0087, 0.0083, 0.0120,
        0.0107, 0.0109, 0.0072, 0.0090, 0.0069, 0.0124, 0.0094, 0.0106, 0.0071,
        0.0093, 0.0070, 0.0059, 0.0068, 0.0162, 0.0082, 0.0129, 0.0063, 0.0134,
        0.0116, 0.0095, 0.0107, 0.0147, 0.0085, 0.0099, 0.0084, 0.0069, 0.0112,
        0.0120, 0.0076, 0.0105, 0.0125, 0.0091, 0.0063, 0.0085, 0.0051, 0.0102,
        0.0116, 0.0070, 0.0120, 0.0094, 0.0156, 0.0159, 0.0096, 0.0125, 0.0065,
        0.0107])
View Code