1. 程式人生 > 其它 >(轉)Python 運算子過載2

(轉)Python 運算子過載2

原文:https://zhuanlan.zhihu.com/p/358748722

前言

運算子過載這個語言特性其實一直備受爭議,鑑於太多 C++ 程式設計師濫用這個特性,Java 之父 James Gosling 很乾脆的決定不為 Java 提供運算子過載功能。但另一方面,正確的使用運算子過載確實能提高程式碼的可讀性和靈活性。為此,Python 施加了一些限制,在靈活性、可用性和安全性之間做到了平衡。主要包括:

  • 不能過載內建型別的運算子
  • 不能新建運算子,只能過載現有的
  • is、and、or 和 not 運算子不能過載(但位運算子 &、\| 和 ~ 可以)

Python 的運算子過載非常方便,只需要重寫對應的特殊方法。在上面一節我們已經介紹瞭如何過載一個向量類的 "+" 和 "==" 運算子,實現還算簡單,接下來我們考慮一個更復雜的情形:不只限於二維向量相加的 Vector 類,以引入 Python 運算子過載更全面的知識點。

改進版的 Vector

考慮到高維向量的應用場景,我們應當支援不同維度向量的相加操作,並且為低維向量的缺失值做預設添 0 處理,這也是一些統計分析應用的常用缺失值處理方式。基於此,首先要確定的便是,Vector 類的建構函式不再只接收固定數量和位置的引數,而應當接收可變引數。

通常情況下,Python 函式接收可變引數有兩種處理方式。一種是接收不定長引數,即*args,這樣我們就可以用類似Vector(1, 2)Vector(1, 2, 3)的方式來初始化不同維數的向量類。在這種情況下,函式會將不定長引數打包成名為args的元組進行處理,當然能滿足迭代的需求。雖然這種方式看上去很直觀,但考慮到向量類從功能上講也是一個序列類,而 Python 中的內建序列型別的構造方法基本上都是接收可迭代物件(Iterable)作為引數,考慮到一致性我們也採取這種形式,並且通過重寫__repr__

輸出更直觀的向量類的數學表示形式。

class Vector:
    def __init__(self, components: Iterable):
        self._components = array('i', components)

    def __repr__(self):
        return str(tuple(self._components))

為了方便之後對向量分量的處理,將其儲存在一個數組中,第一個引數 ‘i’ 標明這是一個整型陣列。這樣做還有一個好處就是,保證了向量序列的不可變性,這一點同 Python 內建型別不可變列表 tuple 類似。如此定義後,我們可以這樣例項化 Vector 類:

>>> from vector import Vector
>>> Vector([1, 2])
(1, 2)
>>> Vector((1, 2, 3))
(1, 2, 3)
>>> Vector(range(4))
(0, 1, 2, 3)

由於 Vector 類接收 Iterable 物件作為構造引數,而任何實現了__iter__方法的類都會被繫結為 Iterable 的子類,所以可以傳入 list、tuple 和 range 等可迭代物件。

接下來,過載 Vector 類的加號運算子,為了滿足之前所說的低維向量預設添 0 處理,我們引入迭代工具包下的zip_longest方法,它可以接收多個可迭代物件,將其打包成一個個的元組,如zip_longest(p, q, ...) --> (p[0], q[0]), (p[1], q[1]), ...。同時關鍵字引數 fillvalue 可以指定填充的預設值。但在這之前,由於zip_longest引數必須是可迭代物件,我們還需要為 Vector 類實現__iter__方法。

class Vector:
    def __iter__(self):
        return iter(self._components)

    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0)
        return Vector(a + b for a, b in pairs)

__add__的實現邏輯很簡單,按位相加返回一個新的 Vector 物件,在構造 Vector 物件時使用到了生成器表示式,而生成器 Generator 是 Iterable 的子類,所以也符合構造引數的要求。

為了驗證效果,還需要過載==運算子,考慮到兩個向量維度可能不同,首先要對維度,也就是向量分量的個數進行比較,為此需要重寫__len__方法。其次是進行按位比較,內建的 zip 函式可以將兩個迭代物件打包從而同時進行遍歷。

class Vector:
    def __len__(self):
        return len(self._components)

    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))

最佳實踐:用 zip 函式同時遍歷兩個迭代器。《Effective Python》的第 11 條提到了這一點。在 Python 中經常會遇到需要平行地迭代兩個序列的情況。一般的做法是,寫一個 for 迴圈對一個序列進行迭代,然後想辦法獲得其索引,通過索引訪問第二個序列的對應元素。常見的做法是藉助 enumerate 函式,通過for index, item in enumerate(items)的方式獲取索引。現在有一種更優雅的寫法,使用內建的 zip 函式,它可以將兩個及以上的迭代器封裝成生成器,這個生成器能在每次迭代時從每個迭代器中取出下一個值構成元組,再結合元組拆包就能達到平行取值的目的,如上述程式碼中的for a, b in zip(self, other)。顯然,這種方式可讀性更高。但如果待遍歷序列不等長,zip 函式會提前終止,這可能導致意外的結果。所以在不確定序列是否等長的條件下,可以考慮使用 itertools 模組中過的zip_longest函式。

至此,過載的 "+" 和 "==" 運算子初步完成了,可以編寫測試用例進行驗證了,作為本系列第一個比較全面的測試類,我將在文末貼出完整的測試程式碼,這裡先在控制檯演示過載之後的效果。

>>> v1 = Vector([1, 2])
>>> v1 == (1, 2)
True
>>> v1 + Vector((1, 1))
(2, 3)
>>> v1 + [1, 1]
(2, 3)
>>> v1 + (1, 1, 1)
(2, 3, 1)

由於__add__方法中的 other 只要求是可迭代物件而沒有型別限制,所以過載的加號運算子不止可以對兩個 Vector 例項進行相加,也支援 Vector 例項與一個可迭代物件相加,不管是 list、tuple 還是其他 Iterable 型別。但需要注意的是,可迭代物件必須作為第二個運算元,也就是 "+" 右側的運算元。理解這一點並不難,因為我們只實現了 Vector 的__add__方法,而 Python 的內建型別類可不明白怎麼對加上一個向量進行處理,比如下面報錯提示的 tuple 類。

>>> (1, 1) + v1
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: can only concatenate tuple (not "Vector") to tuple

反向運算子

那麼有什麼方法,不需要重寫 tuple 類中的__add__方法(顯然這種方式也不合理),也能使過載的加號運算子支援(1, 1) + v1呢?答案是有的,在此之前,不得不提到 Python 的運算子分派機制。

對於中綴運算子,Python 提供了特殊的分派機制。對於表示式a + b,直譯器會執行以下幾步操作:

  1. 如果 a 有__add__方法且不返回 NotImplemented,呼叫a.__add__(b)
  2. 如果 a 沒有__add__方法或呼叫返回 NotImplemented,檢查 b 有沒有__radd__方法,如果有且不返回 NotImplemented,呼叫b.__radd__(a)
  3. 如果 b 沒有__radd__方法或呼叫返回 NotImplemented,丟擲 TypeError。

注:NotImplemented 是 Python 內建的特殊單例值,如果運算子特殊方法不能處理給定的運算元,那麼要把它返回給直譯器。

如果將__add__稱為正向方法,那麼__radd__就可以稱為__add__方法的反向方法,或者右向方法,這個方法的作用是支援運算元從右至左進行計算。因此,為了支援(1, 1) + v1,我們需要定義 Vector 類的反向方法。而反向方法只需要委託給已經定義好的__add__方法。

class Vector:
    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

__radd__通常就是這麼簡單,由於直譯器呼叫的是b.__radd__(a),而這裡的 b 即 v1 是一個 Vector 例項,能夠與一個元組相加,所以這時(1, 1) + v1不會再報錯。同時,還對__add__方法做了些修改:捕獲 TypeError 異常並返回 NotImplemented。這也是一種過載中綴運算子時的最佳實踐,丟擲異常將導致算符分派機制終止,而丟擲 NotImplemented 則會讓直譯器再嘗試呼叫反向運算子方法。當運算子左右運算元是不同型別時,反向方法也許能夠正常運算。

現在,驗證過載的反向運算子:

>>> v1 = Vector([1, 2])
>>> (1, 1) + v1
(2, 3)
>>> [1, 1, 1] + v1
(2, 3, 1)

比較運算子

對於比較運算子,正向和反向呼叫使用的是同一系列方法,只不過對調了引數。注意是同一系列而不是同一方法。例如,對 "==" 來說,正向呼叫是a.__eq__(b),那麼反向呼叫就是b.__eq__(a);而對 ">" 來說,正向a.__gt__(b)的反向呼叫是b.__lt__(a)

如果正向呼叫左運算元的__eq__方法返回 NotImplemented,Python 直譯器會去嘗試反向呼叫右運算元的__eq__方法,若右運算元也返回 NotImplemented,直譯器不會丟擲 TypeError 異常,而是會比較物件的 ID 作最後一搏。

對元組和 Vector 例項比較的具體步驟如下:

  1. 嘗試呼叫 tuple 的__eq__方法,由於 tuple 不認識 Vector 類,返回 NotImplemented;
  2. 嘗試呼叫 Vector 的__eq__方法,返回 True。
>>> (1, 2) == Vector([1, 2])
True

另外,對於 "!=" 運算子,Python 3 的最佳實踐是隻實現__eq__方法而不實現它,因為從 object 繼承來的__ne__方法會對__eq__返回的結果取反。而 Python 2 則不同,過載 "==" 的同時也應過載 "!=" 運算子。Python 之父 Guido 曾提到這是 Python 2 的一個設計缺陷且已在 Python 3 中修復了。

就地運算子

增量賦值運算子,也稱就地運算子,如 "+=",有兩種運算方式。對於不可變型別來說,a += b的作用與a = a + b完全一致,增量賦值不會修改不可變目標,而是新建例項,然後重新繫結,也就是說運算前後的 a 不是同一物件。對於不可變型別,這是預期的行為。

而對於實現了就地運算子方法,如__iadd__,的可變型別來說,a += b會呼叫該方法就地修改左運算元,而不是建立一個新的物件。這一點,Python 的內建型別,不可變的 tuple 和可變的 list 就可以很好的說明。

>>> t = (1, 2)
>>> id(t)
4359598592
>>> t += (3,)
>>> id(t)
4359584960
>>> l = [1, 2]
>>> id(l)
4360054336
>>> l += [3, 4]
>>> id(l)
4360054336

閱讀原始碼你會發現,list 類 實現了__iadd__方法而 tuple 類沒有實現。對 list 而言,"+=" 就地運算子的邏輯與其extend()方法相同,將一個可迭代物件的元素依次追加到當前列表的末尾。而對 tuple 而言,即使沒有定義__iadd__方法,使用 "+=" 也會委託給__add__方法進行運算返回一個新的 tuple 物件。

從設計層面考慮,Vector 應當與元組一致,被設計成不可變型別,即每次對向量進行運算後生成一個新的向量。站在函數語言程式設計的角度,這種設計無副作用(不在函式內部修改傳入引數狀態),從而避免一些難以預料的問題。因此對於不可變型別,一定不能實現就地特殊方法。對 Vector 使用 "+=" 運算子會呼叫現有的__add__方法生成一個新的 Vector 例項。v1 += (1, 1)v1 = v1 + (1, 1)行為一致。

>>> v1 = Vector([1, 2])
>>> id(v1)
4360163280
>>> v1 += (1, 1)
>>> v1
(2, 3)
>>> id(v1)
4359691376

附錄:程式碼

vector.py

import itertools
from array import array
from collections.abc import Iterable


class Vector:
    def __init__(self, components: Iterable):
        self._components = array('i', components)

    def __iter__(self):
        return iter(self._components)

    def __len__(self):
        return len(self._components)

    def __repr__(self):
        return str(tuple(self._components))

    def __eq__(self, other):
        return len(self) == len(other) and all(a == b for a, b in zip(self, other))

    def __add__(self, other):
        try:
            pairs = itertools.zip_longest(self, other, fillvalue=0)
            return Vector(a + b for a, b in pairs)
        except TypeError:
            return NotImplemented

    def __radd__(self, other):
        return self + other

vector_test.py

from vector import Vector


class TestVector:
    def test_should_compare_two_vectors_with_override_compare_operators(self):
        v1 = Vector([1, 2])
        v2 = Vector((1, 2))
        v3 = Vector([2, 3])
        v4 = Vector([2, 3, 4])

        assert v1 == v2
        assert v3 != v2
        assert v4 != v3
        assert (1, 2) == v2
        assert v2 == [1, 2]

    def test_should_add_two_same_dimension_vectors_with_override_add_operator(self):
        v1 = Vector([1, 2])
        v2 = Vector((1, 3))
        result = Vector([2, 5])

        assert result == v1 + v2

    def test_should_add_two_different_dimension_vectors_with_override_add_operator(self):
        v1 = Vector([1, 2])
        v2 = Vector((1, 1, 1))
        result = Vector([2, 3, 1])

        assert result == v1 + v2

    def test_should_add_vector_and_iterable_with_override_add_operator(self):
        v1 = Vector([1, 2])

        assert v1 + (1, 1) == (2, 3)
        assert v1 + [1, 1, 1] == (2, 3, 1)

    def test_should_add_iterable_and_vector_with_override_radd_method(self):
        v1 = Vector([1, 2])

        assert (1, 1) + v1 == (2, 3)
        assert [1, 1, 1] + v1 == (2, 3, 1)

    def test_should_create_new_vector_when_use_incremental_add_operator(self):
        v1 = Vector([1, 2])
        id1 = id(v1)
        v1 += (1, 1)

        assert id(v1) != id1

附錄:常見可過載運算子

一元運算子

二元運算子

比較運算子

技術連結