1. 程式人生 > >Python中numpy的where()函式

Python中numpy的where()函式

第一種用法

np.where(conditions,x,y)

if (conditions成立):

  陣列變x

else:

  陣列變y

In [61]: x=np.random.randn(4,4)

In [62]: x
Out[62]:
array([[ 1.2256504 ,  0.81650419, -0.06063334, -0.37776736],
       [-0.21559056, -0.77642262,  0.48999826,  0.04118671],
       [ 0.22457745,  0.90930544,  1.75082994,  0.95332844],
       [-2.21076019,  0.32498938, -1.51440206, -1.795866  ]])

In [63]: print(np.where(x>0,2,-2))
[[ 2  2 -2 -2]
 [-2 -2  2  2]
 [ 2  2  2  2]
 [-2  2 -2 -2]]

#如果是一維,相當於[xv if c else yv for (c,xv,yv) in zip(condition,x,y)]

In [64]: %paste
xarr = np.array([1.1,1.2,1.3,1.4,1.5])
yarr = np.array([2.1,2.2,2.3,2.4,2.5])
zarr = np.array([True,False,True,True,False])
result = [(x if c else y)
          for x,y,c in zip(xarr,yarr,zarr)]
print(result)

## -- End pasted text --
[1.1, 2.2, 1.3, 1.4, 2.5]

In [66]: %paste
result = np.where(zarr,xarr,yarr)
print(result)

## -- End pasted text --
[1.1 2.2 1.3 1.4 2.5]

第二種用法

where(conditions) 

相當於給出陣列的下標

In [68]: x=np.arange(16)

In [69]: x
Out[69]: array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15])

In [70]: print(x[np.where(x>5)])
[ 6  7  8  9 10 11 12 13 14 15]