Issue
Suppose I have the following two torch.Tensors:
x = torch.tensor([0,0,0,1,1,2,2,2,2], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
I want to somehow filter x
such that only the values that are in y
remain:
x_filtered = torch.tensor([0,0,0,2,2,2,2])
For another example, if y = torch.tensor([0,1])
, then x_filtered = torch.tensor([0,0,0,1,1])
. Both x,y
are always 1D and int64. y
is always sorted, if it makes it simpler, we can assume that x
is always sorted as well.
I tried to think of various ways to do it without using loops, but failed. I cannot really use loops because my use case involves x
in the millions and y
in tens of thousands. Any help is appreciated.
Just realised what I need is the torch equivalent of numpy.in1d
Solution
For filtering tensor as you want in you task, you need to use isin function available in torch. The way it is used is given below:-
import torch
x = torch.tensor([0,0,0,1,1,2,2,2,2,3], dtype=torch.int64)
y = torch.tensor([0,2], dtype=torch.int64)
# torch.isin(x, y)
c=x[torch.isin(x,y)]
print(c)
After running this code you will get your preferred answer.
Answered By - Rishabh Pandit
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.