1. 程式人生 > 其它 >numpy高維陣列獲取top-K

numpy高維陣列獲取top-K

技術標籤:Python

文章目錄

前言

理論知識請自行翻閱numpy的argpartition和partition方法的實現原理,該文章僅僅包含使用和效率驗證。此外,numpy版本需要>=1.8.0。

正文

不廢話了,直接放程式碼,一看就懂,看不懂再說,自己跑一下就知道。

import numpy as np

def get_sorted_top_k(array, top_k=1, axis=-1, reverse=False):
    """
    多維陣列排序
    Args:
        array: 多維陣列
        top_k: 取數
        axis: 軸維度
        reverse: 是否倒序

    Returns:
        top_sorted_scores: 值
        top_sorted_indexes: 位置
    """
if reverse: # argpartition分割槽排序,在給定軸上找到最小的值對應的idx,partition同理找對應的值 # kth表示在前的較小值的個數,帶來的問題是排序後的結果兩個分割槽間是仍然是無序的 # kth絕對值越小,分割槽排序效果越明顯 axis_length = array.shape[axis] partition_index = np.take(np.argpartition(array, kth=-top_k, axis=axis), range
(axis_length - top_k, axis_length), axis) else: partition_index = np.take(np.argpartition(array, kth=top_k, axis=axis), range(0, top_k), axis) top_scores = np.take_along_axis(array, partition_index, axis) # 分割槽後重新排序 sorted_index = np.argsort(top_scores, axis=axis) if reverse:
sorted_index = np.flip(sorted_index, axis=axis) top_sorted_scores = np.take_along_axis(top_scores, sorted_index, axis) top_sorted_indexes = np.take_along_axis(partition_index, sorted_index, axis) return top_sorted_scores, top_sorted_indexes if __name__ == "__main__": import time from sklearn.metrics.pairwise import cosine_similarity x = np.random.rand(10, 128) y = np.random.rand(1000000, 128) z = cosine_similarity(x, y) start_time = time.time() sorted_index_1 = get_sorted_top_k(z, top_k=3, axis=1, reverse=True)[1] print(time.time() - start_time) start_time = time.time() sorted_index_2 = np.flip(np.argsort(z, axis=1)[:, -3:], axis=1) print(time.time() - start_time) print((sorted_index_1 == sorted_index_2).all())

後記

不吹比的說一句,這段程式碼看著perfect好吧,效率提升不少。