1. 程式人生 > >Python Ast介紹及應用

Python Ast介紹及應用

Abstract Syntax Trees即抽象語法樹。Ast是python原始碼到位元組碼的一種中間產物,藉助ast模組可以從語法樹的角度分析原始碼結構。此外,我們不僅可以修改和執行語法樹,還可以將Source生成的語法樹unparse成python原始碼。因此ast給python原始碼檢查、語法分析、修改程式碼以及程式碼除錯等留下了足夠的發揮空間。

1. AST簡介 

Python官方提供的CPython直譯器對python原始碼的處理過程如下:

  1. Parse source code into a parse tree (Parser/pgen.c)
  2. Transform parse tree into an Abstract Syntax Tree (Python/ast.c)
  3. Transform AST into a Control Flow Graph (Python/compile.c)
  4. Emit bytecode based on the Control Flow Graph (Python/compile.c)

即實際python程式碼的處理過程如下:

原始碼解析 --> 語法樹 --> 抽象語法樹(AST) --> 控制流程圖 --> 位元組碼

上述過程在python2.5之後被應用。python原始碼首先被解析成語法樹,隨後又轉換成抽象語法樹。在抽象語法樹中我們可以看到原始碼檔案中的python的語法結構。

大部分時間程式設計可能都不需要用到抽象語法樹,但是在特定的條件和需求的情況下,AST又有其特殊的方便性。

下面是一個抽象語法的簡單例項。

Module(body=[
    Print(
          dest=None,
          values=[BinOp( left=Num(n=1),op=Add(),right=Num(n=2))],
          nl=True,
 )])                                

2. 建立AST

2.1 Compile函式

先簡單瞭解一下compile函式。

compile(source, filename, mode[, flags[, dont_inherit]]) 

  • source -- 字串或者AST(Abstract Syntax Trees)物件。一般可將整個py檔案內容file.read()傳入。
  • filename -- 程式碼檔名稱,如果不是從檔案讀取程式碼則傳遞一些可辨認的值。
  • mode -- 指定編譯程式碼的種類。可以指定為 exec, eval, single。
  • flags -- 變數作用域,區域性名稱空間,如果被提供,可以是任何對映物件。
  • flags和dont_inherit是用來控制編譯原始碼時的標誌。
func_def = \
"""
def add(x, y):
    return x + y
print add(3, 5)
"""

使用Compile編譯並執行:

>>> cm = compile(func_def, '<string>', 'exec')
>>> exec cm
>>> 8

上面func_def經過compile編譯得到位元組碼,cm即code物件,True == isinstance(cm, types.CodeType)。

compile(source, filename, mode, ast.PyCF_ONLY_AST)  <==> ast.parse(source, filename='<unknown>', mode='exec')

2.2 生成ast

使用上面的func_def生成ast.

r_node = ast.parse(func_def)
print astunparse.dump(r_node)    # print ast.dump(r_node)

 下面是func_def對應的ast結構:

Module(body=[
    FunctionDef(
        name='add',
        args=arguments(
            args=[Name(id='x',ctx=Param()),Name(id='y',ctx=Param())],
            vararg=None,
            kwarg=None,
            defaults=[]),
        body=[Return(value=BinOp(
            left=Name(id='x',ctx=Load()),
            op=Add(),
            right=Name(id='y',ctx=Load())))],
        decorator_list=[]),
    Print(
        dest=None,
        values=[Call(
                func=Name(id='add',ctx=Load()),
                args=[Num(n=3),Num(n=5)],
                keywords=[],
                starargs=None,
                kwargs=None)],
        nl=True)
  ])

 除了ast.dump,有很多dump ast的第三方庫,如astunparse, codegen, unparse等。這些第三方庫不僅能夠以更好的方式展示出ast結構,還能夠將ast反向匯出python source程式碼。

module Python version "$Revision$"
{
  mod = Module(stmt* body)| Expression(expr body)

  stmt = FunctionDef(identifier name, arguments args, stmt* body, expr* decorator_list)
        | ClassDef(identifier name, expr* bases, stmt* body, expr* decorator_list)
        | Return(expr? value)
        | Print(expr? dest, expr* values, bool nl)| For(expr target, expr iter, stmt* body, stmt* orelse)

  expr = BoolOp(boolop op, expr* values)
       | BinOp(expr left, operator op, expr right)| Lambda(arguments args, expr body)| Dict(expr* keys, expr* values)| Num(object n) -- a number as a PyObject.
       | Str(string s) -- need to specify raw, unicode, etc?| Name(identifier id, expr_context ctx)
       | List(expr* elts, expr_context ctx) 
        -- col_offset is the byte offset in the utf8 string the parser uses
        attributes (int lineno, int col_offset)

  expr_context = Load | Store | Del | AugLoad | AugStore | Param
  boolop = And | Or 
  operator = Add | Sub | Mult | Div | Mod | Pow | LShift | RShift | BitOr | BitXor | BitAnd | FloorDiv
  arguments = (expr* args, identifier? vararg, identifier? kwarg, expr* defaults)
}
View Code

  上面是部分摘自官網的 Abstract Grammar,實際遍歷ast Node過程中根據Node的型別訪問其屬性。

 3. 遍歷AST

python提供了兩種方式來遍歷整個抽象語法樹。

3.1 ast.NodeTransfer

將func_def中的add函式中的加法運算改為減法,同時為函式實現新增呼叫日誌。

 1 class CodeVisitor(ast.NodeVisitor):
 2     def visit_BinOp(self, node):
 3         if isinstance(node.op, ast.Add):
 4             node.op = ast.Sub()
 5         self.generic_visit(node)
 6 
 7     def visit_FunctionDef(self, node):
 8         print 'Function Name:%s'% node.name
 9         self.generic_visit(node)
10         func_log_stmt = ast.Print(
11             dest = None,
12             values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)],
13             nl = True,
14             lineno = 0,
15             col_offset = 0,
16         )
17         node.body.insert(0, func_log_stmt)
18 
19 r_node = ast.parse(func_def)
20 visitor = CodeVisitor()
21 visitor.visit(r_node)
22 # print astunparse.dump(r_node)
23 print astunparse.unparse(r_node)
24 exec compile(r_node, '<string>', 'exec')

 執行結果:

Function Name:add
def add(x, y):
    print 'calling func: add'
    return (x - y)
print add(3, 5)
calling func: add
-2

 3.2 ast.NodeTransfer

使用NodeVisitor主要是通過修改語法樹上節點的方式改變AST結構,NodeTransformer主要是替換ast中的節點。

既然func_def中定義的add已經被改成一個減函數了,那麼我們就徹底一點,把函式名和引數以及被呼叫的函式都在ast中改掉,並且將新增的函式呼叫log寫的更加複雜一些,爭取改的面目全非:-)

 1 class CodeTransformer(ast.NodeTransformer):
 2     def visit_BinOp(self, node):
 3         if isinstance(node.op, ast.Add):
 4             node.op = ast.Sub()
 5         self.generic_visit(node)
 6         return node
 7 
 8     def visit_FunctionDef(self, node):
 9         self.generic_visit(node)
10         if node.name == 'add':
11             node.name = 'sub'
12         args_num = len(node.args.args)
13         args = tuple([arg.id for arg in node.args.args])
14         func_log_stmt = ''.join(["print 'calling func: %s', " % node.name, "'args:'", ", %s" * args_num % args])
15         node.body.insert(0, ast.parse(func_log_stmt))
16         return node
17 
18     def visit_Name(self, node):
19         replace = {'add': 'sub', 'x': 'a', 'y': 'b'}
20         re_id = replace.get(node.id, None)
21         node.id = re_id or node.id
22 self.generic_visit(node) 23 return node 24 25 r_node = ast.parse(func_def) 26 transformer = CodeTransformer() 27 r_node = transformer.visit(r_node) 28 # print astunparse.dump(r_node) 29 source = astunparse.unparse(r_node) 30 print source 31 # exec compile(r_node, '<string>', 'exec') # 新加入的node func_log_stmt 缺少lineno和col_offset屬性 32 exec compile(source, '<string>', 'exec') 33 exec compile(ast.parse(source), '<string>', 'exec')

結果:

def sub(a, b):
    print 'calling func: sub', 'args:', a, b
    return (a - b)
print sub(3, 5)
calling func: sub args: 3 5
-2
calling func: sub args: 3 5
-2

 程式碼中能夠清楚的看到兩者的區別。這裡不再贅述。

 4.AST應用

AST模組實際程式設計中很少用到,但是作為一種原始碼輔助檢查手段是非常有意義的;語法檢查,除錯錯誤,特殊欄位檢測等。

上面通過為函式新增呼叫日誌的資訊是一種除錯python原始碼的一種方式,不過實際中我們是通過parse整個python檔案的方式遍歷修改原始碼。

4.1 漢字檢測

下面是中日韓字元的unicode編碼範圍

CJK Unified Ideographs 

  • Range: 4E00— 9FFF
  • Number of characters: 20992
  • Languages: chinese, japanese, korean, vietnamese

使用 unicode 範圍 \u4e00 - \u9fff 來判別漢字,注意這個範圍並不包含中文字元(e.g. u';' == u'\uff1b') .

下面是一個判斷字串中是否包含中文字元的一個類CNCheckHelper:

 1 class CNCheckHelper(object):
 2     # 待檢測文字可能的編碼方式列表
 3     VALID_ENCODING = ('utf-8', 'gbk')
 4 
 5     def _get_unicode_imp(self, value, idx = 0):
 6         if idx < len(self.VALID_ENCODING):
 7             try:
 8                 return value.decode(self.VALID_ENCODING[idx])
 9             except:
10                 return self._get_unicode_imp(value, idx + 1)
11 
12     def _get_unicode(self, from_str):
13         if isinstance(from_str, unicode):
14             return None
15         return self._get_unicode_imp(from_str)
16 
17     def is_any_chinese(self, check_str, is_strict = True):
18         unicode_str = self._get_unicode(check_str)
19         if unicode_str:
20             c_func = any if is_strict else all
21             return c_func(u'\u4e00' <= char <= u'\u9fff' for char in unicode_str)
22         return False

 介面is_any_chinese有兩種判斷模式,嚴格檢測只要包含中文字串就可以檢查出,非嚴格必須全部包含中文。

下面我們利用ast來遍歷原始檔的抽象語法樹,並檢測其中字串是否包含中文字元。

 1 class CodeCheck(ast.NodeVisitor):
 2 
 3     def __init__(self):
 4         self.cn_checker = CNCheckHelper()
 5 
 6     def visit_Str(self, node):
 7         self.generic_visit(node)
 8         # if node.s and any(u'\u4e00' <= char <= u'\u9fff' for char in node.s.decode('utf-8')):
 9         if self.cn_checker.is_any_chinese(node.s, True):
10             print 'line no: %d, column offset: %d, CN_Str: %s' % (node.lineno, node.col_offset, node.s)
11 
12 project_dir = './your_project/script'
13 for root, dirs, files in os.walk(project_dir):
14     print root, dirs, files
15     py_files = filter(lambda file: file.endswith('.py'), files)
16     checker = CodeCheck()
17     for file in py_files:
18         file_path = os.path.join(root, file)
19         print 'Checking: %s' % file_path
20         with open(file_path, 'r') as f:
21             root_node = ast.parse(f.read())
22             checker.visit(root_node)

 上面這個例子比較的簡單,但大概就是這個意思。

關於CPython直譯器執行原始碼的過程可以參考官網描述:PEP 339

4.2 Closure 檢查

一個函式中定義的函式或者lambda中引用了父函式中的local variable,並且當做返回值返回。特定場景下閉包是非常有用的,但是也很容易被誤用。

關於python閉包的概念可以參考我的另一篇文章:理解Python閉包概念

這裡簡單介紹一下如何藉助ast來檢測lambda中閉包的引用。程式碼如下:

 1 class LambdaCheck(ast.NodeVisitor):
 2 
 3     def __init__(self):
 4         self.illegal_args_list = []
 5         self._cur_file = None
 6         self._cur_lambda_args = []
 7 
 8     def set_cur_file(self, cur_file):
 9         assert os.path.isfile(cur_file), cur_file
10         self._cur_file = os.path.realpath(cur_file)
11 
12     def visit_Lambda(self, node):
13         """
14         lambda 閉包檢查原則:
15         只需檢測lambda expr body中args是否引用了lambda args list之外的引數
16         """
17         self._cur_lambda_args =[a.id for a in node.args.args]
18         print astunparse.unparse(node)
19         # print astunparse.dump(node)
20         self.get_lambda_body_args(node.body)
21         self.generic_visit(node)
22 
23     def record_args(self, name_node):
24         if isinstance(name_node, ast.Name) and name_node.id not in self._cur_lambda_args:
25             self.illegal_args_list.append((self._cur_file, 'line no:%s' % name_node.lineno, 'var:%s' % name_node.id))
26 
27     def _is_args(self, node):
28         if isinstance(node, ast.Name):
29             self.record_args(node)
30             return True
31         if isinstance(node, ast.Call):
32             map(self.record_args, node.args)
33             return True
34         return False
35 
36     def get_lambda_body_args(self, node):
37         if self._is_args(node): return
38         # for cnode in ast.walk(node):
39         for cnode in ast.iter_child_nodes(node):
40             if not self._is_args(cnode):
41                 self.get_lambda_body_args(cnode)

 遍歷工程檔案:

 1 project_dir = './your project/script'
 2 for root, dirs, files in os.walk(project_dir):
 3     py_files = filter(lambda file: file.endswith('.py'), files)
 4     checker = LambdaCheck()
 5     for file in py_files:
 6         file_path = os.path.join(root, file)
 7         checker.set_cur_file(file_path)
 8         with open(file_path, 'r') as f:
 9             root_node = ast.parse(f.read())
10             checker.visit(root_node)
11     res = '\n'.join([' ## '.join(info) for info in checker.illegal_args_list])
12     print res
View Code

 由於Lambda(arguments args, expr body)中的body expression可能非常複雜,上面的例子中僅僅處理了比較簡單的body expr。可根據自己工程特點修改和擴充套件檢查規則。為了更加一般化可以單獨寫一個visitor類來遍歷lambda節點。

Ast的應用不僅限於上面的例子,限於篇幅,先介紹到這裡。期待ast能幫助你解決一些比較棘手的問題。