Issue
I am trying to slice signal_matrix by in_inds, but uint8 indexing is a different i guess. Can someone explain how it suppose to work?
signal_matrix = torch.tensor(
[[0, 0, 1, 1],
[0, 0, 0, 0],
[0, 0, 1, 1],
[0, 1, 1, 0],
[1, 0, 1, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0],
[0, 0, 0, 0]], dtype=torch.uint8)
in_inds = torch.tensor(
[[ 0, 2, 3],
[ 1, 2, 4],
[ 0, 0, 0]][::-1], dtype= torch.uint8
)
out_inds = torch.tensor(
[ 5, 6, 7], dtype= torch.uint8)
op_inds = torch.tensor(
[ [1, 1, 1, 1],
[0, 0, 0, 0],
[2, 2, 2, 2]], dtype= torch.uint8)
in_signals = signal_matrix[in_inds]
IndexError: The shape of the mask [3, 3] at index 0 does not match the shape of the indexed tensor [9, 4] at index 0
In general, I expected something like this. And the same result is obtained with int32.
tensor([[[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 0, 1, 1]],
[[0, 0, 0, 0],
[0, 0, 1, 1],
[1, 0, 1, 0]],
[[0, 0, 1, 1],
[0, 0, 1, 1],
[0, 1, 1, 0]]], dtype=torch.int32)
Solution
Here's the deal:
When using uint8
as an index, it's interpreted as masking instead of indexing. That is, you'll get a new tensor containing the values of the previous tensor at those locations in which your mask has positive values.
Note that using uint8
for indexing (which is actually masking) is deprecated and should not be done.
Concretely:
data = torch.arange(9).reshape((3, 3))
# tensor([[0, 1, 2],
# [3, 4, 5],
# [6, 7, 8]])
in_inds = torch.tensor(
[[ 0, 2, 3],
[ 1, 2, 4],
[ 0, 0, 0]][::-1], dtype= torch.uint8)
torch.allclose(data[in_inds], data[in_inds > 0]) # True
# tensor([3, 4, 5, 7, 8])
On the other hand, you can perform actual indexing by using a int
or a 'long' tensor as an index:
data[0] # tensor([0, 1, 2])
data[2] # tensor([6, 7, 8])
data[torch.tensor([2, 2, 0], dtype=torch.int)]
# tensor([[6, 7, 8],
# [6, 7, 8],
# [0, 1, 2]])
Answered By - Yakov Dan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.