Issue
Here's my question: I'm trying to filter an image based on the values of two coordinates. I can do this easily with a for loop:
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-1, 1, 101)
y = np.linspace(-1, 1, 51)
X, Y = np.meshgrid(x, y)
Z = np.exp(X**2 + Y**2)
newZ = np.zeros(Z.shape)
for i, y_i in enumerate(y):
for j, x_j in enumerate(x):
if x_j > 0 and y_i > 0:
newZ[i, j] = Z[i, j]
else:
newZ[i, j] = 0
plt.contourf(x, y, newZ)
plt.show()
but I'm pretty sure there should be a way to do it by indexing (as it should be faster) like:
Z = Z[y>0, x>0]
which doesn't work (IndexError: Shape mismatch
)
I assume I could do this with a mask using masked array perhaps (it seems they are doing something like that here), but I wonder if there is a simple one-liner in normal numpy that I can't seem to figure out. Thanks
Solution
Edit, you have to use numpy broadcasting:
m = (y[:, None] > 0) | (x > 0)
newZ = np.where(m, Z, 0)
# OR
m = (y[:, None] > 0) | (x > 0)
newZ = np.zeros(Z.shape)
newZ[m] = Z[m]
Demo:
x = np.linspace(-1, 1, 11)
y = np.linspace(-1, 1, 6)
Z = np.arange(len(y)*len(x)).reshape(len(y), len(x))
m = (y[:, None] > 0) | (x > 0)
newZ = np.zeros(Z.shape, dtype=int)
newZ[m] = Z[m]
Output:
>>> x
array([-1. , -0.8, -0.6, -0.4, -0.2, 0. , 0.2, 0.4, 0.6, 0.8, 1. ])
>>> y
array([-1. , -0.6, -0.2, 0.2, 0.6, 1. ])
>>> Z
array([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
[22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],
[33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43],
[44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]])
>>> newZ
array([[ 0, 0, 0, 0, 0, 0, 6, 7, 8, 9, 10],
[ 0, 0, 0, 0, 0, 0, 17, 18, 19, 20, 21],
[ 0, 0, 0, 0, 0, 0, 28, 29, 30, 31, 32],
[33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43],
[44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]])
With masked array:
>>> np.ma.masked_array(Z, ~m, fill_value=0)
masked_array(
data=[[--, --, --, --, --, --, 6, 7, 8, 9, 10],
[--, --, --, --, --, --, 17, 18, 19, 20, 21],
[--, --, --, --, --, --, 28, 29, 30, 31, 32],
[33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43],
[44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54],
[55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]],
mask=[[ True, True, True, True, True, True, False, False, False,
False, False],
[ True, True, True, True, True, True, False, False, False,
False, False],
[ True, True, True, True, True, True, False, False, False,
False, False],
[False, False, False, False, False, False, False, False, False,
False, False],
[False, False, False, False, False, False, False, False, False,
False, False],
[False, False, False, False, False, False, False, False, False,
False, False]],
fill_value=0)
IIUC, you are looking for np.where
:
newZ = np.where(Z>0, Z, 0)
Same as:
m = Z>0
newZ = np.zeros(Z.shape)
newZ[m] = Z[m]
Alternative with np.clip
:
newZ = np.clip(Z, 0, np.inf)
Answered By - Corralien
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.