Issue
In the following code, I want to compare the loss
(data type at::Tensor
) with a lossThreshold
(data type double
). I want to convert loss
to double
before making that comparison. How do I do it?
int main() {
auto const input1(torch::randn({28*28});
auto const input2(torch::randn({28*28});
double const lossThreshold{0.05};
auto const loss{torch::nn::functional::mse_loss(input1, input2)}; // this returns an at::Tensor datatype
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
Solution
Thanks to GitHub CoPilot which recommended this solution. I guess I should leave my job now. :(
The solution is using the item<T>()
template function as follows:
int main() {
auto const input1(torch::randn({28*28}); // at::Tensor
auto const input2(torch::randn({28*28}); // at::Tensor
double const lossThreshold{0.05}; // double
auto const loss{torch::nn::functional::mse_loss(input1, input2).item<double>()}; // the item<double>() converts at::Tensor to double
return loss > lossThreshold ? EXIT_FAILURE : EXIT_SUCCESS;
}
Answered By - Raashid
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.