Issue
As a minimal example, say I have a tensor of the form:
[[ 1. 0. 3. ]
[ 7. 5. 6. ]
[ 0. 0. 0. ]
[ 0. 11. 1. ]
[ 0. 0. 0. ]
[ 0. 0. 0. ]
[13. 14. 16.5]]
Is there a way (natively in tensorflow) to impute the fully zeroed rows such their values are assigned to be equal to the last non-fully zeroed row? I.e. ->:
[[ 1. 0. 3. ]
[ 7. 5. 6. ]
[ 7. 5. 6. ]
[ 0. 11. 1. ]
[ 0. 11. 1. ]
[ 0. 11. 1. ]
[13. 14. 16.5]]
I thought about using tf.tensor_scatter_nd_update
but with no success.
Solution
We can use tf.gather(a, indices)
to get the above output.
The indices
needs to be [0, 1, 1, 3, 3, 3, 6]
which can be be obtained with the following code:
mask = tf.cast(tf.cast(tf.reduce_sum(a, axis=1), dtype=tf.bool), tf.float32)
#[1., 1., 0., 1., 0., 0., 1.] where non-zero sum
mask_range = (mask*tf.range(a.shape[0], dtype=tf.float32))
#[0., 1., 0., 3., 0., 0., 6.] apply mask on range()
indices =tf.cast(tf.scan(lambda a, b: tf.maximum(a, b), mask_range, initializer=tf.reduce_min(mask_range)), tf.int32)
# cumulative max [0, 1, 1, 3, 3, 3, 6]
tf.gather(a, indices)
[[ 1. , 0. , 3. ], [ 7. , 5. , 6. ], [ 7. , 5. , 6. ], [ 0. , 11. , 1. ], [ 0. , 11. , 1. ], [ 0. , 11. , 1. ], [13. , 14. , 16.5]]
Answered By - V.M
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.