MindSpore載入圖資料集
載入圖資料集
- MindSpore提供的
mindspore.dataset
模組可以幫助使用者構建資料集物件,分批次地讀取文字資料。
圖的概念
-
通常一個圖(graph)
G
是由一系列的節點(vertices)V
以及邊(eage)E
組成的,每條邊都連線著圖中的兩個節點,用公式可表述為:G = F(V, E)
,簡單的圖如下所示。 -
圖中包含節點V = {a, b, c, d},和邊E = {(a, b), (b, c), (c, d), (d, b)},針對圖中的連線關係通常需藉助數學的方式進行描述,如常用的基於鄰接矩陣的方式,用於描述上述圖連線關係的矩陣C如下,其中a、 b、c、d對應為第1、2、 3、4個節點。
資料集下載和轉換
(1) 資料集介紹
-
常用的圖資料集包含Cora、Citeseer、PubMed等
-
原始資料集可以從ucsc網站進行下載,
-
github提供的預處理後的資料集,GCN等公開使用
-
Cora資料集主體部分(
cora.content
)-
2708條樣本(節點),每條樣本描述1篇科學論文的資訊,論文都屬於7個類別中的一個。每條樣本資料包含三部分,依次為論文編號、論文的詞向量(一個1433位的二進位制)、論文的類別;
-
引用資料集部分(
cora.cites
)包含5429行(邊),每行包含兩個論文編號,表示第二篇論文對第一篇論文進行了引用。
-
資料集下載:下載預處理後的cora資料集目錄如下:
.
└── cora
├── ind.cora.allx
├── ind.cora.ally
├── ind.cora.graph
├── ind.cora.test.index
├── ind.cora.tx
├── ind.cora.ty
├── ind.cora.x
├── ind.cora.y
├── trans.cora.graph
├── trans.cora.tx
├── trans.cora.ty
├── trans.cora.x
└── trans.cora.y
(2)資料集下載
以下示例程式碼將cora資料集下載並解壓到指定位置。
!mkdir -p ./cora
!git clone https://github.com/kimiyoung/planetoid
!cp planetoid/data/*.cora.* ./cora
!rm -rf planetoid
(3)資料集格式轉換
- 資料集格式轉換:將資料集轉換為MindRecord格式,可藉助models倉庫提供的轉換指令碼進行轉換,生成的MindRecord檔案在
./cora_mindrecord
路徑下。
!git clone https://gitee.com/mindspore/models.git
SRC_PATH = "./cora"
MINDRECORD_PATH = "./cora_mindrecord"
!rm -rf $MINDRECORD_PATH
!mkdir $MINDRECORD_PATH
!python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "$MINDRECORD_PATH/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "$SRC_PATH"
- 報錯,但命令列可以
- 改: 環境切換 沒得搞定啊
!source activate py37_ms16
!python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "$MINDRECORD_PATH/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "$SRC_PATH"
- 乖乖命令列試試。看來預設環境沒有ms不行?
source activate py37_ms16
python models/utils/graph_to_mindrecord/writer.py --mindrecord_script cora --mindrecord_file "./cora_mindrecord/cora_mr" --mindrecord_partitions 1 --mindrecord_header_size_by_bit 18 --mindrecord_page_size_by_bit 20 --graph_api_args "./cora"
載入資料集
-
MindSpore目前支援載入文字領域常用的經典資料集和多種資料儲存格式下的資料集,使用者也可以通過構建自定義資料集類實現自定義方式的資料載入。
-
下面演示使用
MindSpore.dataset
模組中的MindDataset
類載入上述已轉換成mindrecord格式的cora資料集。
(1)配置資料集目錄,建立資料集物件。
import mindspore.dataset as ds
import numpy as np
data_file = "./cora_mindrecord/cora_mr"
dataset = ds.GraphData(data_file)
(2)訪問對應的介面,獲取圖資訊及特性、標籤內容。
# 檢視圖中結構資訊
graph = dataset.graph_info()
print("graph info:", graph)
# 獲取所有的節點資訊
nodes = dataset.get_all_nodes(0)
nodes_list = nodes.tolist()
print("node shape:", len(nodes_list))
# 獲取特徵和標籤資訊,總共2708條資料
# 每條資料中特徵資訊是用於描述論文i,長度為1433的二進位制表示,標籤資訊指的是論文所屬的種類
raw_tensor = dataset.get_node_feature(nodes_list, [1, 2])
features, labels = raw_tensor[0], raw_tensor[1]
print("features shape:", features.shape)
print("labels shape:", labels.shape)
print("labels:", labels)
資料處理
- MindSpore目前支援的資料處理運算元及其詳細使用方法。下面構建pipeline,對節點進行取樣等操作。
(1)獲取節點的鄰居節點,構造鄰接矩陣。
neighbor = dataset.get_all_neighbors(nodes_list, 0)
# neighbor的第一列是node_id,第二列到最後一列儲存的是第一列的鄰居節點,如果不存在這麼多,則用-1補齊。
print("neighbor:\n", neighbor)
(2)依據節點的鄰居節點資訊,構造鄰接矩陣。
nodes_num = labels.shape[0]
node_map = {node_id: index for index, node_id in enumerate(nodes_list)}
adj = np.zeros([nodes_num, nodes_num], dtype=np.float32)
for index, value in np.ndenumerate(neighbor):
# neighbor的第一列是node_id,第二列到最後一列儲存的是第一列的鄰居節點,如果不存在這麼多,則用-1補齊。
if value >= 0 and index[1] > 0:
adj[node_map[neighbor[index[0], 0]], node_map[value]] = 1
print("adj:\n", adj)
(3)節點取樣,支援常見的多次跳躍取樣與隨機遊走取樣方法等。
- 多跳鄰接點取樣如(a)圖所示,當次取樣的節點將作為下次取樣的起始點;隨機遊走方式如(b)圖所示,隨機選擇一條路徑依次遍歷相鄰的節點,對應圖中則選擇了從Vi到Vj的遊走路徑。
# 基於多次跳躍進行節點取樣
neighbor = dataset.get_sampled_neighbors(nodes_list[0:21], [2], [0])
print("neighbor:\n", neighbor)
# 基於隨機遊走進行節點取樣
meta_path = [0]
walks = dataset.random_walk(nodes_list[0:21], meta_path)
print("walks:\n", walks)
(4)通過節點獲取邊/通過邊獲取節點。
# 通過邊獲取節點
part_edges = dataset.get_all_edges(0)[:10]
nodes = dataset.get_nodes_from_edges(part_edges)
print("part edges:", part_edges)
print("nodes:", nodes)
# 通過節點獲取邊
# nodes_pair_list = [(0, 1), (1, 2), (1, 3), (1, 4)]
# edges = dataset.get_edges_from_nodes(nodes_pair_list)
# print("edges:", edges)