Issue
Say you have the following a
array
>>> a = np.arange(27).reshape((3,3,3))
>>> a
array([[[ 0, 1, 2],
[ 3, 4, 5],
[ 6, 7, 8]],
[[ 9, 10, 11],
[12, 13, 14],
[15, 16, 17]],
[[18, 19, 20],
[21, 22, 23],
[24, 25, 26]]], dtype=int64)
And m
, an array that specifies segment ids
>>> m = np.linspace(start=0, stop=6, num=27).astype(int).reshape(a.shape)
>>> m
array([[[0, 0, 0],
[0, 0, 1],
[1, 1, 1]],
[[2, 2, 2],
[2, 3, 3],
[3, 3, 3]],
[[4, 4, 4],
[4, 5, 5],
[5, 5, 6]]])
When using JAX and wishing to perform, say, a sum over the scalars in a
that share the same id in m
, we can rely on jax.ops.segment_sum
.
>>> jax.ops.segment_sum(data=a.ravel(), segment_ids=m.ravel())
Array([10, 26, 42, 75, 78, 94, 26], dtype=int64)
Note that I had to resort to numpy.ndarray.ravel
since ~.segment_sum
assumes m
to indicate the segments of data along its leading axis.
Q1 : Can you confirm there is no better approach, either with or without JAX ?
Q2 : How would one then build n
, an array that results from the replacement of the ids with the just-performed sums ? Note that I am not interested in non-vectorized approaches such as numpy.where
.
>>> n
array([[[10, 10, 10],
[10, 10, 26],
[26, 26, 26]],
[[42, 42, 42],
[42, 75, 75],
[75, 75, 75]],
[[78, 78, 78],
[78, 94, 94],
[94, 94, 26]]], dtype=int64)
Solution
The segment_sum
operation is somewhat more specialized than what you're asking about. In the case you describe, I would use ndarray.at
directly:
sums = jnp.zeros(m.max() + 1).at[m].add(a)
print(sums[m])
[[[10. 10. 10.]
[10. 10. 26.]
[26. 26. 26.]]
[[42. 42. 42.]
[42. 75. 75.]
[75. 75. 75.]]
[[78. 78. 78.]
[78. 94. 94.]
[94. 94. 26.]]]
This will also work when the segments are non-adjacent.
Answered By - jakevdp
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.