Issue
I have some code that finds the indices of elements in the "input" list that match any element in the "values" list. The indices are then outputted, arranged in the same order as the "values" list.
input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]
match_index_lst, match_index_values = np.where(np.array(input) == np.array(values)[:,None])
output_indice_lst = match_index_values[np.argsort(match_index_lst)]
# [5, 2, 4]
My question, is whether it's possible to efficiently(using vectorized operations) expand this code to use it in specific multidimensional lists? Currently the input list is of dimension c
, but I will be having inputs with dimensions of [a, b, c]
. So instead of something like:
input = [1, 2, 8, 7, 3, 4, 6, 5, 9]
values = [4, 8, 3]
# output: [5, 2, 4]
I will have something like
input = [[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35 ]],
[[ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.88, 0.51 ]]],
[[[ 1.14, 0.48, 1.09, 0.93, 0.47, 0.13, 0.75, 0.19 ]],
[[ 1.15, 0.17, 2.33, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]]
values = [[[[ 0.54, 1.58 ]],
[[ 0.77, 0.88 ]]],
[[[ 0.48, 1.09 ]],
[[ 2.60, 2.33 ]]]]
# output: [[[[ 5, 2 ]],
# [[ 0, 6 ]]],
#
# [[[ 1, 2 ]],
# [[ 5, 2 ]]]]
My specific example is of size (2, 2, 8)
, but it could be any (a,b,c)
size.
I've tried flattening it and then performing operations on it, but I just can't seem to get the order right after unflattening it, and then getting the output properly formatted is also a nightmare. I can see how it would be pretty easy to implement with for loops, but I want to keep that as a last resort as speed is crucial.
Solution
I think your idea with using the flattened index was right. Here is what that would look like:
import numpy as np
input = np.array([[[[ 0.31, 1.56, 1.58, 0.16, 0.22, 0.54, 0.98, 0.35],
[ 0.77, 2.62, 0.44, 0.08, 0.76, 0.87, 0.87, 0.51]],
[[ 1.14, 0.48, 1.08, 0.93, 0.47, 0.13, 0.75, 0.19 ],
[ 1.15, 0.17, 2.32, 0.46, 0.30, 2.60, 0.79, 1.07 ]]]])
values = np.array([[[[ 0.54, 1.58 ]],
[[ 0.77, 0.88 ]]],
[[[ 0.48, 1.09 ]],
[[ 2.60, 2.33 ]]]])
sort_idx = np.argsort(input.flat)
output_flat = sort_idx[np.searchsorted(input.flat, values.flat, sorter=sort_idx)]
output = np.unravel_index(output_flat.reshape(values.shape), input.shape)[-1]
print(output)
Which prints:
[[[[5 2]]
[[0 3]]]
[[[1 0]]
[[5 5]]]]
(please check your reference output it does not seem right to me, especially there is the question, what should happen I the value is not in the index)
The key missing part was the np.unravel_index()
and reshaping to the values.shape
.
I hope this helps!
Answered By - Axel Donath
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.