Issue
I would like to get a max/min value in tf.math.bincount instead of the weight sum. Basically currently it works as:
values = tf.constant([1,1,2,3,2,4,4,5])
weights = tf.constant([1,5,0,1,0,5,4,5])
tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5]
However, I would like to get max/min for the conflicting weights instead, e.g. for max it should return:
[0 5 0 1 5 5]
Solution
It requires some finessing, but you can accomplish this as follows:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
return tf.map_fn(lambda x: tf.maximum(
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), 0), _range)
The output for the example case is:
[0 5 0 1 5 5]
Breaking it down, the first line computes the range of values in values:
_range = tf.range(tf.reduce_max(values) + 1)
and in the second line, the maximum of weight is computed per element in _range using tf.map_fn with tf.where, which retrieves indices where the clause is true, and tf.gather, which retrieves the values corresponding to supplied indices.
The tf.maximum wraps the output to handle the case where the element does not exist in values i.e; in the example case, 0 does not exist in values so the output without tf.maximum would be INT_MIN for 0:
[-2147483648 5 0 1 5 5]
This could also be applied on the final result tensor instead of per element:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
result = tf.map_fn(lambda x:
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range)
return tf.maximum(result, 0)
Note that this would not work if negative weights are utilized - in that case, tf.where can be used for comparing against the minimum integer value (tf.int32.min in the example, although this can be applied for any numeric dtype) instead of applying tf.maximum:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
result = tf.map_fn(lambda x:
tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range)
return tf.where(tf.equal(result, tf.int32.min), 0, result)
Update
For handling the 2D Tensor case, we can use tf.map_fn to apply the maximum weight function to each pair of values and weights in the batch:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor, axis: Optional[int] = None) -> tf.Tensor:
_range = tf.range(tf.reduce_max(values) + 1)
def mapping_function(x: int, _values: tf.Tensor, _weights: tf.Tensor) -> tf.Tensor:
return tf.reduce_max(tf.gather(_weights, tf.where(tf.equal(_values, x))))
if axis == -1:
result = tf.map_fn(lambda pair: tf.map_fn(lambda x: mapping_function(x, *pair), _range), (values, weights),
dtype=tf.int32)
else:
result = tf.map_fn(lambda x: mapping_function(x, values, weights), _range)
return tf.where(tf.equal(result, tf.int32.min), 0, result)
For the 2D example provided:
values = tf.constant([[1, 1, 2, 3], [2, 1, 4, 5]])
weights = tf.constant([[1, 5, 0, 1], [0, 5, 4, 5]])
print(bincount_with_max_weight(values, weights, axis=-1))
The output is:
tf.Tensor(
[[0 5 0 1 0 0]
[0 5 0 0 4 5]], shape=(2, 6), dtype=int32)
This implementation is a generalization of the approach originally described - if axis is omitted, it will compute results for the 1D case.
Answered By - danielcahall
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.