Issue
I tried to use torch.normal but got the error that shows std >= 0.0; I need to fix this error.
b=32
n_s = 10
dim = 64
slots_mu = nn.Parameter(torch.randn(1, 1, dim))
slots_log_sigma = nn.Parameter(torch.randn(1, 1, dim))
mu = slots_mu.expand(b, n_s, -1)
sigma = slots_log_sigma.expand(b, n_s, -1)
slots = torch.normal(mu, sigma)
and it raised an error below
---> 10 slots = torch.normal(mu, sigma)
RuntimeError: normal expects all elements of std >= 0.0
Solution
It's because of the definition of standard deviation, std is for distance measurement, please take a look at this answer.
To solve your problem, converting std
to absolute value can help:
slots_log_sigma = nn.Parameter(abs(torch.randn(1, 1, dim)))
Answered By - CuCaRot
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.