1. 程式人生 > >【Python】我是如何使計算時間提速25.6倍的

【Python】我是如何使計算時間提速25.6倍的

# 我是如何使計算時間提速25.6倍的 > 我的原始文件: 在顯著性目標檢測任務中有個重要的評價指標, E-measure, 需要使用在閉區間 `[0, 255]` 內連續變化的閾值對模型預測的灰度圖二值化. 直接的書寫方式就是使用 `for` 迴圈, 將對應的閾值送入指標得分計算函式中, 讓其計算分割後的預測結果和真值mask之間的統計相似度. 在顯著性目標檢測中, 另一個指標, F-measure, 同樣涉及到連續變化的閾值二值化處理, 但是該指標計算僅需要precision和recall, 這兩項實際上僅需要正陽性(TP)和假陽性(FP)元素數量, 以及總的正(T)樣本元素數量. T可以使用 `np.count_nonzero(gt)` 來計算, 而前兩項則可以直接利用累計直方圖的策略一次性得到所有的256個TP、FP數量對, 分別對應不同的閾值. 這樣就可以非常方便且快速的計算出來這一系列的指標結果. 這實際上是對於F-measure計算的一種非常有效的加速策略. 但是不同的是, E-measure的計算方式(需要減去對應二值圖的均值後進行計算)導致按照上面的這種針對變化閾值加速計算的策略並不容易變通, 至少我目前沒有這樣使用. 但是最後我找到了一種更加(相較於原始的 `for` 策略)高效的計算方式, 這裡簡單做一下思考和實驗重現的記錄. ## 選擇使用更合適的函式 雖然運算主要基於 `numpy` 的各種函式, 但是針對同一個目的不同的函式實現方式也是有明顯的速度差異的, 這裡簡單彙總下: ### 統計非零元素數量首選 `np.count_nonzero(array)` 我想到的針對二值圖的幾種不同的實現: ``` python import time import numpy as np # 快速統計numpy陣列的非零值建議使用np.count_nonzero,一個簡單的小實驗 def cal_nonzero(size): a = np.random.randn(size, size) a = a >
0 start = time.time() print(np.count_nonzero(a), time.time() - start) start = time.time() print(np.sum(a), time.time() - start) start = time.time() print(len(np.nonzero(a)[0]), time.time() - start) start = time.time() print(len(np.where(a)), time.time() - start) if __name__ == '__main__': cal_nonzero(1000) # 499950 6.723403930664062e-05 # 499950 0.0006949901580810547 # 499950 0.007088184356689453 ``` 可以看到, 最合適的是 `np.count_nonzero(array)` 了. ### 更快的交集計算方式 ``` python import time import numpy as np # 快速統計numpy陣列的非零值建議使用np.count_nonzero,一個簡單的小實驗 def cal_andnot(size): a = np.random.randn(size, size) b = np.random.randn(size, size) a = a >
0 b = b < 0 start = time.time() a_and_b_mul = a * b _a_and__b_mul = (1 - a) * (1 - b) print(time.time() - start) start = time.time() a_and_b_and = a & b _a_and__b_and = ~a & ~b print(time.time() - start) if __name__ == '__main__': cal_andnot(1000) # 0.0036919116973876953 # 0.0005502700805664062 ``` 可見, 對於bool陣列, numpy的位運算是要更快更有效的. 而且bool陣列可以直接用來索引矩陣即 `array[bool_array]` , 非常方便. ## 邏輯的改進 經過儘可能的挑選更加快速的計算函式之後, 目前速度受限的最大問題就是這個 `for` 迴圈中的256次矩陣運算了. 也就是這部分程式碼: ``` python ... def step(self, pred: np.ndarray, gt: np.ndarray): pred, gt = _prepare_data(pred=pred, gt=gt) self.all_fg = np.all(gt) self.all_bg = np.all(~gt) self.gt_size = gt.shape[0] * gt.shape[1] if self.changeable_ems is not None: changeable_ems = self.cal_changeable_em(pred, gt) self.changeable_ems.append(changeable_ems) adaptive_em = self.cal_adaptive_em(pred, gt) self.adaptive_ems.append(adaptive_em) def cal_adaptive_em(self, pred: np.ndarray, gt: np.ndarray) -> float: adaptive_threshold = _get_adaptive_threshold(pred, max_value=1) adaptive_em = self.cal_em_with_threshold(pred, gt, threshold=adaptive_threshold) return adaptive_em def cal_changeable_em(self, pred: np.ndarray, gt: np.ndarray) -> list: changeable_ems = [self.cal_em_with_threshold(pred, gt, threshold=th) for th in np.linspace(0, 1, 256)] return changeable_ems def cal_em_with_threshold(self, pred: np.ndarray, gt: np.ndarray, threshold: float) -> float: binarized_pred = pred >= threshold if self.all_bg: enhanced_matrix = 1 - binarized_pred elif self.all_fg: enhanced_matrix = binarized_pred else: enhanced_matrix = self.cal_enhanced_matrix(binarized_pred, gt) em = enhanced_matrix.sum() / (gt.shape[0] * gt.shape[1] - 1 + _EPS) return em def cal_enhanced_matrix(self, pred: np.ndarray, gt: np.ndarray) -> np.ndarray: demeaned_pred = pred - pred.mean() demeaned_gt = gt - gt.mean() align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS) enhanced_matrix = (align_matrix + 1) ** 2 / 4 return enhanced_matrix ... ``` 可以看到, 這裡對於每一個閾值都要計算一遍同樣的流程, 如果每次的計算都比較耗時的話, 那麼總體時間也就很難減下來. 所以需要探究如何降低這裡的 `cal_enhanced_matrix` 的耗時. 前面的嘗試都是在程式碼函式選擇層面的改進, 但是對於這裡, 這樣的思路已經很難產生明顯的效果了. 那麼我們就應該轉變思路了, 應該從計算流程本身上思考. 可以按照下面這一系列思考來引出最終的一種比較好的策略. * 這裡計算為什麼會那麼慢? + 因為涉及到了大量的矩陣元素級的運算, 例如元素級減法、加法、乘法、平方、除法. * 大量的元素級運算是否可以優化? + 必須可以:< * 如何優化元素級運算? + 尋找規律性、重複性的計算, 將其合併、消減, 可以聯想numpy的稀疏矩陣的思想. * 規律性、重複性的計算在哪裡? + 去均值實際上是對每個元素減去了相同的一個值, 如果被減數可以優化, 那麼這一步就可以被優化 + 元素乘法和平方涉及到兩部分, `demeaned_gt`和`demeaned_pred`, 如果這兩個可以被優化, 那麼這些運算就都可以被優化 + 這些元素運算的連鎖關係導致了只要我們優化了最初的`pred`和`gt`, 那麼整個流程就都可以被優化 * 如何優化`pred`和`gt`的表示? + 這裡需要從二者本身的屬性上入手 * 二者最大的特點是什麼? + 都是二值陣列, 只有0和1 * 那如何優化? + 實際上就借鑑了稀疏矩陣的思想, 既然存在大量的重複性, 那麼我們就將數值與位置解耦, 優化表示方式 * 如何解耦? + 以`gt`為例, 可以表示為0和1兩種資料, 其中0對應背景, 1對應前景, 0的數量表示背景面積, 1的數量表示前景面積 * 那如何使用該思想重構前面的計算呢? 到最後一個問題, 實際上核心策略已經出現, 就是"解耦", 將數值與位置解耦. 這裡需要具體分析下, 我們直接將 `pred` 和 `gt` 拆分成數值和數量, 是可以比較好的處理 `demeaned_*` 項的表示的, 也就是: ``` python # demeaned_pred = pred - pred.mean() # demeaned_gt = gt - gt.mean() pred_fg_numel = np.count_nonzero(binarized_pred) pred_bg_numel = self.gt_size - pred_fg_numel gt_fg_numel = np.count_nonzero(gt) gt_bg_numel = self.gt_size - gt_fg_numel mean_pred_value = pred_fg_numel / self.gt_size mean_gt_value = gt_fg_numel / self.gt_size demeaned_pred_fg_value = 1 - mean_pred_value demeaned_pred_bg_value = 0 - mean_pred_value demeaned_gt_fg_value = 1 - mean_gt_value demeaned_gt_bg_value = 0 - mean_gt_value ``` 接下來需要進一步優化後面的乘法和加法了, 因為這裡同時涉及到了同一位置的 `pred` 和 `gt` 的值, 這就需要注意了, 因為二者前景與背景對應關係並不明確, 這就得分情況考慮了. 總體而言, 包含四種情況, 就是: 1. pred: fg; gt: fg 1. pred: fg; gt: bg 1. pred: bg; gt: fg 1. pred: bg; gt: bg 而這些區域實際上是對前面初步解耦區域的進一步細化, 所以我們重新整理思路, 可以將整個流程構造如下: ``` python fg_fg_numel = np.count_nonzero(binarized_pred & gt) fg_bg_numel = np.count_nonzero(binarized_pred & ~gt) # bg_fg_numel = np.count_nonzero(~binarized_pred & gt) bg_fg_numel = self.gt_fg_numel - fg_fg_numel # bg_bg_numel = np.count_nonzero(~binarized_pred & ~gt) bg_bg_numel = self.gt_size - (fg_fg_numel + fg_bg_numel + bg_fg_numel) parts_numel = [fg_fg_numel, fg_bg_numel, bg_fg_numel, bg_bg_numel] mean_pred_value = (fg_fg_numel + fg_bg_numel) / self.gt_size mean_gt_value = self.gt_fg_numel / self.gt_size demeaned_pred_fg_value = 1 - mean_pred_value demeaned_pred_bg_value = 0 - mean_pred_value demeaned_gt_fg_value = 1 - mean_gt_value demeaned_gt_bg_value = 0 - mean_gt_value combinations = [(demeaned_pred_fg_value, demeaned_gt_fg_value), (demeaned_pred_fg_value, demeaned_gt_bg_value), (demeaned_pred_bg_value, demeaned_gt_fg_value), (demeaned_pred_bg_value, demeaned_gt_bg_value)] ``` 這裡忽略掉了一些不必要的計算, 能直接使用現有量就使用現有的量. 針對前面的這些解耦, 後面就可以比較簡單的書寫了: ``` python results_parts = [] for part_numel, combination in zip(parts_numel, combinations): # align_matrix = 2 * (demeaned_gt * demeaned_pred) / (demeaned_gt ** 2 + demeaned_pred ** 2 + _EPS) align_matrix_value = 2 * (combination[0] * combination[1]) / \ (combination[0] ** 2 + combination[1] ** 2 + _EPS) # enhanced_matrix = (align_matrix + 1) ** 2 / 4 enhanced_matrix_value = (align_matrix_value + 1) ** 2 / 4 results_parts.append(enhanced_matrix_value * part_numel) # enhanced_matrix = enhanced_matrix.sum() enhanced_matrix = sum(results_parts) ``` 由於不同區域元素結果一致, 而區域的面積也已知, 所以最終 `cal_em_with_threshold` 中的 `enhanced_matrix.sum()` 其實更適合放到 `cal_enhanced_matrix` 中, 可以一便計算出來. 為了儘可能重用現有變數, 我們其實反過來可以優化 `cal_em_with_threshold` : ``` python binarized_pred = pred >
= threshold if self.all_bg: enhanced_matrix = 1 - binarized_pred elif self.all_fg: enhanced_matrix = binarized_pred else: enhanced_matrix = self.cal_enhanced_matrix(binarized_pred, gt) em = enhanced_matrix.sum() / (gt.shape[0] * gt.shape[1] - 1 + _EPS) ``` 這裡的 `self.all_bg` 和 `self.all_fg` 實際上可以使用 `self.gt_fg_numel` 和 `self.gt_size` 表示, 也就是隻需計算一次 `np.count_nonzero(array)` 就可以了. 另外在 `cal_em_with_threshold` 中 `if` 的前兩個分支中, 需要將 `sum` 整合到各個分支內部(else分支已經被整合到了 `cal_enhanced_matrix` 方法中), `(1-binarized_pred).sum()` 和 `binarized_pred.sum()` 實際上就是表示背景畫素數量和前景畫素數量. 所以可以藉助於更快的 `np.count_nonzero(array)` , 從而改成如下形式: ``` python binarized_pred = pred >= threshold if self.gt_fg_numel == 0: binarized_pred_bg_numel = np.count_nonzero(~binarized_pred) enhanced_matrix_sum = binarized_pred_bg_numel elif self.gt_fg_numel == self.gt_size: binarized_pred_fg_numel = np.count_nonzero(binarized_pred) enhanced_matrix_sum = binarized_pred_fg_numel else: enhanced_matrix_sum = self.cal_enhanced_matrix(binarized_pred, gt) em = enhanced_matrix_sum / (self.gt_size - 1 + _EPS) ``` ## 效率對比 使用本地的845張灰度預測圖和二值mask真值資料進行測試比較, 總體時間對比如下: * 'base': 503.5014679431915s * 'best': 19.27734637260437s 雖然具體時間可能還受硬體限制, 但是相對快慢還是比較明顯的. 變為原來的19/504~=4%, 快了504/19~=26.5倍. 測試程式碼可見我的 `github` :