1. 程式人生 > >deepwalk原始碼解讀2: __main()__.py

deepwalk原始碼解讀2: __main()__.py

1. 準備工作

__main()__.py 檔案是程式的入口,主要包含兩段程式碼,一個是main(),用於定義命令列引數,一個是process(args) 用於解析命令列引數。還有一些程式碼片段是日誌和debug使用的,為了不影響對程式碼的分析,我們將這兩部分程式碼註釋掉。

另一方面,我們如果想要在IDE裡分析原始碼,不可能每一次都使用命令列進行,因此這裡面使用了一個技巧,即源程式在定義完入口命令列引數後,使用了args = parser.parse_args() 來接送實際使用命令列時的輸入,我們這裡把這句程式碼替換為:

args= parser.parse_args("--input ../example_graphs/karate.adjlist "
"--output ./output".split())

這樣就可以脫離命令列終端,直接在IDE裡進行分析。

2. 詳細解讀

下面是程式碼的詳細解讀:

2.1 import

#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import random

from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter

from deepwalk import graph
from deepwalk import
walks as serialized_walks from gensim.models import Word2Vec from deepwalk.skipgram import Skipgram

2.2 main()

def main():
    parser = ArgumentParser("deepwalk",
                            formatter_class=ArgumentDefaultsHelpFormatter,
                            conflict_handler='resolve')
    parser.add_argument("--debug"
, dest="debug", action='store_true', default=False, help="drop a debugger if an exception is raised.") parser.add_argument('--format', default='adjlist', help='File format of input file') parser.add_argument('--input', nargs='?', required=True, help='Input graph file') """ 這句話的意思是,你的命令列裡必須有--input,至於說這個引數後面有沒有檔名字無所謂,如果有檔案路徑,那麼input就是 這個檔案;如果沒有指定檔案路徑,那麼input=None """ parser.add_argument("-l", "--log", dest="log", default="INFO", help="log verbosity level") parser.add_argument('--matfile-variable-name', default='network', help='variable name of adjacency matrix inside a .mat file.') parser.add_argument('--max-memory-data-size', default=1000000000, type=int, help='Size to start dumping walks to disk, instead of keeping them in memory.') parser.add_argument('--number-walks', default=10, type=int, help='Number of random walks to start at each node') """ number-walks:相當於你演算法裡面的$\gamma$ """ parser.add_argument('--output', required=True, help='Output representation file') parser.add_argument('--representation-size', default=64, type=int, help='Number of latent dimensions to learn for each node.') parser.add_argument('--seed', default=0, type=int, help='Seed for random walk generator.') parser.add_argument('--undirected', default=True, type=bool, help='Treat graph as undirected.') parser.add_argument('--vertex-freq-degree', default=False, action='store_true', help='Use vertex degree to estimate the frequency of nodes ' 'in the random walks. This option is faster than ' 'calculating the vocabulary.') """ 這個引數尚不知道具體是用來做什麼的 """ parser.add_argument('--walk-length', default=40, type=int, help='Length of the random walk started at each node') """ walk_length:相當於你演算法裡面的$\ell$ """ parser.add_argument('--window-size', default=5, type=int, help='Window size of skipgram model.') """ window_size:相當於你演算法裡面的context size $\omega$ """ parser.add_argument('--workers', default=1, type=int, help='Number of parallel processes.') # args = parser.parse_args() args= parser.parse_args("--input ../example_graphs/karate.adjlist " "--output ./output".split()) print(args) process(args)

2.3 process(args)

def process(args):
    '''載入資料'''
    if args.format == "adjlist":
        G = graph.load_adjacencylist(args.input, undirected=args.undirected)
    elif args.format == "edgelist":
        G = graph.load_edgelist(args.input, undirected=args.undirected)
    elif args.format == "mat":
        G = graph.load_matfile(args.input, variable_name=args.matfile_variable_name, undirected=args.undirected)
    else:
        raise Exception("Unknown file format: '%s'.  Valid formats: 'adjlist', 'edgelist', 'mat'" % args.format)

    print("Number of nodes: {}".format(len(G.nodes())))

    num_walks = len(G.nodes()) * args.number_walks

    print("Number of walks: {}".format(num_walks))

    data_size = num_walks * args.walk_length

    print("Data size (walks*length): {}".format(data_size))

    if data_size < args.max_memory_data_size:
        print("Walking...")
        walks = graph.build_deepwalk_corpus(G, num_paths=args.number_walks,
                                            path_length=args.walk_length, alpha=0, rand=random.Random(args.seed))
        print("Training...")
        model = Word2Vec(walks, size=args.representation_size, window=args.window_size, min_count=0, sg=1, hs=1,
                         workers=args.workers)
    else:
        print("Data size {} is larger than limit (max-memory-data-size: {}).  Dumping walks to disk.".format(data_size,
                                                                                                             args.max_memory_data_size))
        print("Walking...")

        walks_filebase = args.output + ".walks"
        walk_files = serialized_walks.write_walks_to_disk(G, walks_filebase, num_paths=args.number_walks,
                                                          path_length=args.walk_length, alpha=0,
                                                          rand=random.Random(args.seed),
                                                          num_workers=args.workers)

        print("Counting vertex frequency...")
        if not args.vertex_freq_degree:
            vertex_counts = serialized_walks.count_textfiles(walk_files, args.workers)
        else:
            # use degree distribution for frequency in tree
            vertex_counts = G.degree(nodes=G.iterkeys())

        print("Training...")
        walks_corpus = serialized_walks.WalksCorpus(walk_files)
        model = Skipgram(sentences=walks_corpus, vocabulary_counts=vertex_counts,
                         size=args.representation_size,
                         window=args.window_size, min_count=0, trim_rule=None, workers=args.workers)

    model.wv.save_word2vec_format(args.output)

2.4 程式入口

if __name__ == "__main__":
    sys.exit(main())

3. deepwalk檔案邏輯

通過對主檔案的分析之後,我們整理了整個工程的呼叫邏輯,如下圖: