Issue
I am trying to extract the unique values in each row of a matrix and returning them into the same matrix (with repeated values set to say, 0) For example, I would like to transform
torch.Tensor(([1, 2, 3, 4, 3, 3, 4],
[1, 6, 3, 5, 3, 5, 4]])
to
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 0, 0, 4]])
or
torch.Tensor(([1, 2, 3, 4, 0, 0, 0],
[1, 6, 3, 5, 4, 0, 0]])
I.e. the order does not matter in the rows. I have tried using pytorch.unique()
and in the documentation it is mentioned that the dimension to take the unique values can be specified with the parameter dim
. However, It doesn't seem to work for this case.
I've tried:
output= torch.unique(torch.Tensor([[4,2,52,2,2],[5,2,6,6,5]]), dim = 1)
output
Which gives
tensor([[ 2., 2., 2., 4., 52.],
[ 2., 5., 6., 5., 6.]])
Does anyone have a particular fix for this? If possible, I'm trying to avoid for loops.
Solution
One must admit the unique
function can sometimes be very confusing without given proper examples and explanations.
The dim
parameter specifies which dimension on the matrix tensor you want to apply on.
For instance, in a 2D matrix, dim=0
will let operation perform vertically where dim=1
means horizontally.
Example, let's consider a 4x4 matrix with dim=1
. As you can see from my code below, the unique
operation is applied row by row.
You notice the double occurrence of the number 11
in the first and last row. Numpy and Torch does this to preserve the shape of the final matrix.
However, if you do not specify any dimension, torch will automatically flatten your matrix and then apply unique
to it and you will get a 1D array that contains unique data.
import torch
m = torch.Tensor([
[11, 11, 12,11],
[13, 11, 12,11],
[16, 11, 12, 11],
[11, 11, 12, 11]
])
output, indices = torch.unique(m, sorted=True, return_inverse=True, dim=1)
print("Ori \n{}".format(m.numpy()))
print("Sorted \n{}".format(output.numpy()))
print("Indices \n{}".format(indices.numpy()))
# without specifying dimension
output, indices = torch.unique(m, sorted=True, return_inverse=True)
print("Sorted (no dim) \n{}".format(output.numpy()))
Result (dim=1)
Ori
[[11. 11. 12. 11.]
[13. 11. 12. 11.]
[16. 11. 12. 11.]
[11. 11. 12. 11.]]
Sorted
[[11. 11. 12.]
[11. 13. 12.]
[11. 16. 12.]
[11. 11. 12.]]
Indices
[1 0 2 0]
Result (no dimension)
Sorted (no dim)
[11. 12. 13. 16.]
Answered By - Rex Low
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.