1. 程式人生 > >GraphSAGE 程式碼解析 - minibatch.py

GraphSAGE 程式碼解析 - minibatch.py

class EdgeMinibatchIterator

    """ This minibatch iterator iterates over batches of sampled edges or
    random pairs of co-occuring edges.

    G -- networkx graph
    id2idx -- dict mapping node ids to index in feature tensor
    placeholders -- tensorflow placeholders object
    context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
    batch_size -- size of the minibatches
    max_degree -- maximum size of the downsampled adjacency lists
    n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
    fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
    
"""

def __init__(self, G, id2idx, placeholders, context_pairs=None, batch_size=100, max_degree=25,

n2v_retrain=False, fixed_n2v=False, **kwargs) 中具體介紹以下:

1 self.nodes = np.random.permutation(G.nodes())
2 # 函式shuffle與permutation都是對原來的陣列進行重新洗牌,即隨機打亂原來的元素順序
3 # shuffle直接在原來的陣列上進行操作,改變原來陣列的順序,無返回值
4 # permutation不直接在原來的陣列上進行操作,而是返回一個新的打亂順序的陣列,並不改變原來的陣列。
1 self.adj, self.deg = self.construct_adj()

這裡重點看construct_adj()函式。

 1 def construct_adj(self):
 2         adj = len(self.id2idx) * \
 3             np.ones((len(self.id2idx) + 1, self.max_degree))
 4         # 該矩陣記錄訓練資料中各節點的鄰居節點的編號
 5         # 取樣只取max_degree個鄰居節點,取樣方法見下
 6         # 同樣進行了行數加一操作
7 8 deg = np.zeros((len(self.id2idx),)) 9 # 該矩陣記錄了每個節點的度數 10 11 for nodeid in self.G.nodes(): 12 if self.G.node[nodeid]['test'] or self.G.node[nodeid]['val']: 13 continue 14 neighbors = np.array([self.id2idx[neighbor] 15 for neighbor in self.G.neighbors(nodeid) 16 if (not self.G[nodeid][neighbor]['train_removed'])]) 17 # Graph.neighbors() Return a list of the nodes connected to the node n. 18 # 在選取鄰居節點時進行了篩選,對於G.neighbors(nodeid) 點node的鄰居, 19 # 只取該node與neighbor相連的邊的train_removed = False的neighbor 20 # 也就是隻取不是val, test的節點。 21 # neighbors得到了鄰居節點編號數列。 22 23 deg[self.id2idx[nodeid]] = len(neighbors) 24 # deg各位取值為該位對應nodeid的節點的度數, 25 # 也即經過上面篩選後得到的鄰居數 26 27 if len(neighbors) == 0: 28 continue 29 if len(neighbors) > self.max_degree: 30 neighbors = np.random.choice( 31 neighbors, self.max_degree, replace=False) 32 # range: neighbors; size = max_degree; replace: replace the origin matrix or not 33 # np.random.choice為選取size大小的數列 34 35 elif len(neighbors) < self.max_degree: 36 neighbors = np.random.choice( 37 neighbors, self.max_degree, replace=True) 38 # 經過choice隨機選取,得到了固定大小max_degree = 25的直接相連的鄰居數列 39 40 adj[self.id2idx[nodeid], :] = neighbors 41 # 把該node的鄰居數列,賦值給adj矩陣中對應nodeid位的向量。 42 return adj, deg

 

construct_test_adj()  函式中,與上不同之處在於,可以直接得到鄰居而無需根據val/test/train_removed篩選.

1 neighbors = np.array([self.id2idx[neighbor]
2                           for neighbor in self.G.neighbors(nodeid)])