Issue
I need to sort a batch of 2d matrice's rows by the key value of the first column:
original batch matrices(3d tensor):
torch.tensor([[[2, 0],
[0, 1],
[1, 2]],
[[1, 2],
[0, 0],
[2, 1]]])
desired tensor:
torch.tensor([[[0, 1],
[1, 2],
[2, 0]],
[[0, 0],
[1, 2],
[2, 1]]])
Already known how to handle one of the batch, and another answer solve the problem by for loop, which is not parallel. So how to handle the whole batch parallelly?
Solution
This can be result a bit confusing but makes sense:
(my_tensor[:,torch.argsort(my_tensor[:,:,0], dim=1)])\
[torch.arange(len(my_tensor)),torch.arange(len(my_tensor))]
I the first line you extract the sorting tensor thought torch.argsort
and apply it to my_tensor
, resulting in a (2, 2, 3, 2)
shape tensor. Since you want each element to be sorted only according to its first column, you are only interested in the diagonal of first two dimensions and you can extract it by slicing (second line of code).
Answered By - Salvatore Daniele Bianco
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.