After reading about how to solve an ODE with neural networks following the paper Neural Ordinary Differential Equations and the blog that uses the library JAX I tried to do the same thing with "plain" Pytorch but found a point rather "obscure": How to properly use the partial derivative of a function (in this case the model) w.r.t one of the input parameters.
To resume the problem at hand as shown in 2 it is intended to solve the ODE y' = -2*x*y with the condition y(x=0) = 1 in the domain -2 <= x <= 2. Instead of using finite differences the solution is replaced by a NN as y(x) = NN(x) with a single layer with 10 nodes.
I managed to (more or less) replicate the blog with the following code
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
import numpy as np
# Define the NN model to solve the problem
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lin1 = nn.Linear(1,10)
self.lin2 = nn.Linear(10,1)
def forward(self, x):
x = torch.sigmoid(self.lin1(x))
x = torch.sigmoid(self.lin2(x))
return x
model = Model()
# Define loss_function from the Ordinary differential equation to solve
def ODE(x,y):
dydx, = torch.autograd.grad(y, x,
grad_outputs=y.data.new(y.shape).fill_(1),
create_graph=True, retain_graph=True)
eq = dydx + 2.* x * y # y' = - 2x*y
ic = model(torch.tensor([0.])) - 1. # y(x=0) = 1
return torch.mean(eq**2) + ic**2
loss_func = ODE
# Define the optimization
# opt = optim.SGD(model.parameters(), lr=0.1, momentum=0.99,nesterov=True) # Equivalent to blog
opt = optim.Adam(model.parameters(),lr=0.1,amsgrad=True) # Got faster convergence with Adam using amsgrad
# Define reference grid
x_data = torch.linspace(-2.0,2.0,401,requires_grad=True)
x_data = x_data.view(401,1) # reshaping the tensor
# Iterative learning
epochs = 1000
for epoch in range(epochs):
opt.zero_grad()
y_trial = model(x_data)
loss = loss_func(x_data, y_trial)
loss.backward()
opt.step()
if epoch % 100 == 0:
print('epoch {}, loss {}'.format(epoch, loss.item()))
# Plot Results
plt.plot(x_data.data.numpy(), np.exp(-x_data.data.numpy()**2), label='exact')
plt.plot(x_data.data.numpy(), y_data.data.numpy(), label='approx')
plt.legend()
plt.show()
From here I manage to get the results as shown in the fig. enter image description here
The problems is that at the definition of the ODE functional, instead of passing (x,y) I would rather prefer to pass something like (x,fun) (where fun is my model) such that the partial derivative and specific evaluations of the model can be done with a call . So, something like
def ODE(x,fun):
dydx, = "grad of fun w.r.t x as a function"
eq = dydx(x) + 2.* x * fun(x) # y' = - 2x*y
ic = fun( torch.tensor([0.]) ) - 1. # y(x=0) = 1
return torch.mean(eq**2) + ic**2
Any ideas? Thanks in advance
EDIT:
After some trials I found a way to pass the model as an input but found another strange behavior... The new problem is to solve the ODE y'' = -2 with the BC y(x=-2) = -1 and y(x=2) = 1, for which the analytical solution is y(x) = -x^2+x/2+4
Let's modify a bit the previous code as:
import torch
import torch.nn as nn
from torch import optim
import matplotlib.pyplot as plt
import numpy as np
# Define the NN model to solve the equation
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lin1 = nn.Linear(1,10)
self.lin2 = nn.Linear(10,1)
def forward(self, x):
y = torch.sigmoid(self.lin1(x))
z = torch.sigmoid(self.lin2(y))
return z
model = Model()
# Define loss_function from the Ordinary differential equation to solve
def ODE(x,fun):
y = fun(x)
dydx = torch.autograd.grad(y, x,
grad_outputs=y.data.new(y.shape).fill_(1),
create_graph=True, retain_graph=True)[0]
d2ydx2 = torch.autograd.grad(dydx, x,
grad_outputs=dydx.data.new(dydx.shape).fill_(1),
create_graph=True, retain_graph=True)[0]
eq = d2ydx2 + torch.tensor([ 2.]) # y'' = - 2
bc1 = fun(torch.tensor([-2.])) - torch.tensor([-1.]) # y(x=-2) = -1
bc2 = fun(torch.tensor([ 2.])) - torch.tensor([ 1.]) # y(x= 2) = 1
return torch.mean(eq**2) + bc1**2 + bc2**2
loss_func = ODE
So, here I passed the model as argument and managed to derive twice... so far so good. BUT, using the sigmoid function for this case is not only not necessary but also gives a result that is far from the analytical one.
If I change the NN for:
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.lin1 = nn.Linear(1,1)
self.lin2 = nn.Linear(1,1)
def forward(self, x):
y = self.lin1(x)
z = self.lin2(y)
return z
In which case I would expect to optimize a double pass through two linear functions that would retrieve a 2nd order function ... I get the error:
RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
Adding the option to the definition of dydx doesn't solve the problem, and adding it to d2ydx2 gives a NoneType definition.
Is there something wrong with the layers as they are?
Quick Solution:
add allow_unused=True
to .grad
functions. So, change
dydx = torch.autograd.grad(
y, x,
grad_outputs=y.data.new(y.shape).fill_(1),
create_graph=True, retain_graph=True)[0]
d2ydx2 = torch.autograd.grad(dydx, x, grad_outputs=dydx.data.new(
dydx.shape).fill_(1), create_graph=True, retain_graph=True)[0]
To
dydx = torch.autograd.grad(
y, x,
grad_outputs=y.data.new(y.shape).fill_(1),
create_graph=True, retain_graph=True, allow_unused=True)[0]
d2ydx2 = torch.autograd.grad(dydx, x, grad_outputs=dydx.data.new(
dydx.shape).fill_(1), create_graph=True, retain_graph=True, allow_unused=True)[0]
More explanation:
See what allow_unused
do:
allow_unused (bool, optional): If ``False``, specifying inputs that were not
used when computing outputs (and therefore their grad is always zero)
is an error. Defaults to ``False``.
So, if you try to differentiate w.r.t to a variable that is not in being used to compute the value, it will give an error. Also, note that error only occurs when you use linear layers.
This is because when you use linear layers, you have y=W1*W2*x + b = Wx+b
and dy/dx
is not a function of x
, it is simply W
. So when you try to differentiate dy/dx
w.r.t x
it throws an error. This error goes away as soon as you use sigmoid because then dy/dx
will be a function of x
. To avoid the error, either make sure dy/dx
is a function of x or use allow_unused=True