Issue
I want to get surprisal values from logit outputs from PyTorch, using log base 2.
One way to do this, given a logits tensor, is:
probs = nn.functional.softmax(logits, dim = 2)
surprisals = -torch.log2(probs)
However, PyTorch provides a function that combines log and softmax, which is faster than the above:
surprisals = -nn.functional.log_softmax(logits, dim = 2)
But this seems to return values in base e, which I don't want. Is there a function like log_softmax
, but which uses base 2? I have tried log2_softmax
and log_softmax2
, neither of which seems to work, and haven't had any luck finding documentation online.
Solution
How about just using the fact that logarithm bases can be easily altered by the following mathematical identity
is what
F.log_softmax()
is giving you. All you need to do is
surprisals = - (1 / torch.log(2.)) * nn.functional.log_softmax(logits, dim = 2)
Its just a scalar multiplication. So, it hardly has any performance penalty.
Answered By - ayandas
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.