Issue
I have written some code with scipy.optimize.minimize
using the LBFGS algorithm. Now I want to implement the same with PyTorch.
SciPy:
res = minimize(calc_cost, x_0, args = const_data, method='L-BFGS-B', jac=calc_grad)
def calc_cost(x, const_data):
# do some calculations with array "calculation" as result
return np.sum(np.square(calculation)) #this returns a scalar!
def calc_grad(x, const_data):
# do some calculations which result in array "calculation"
return np.ravel(calculation) #in PyTorch this returns without ravel!
Now in PyTorch I am following this example. However, I want to use my own gradient calculation. This results in the error RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([3, 200, 200]) and output[0] has a shape of torch.Size([]).
I understand that the shape/size of my gradient should be the same as the objective function (i.e. here a scalar), but this is not what I need (see above). How do I adapt the following code in a way that it does the same calculations as the SciPy version:
optimizer = optim.LBFGS([x_0], history_size=10, max_iter=10, line_search_fn="strong_wolfe")
h_lbfgs = []
for i in range(10):
optimizer.zero_grad()
objective = calc_cost(x_0, const_data)
objective.backward(gradient = calc_gradient(x_0, const_data))
optimizer.step(lambda: calc_cost(x_0, const_data))
h_lbfgs.append(objective.item())
I have had a look at the PyTorch docs already but don't quite understand how they apply here:
- https://pytorch.org/docs/stable/generated/torch.optim.LBFGS.html
- https://pytorch.org/docs/stable/optim.html#optimizer-step-closure
Solution
The problem is that I was using the wrong "objective" function. What I am trying to optimize is the x_0
array, therefore I had to alter my code as follows:
for i in range(10):
optimizer.zero_grad()
x_0.backward(gradient = calc_gradient(x_0, const_data))
optimizer.step(lambda: calc_cost(x_0, const_data))
h_lbfgs.append(objective.item())
Answered By - X_841
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.