Python中numpy的where()函式
阿新 • • 發佈:2019-01-06
第一種用法
np.where(conditions,x,y)
if (conditions成立):
陣列變x
else:
陣列變y
In [61]: x=np.random.randn(4,4) In [62]: x Out[62]: array([[ 1.2256504 , 0.81650419, -0.06063334, -0.37776736], [-0.21559056, -0.77642262, 0.48999826, 0.04118671], [ 0.22457745, 0.90930544, 1.75082994, 0.95332844], [-2.21076019, 0.32498938, -1.51440206, -1.795866 ]]) In [63]: print(np.where(x>0,2,-2)) [[ 2 2 -2 -2] [-2 -2 2 2] [ 2 2 2 2] [-2 2 -2 -2]] #如果是一維,相當於[xv if c else yv for (c,xv,yv) in zip(condition,x,y)] In [64]: %paste xarr = np.array([1.1,1.2,1.3,1.4,1.5]) yarr = np.array([2.1,2.2,2.3,2.4,2.5]) zarr = np.array([True,False,True,True,False]) result = [(x if c else y) for x,y,c in zip(xarr,yarr,zarr)] print(result) ## -- End pasted text -- [1.1, 2.2, 1.3, 1.4, 2.5] In [66]: %paste result = np.where(zarr,xarr,yarr) print(result) ## -- End pasted text -- [1.1 2.2 1.3 1.4 2.5]
第二種用法
where(conditions)
相當於給出陣列的下標
In [68]: x=np.arange(16)
In [69]: x
Out[69]: array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15])
In [70]: print(x[np.where(x>5)])
[ 6 7 8 9 10 11 12 13 14 15]