Issue
I have two tensores, tensor a and tensor b.
I want to get all indexes of values in tensor b.
For example.
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
I want the index of 1, 2, 4
in tensor a. I can do this by the following code.
a = torch.Tensor([1,2,2,3,4,4,4,5])
b = torch.Tensor([1,2,4])
mask = torch.zeros(a.shape).type(torch.bool)
print(mask)
for e in b:
mask = mask + (a == e)
print(mask)
How can I do it without for
?
Solution
Update:
As @zaydh kindly pointed out in the comments, since PyTorch 1.10
, isin()
and isinf()
(and many other numpy equivalents) are available as well, thus you can simply do:
torch.isin(a, b)
which would give you :
Out[4]: tensor([ True, True, True, False, True, True, True, False])
Old answer:
Is this what you want? :
np.in1d(a.numpy(), b.numpy())
will result in :
array([ True, True, True, False, True, True, True, False])
Answered By - Hossein
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.