Search code examples
pytorchdifferential-equationsautogradcomputation-graph

PyTorch "Double backward error" occurs when using the package DeepXDE with a trainable variable included in the initial condition


Kindly ask for any ideas about how I can debug the following PyTorch error. Here's an example adapted from the official example about the inverse modeling of the reaction-diffusion system:

"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import deepxde as dde
import numpy as np
import torch

def gen_traindata():
    data = np.load("reaction.npz")
    t, x, ca, cb = data["t"], data["x"], data["Ca"], data["Cb"]
    X, T = np.meshgrid(x, t)
    X = np.reshape(X, (-1, 1))
    T = np.reshape(T, (-1, 1))
    Ca = np.reshape(ca, (-1, 1))
    Cb = np.reshape(cb, (-1, 1))
    return np.hstack((X, T)), Ca, Cb


kf = dde.Variable(0.05)
D = dde.Variable(1.0)

def pde(x, y):
    ca, cb = y[:, 0:1], y[:, 1:2]
    dca_t = dde.grad.jacobian(y, x, i=0, j=1)
    dca_xx = dde.grad.hessian(y, x, component=0, i=0, j=0)
    dcb_t = dde.grad.jacobian(y, x, i=1, j=1)
    dcb_xx = dde.grad.hessian(y, x, component=1, i=0, j=0)
    eq_a = dca_t - 1e-3 * D * dca_xx + kf * ca * cb ** 2
    eq_b = dcb_t - 1e-3 * D * dcb_xx + 2 * kf * ca * cb ** 2
    return [eq_a, eq_b]


def fun_bc(x):
    return 1 - x[:, 0:1]

###### main changes start from here ##########

# def fun_init(x):
#     return np.exp(-20 * x[:, 0:1])

x0 = dde.Variable(0.2)
def fun_init(x):
    return torch.exp(-10 * (dde.backend.as_tensor(x[:, 0:1]) - x0)**2)

###### main changes end here ##########


geom = dde.geometry.Interval(0, 1)
timedomain = dde.geometry.TimeDomain(0, 10)
geomtime = dde.geometry.GeometryXTime(geom, timedomain)

bc_a = dde.icbc.DirichletBC(
    geomtime, fun_bc, lambda _, on_boundary: on_boundary, component=0
)
bc_b = dde.icbc.DirichletBC(
    geomtime, fun_bc, lambda _, on_boundary: on_boundary, component=1
)
ic1 = dde.icbc.IC(geomtime, fun_init, lambda _,
                  on_initial: on_initial, component=0)
ic2 = dde.icbc.IC(geomtime, fun_init, lambda _,
                  on_initial: on_initial, component=1)

observe_x, Ca, Cb = gen_traindata()

# simplification
print(observe_x.shape, Ca.shape, Cb.shape)
observe_x = observe_x[::20, :]
Ca = Ca[::20, :]
Cb = Cb[::20, :]
#--

observe_y1 = dde.icbc.PointSetBC(observe_x, Ca, component=0)
observe_y2 = dde.icbc.PointSetBC(observe_x, Cb, component=1)

data = dde.data.TimePDE(
    geomtime,
    pde,
    [bc_a, bc_b, ic1, ic2, observe_y1, observe_y2],
    num_domain=1000,
    num_boundary=100,
    num_initial=100,
    # anchors=observe_x,
    num_test=1000,
)
net = dde.nn.FNN([2] + [20] * 2 + [2], "tanh", "Glorot uniform")

model = dde.Model(data, net)
model.compile("adam", lr=0.001, external_trainable_variables=[kf, D, x0])
variable = dde.callbacks.VariableValue([kf, D, x0], period=100)
losshistory, train_state = model.train(iterations=200, callbacks=[variable])
# dde.saveplot(losshistory, train_state, issave=True, isplot=True)

I ran this example on a Ubuntu 22.04 computer with python 3.10, deepxde 1.11.0 and torch 2.2.2+cu121. The error messages I got are

Traceback (most recent call last):
  File "~/.../test-unknowninit.py", line 87, in <module>
    losshistory, train_state = model.train(iterations=200, callbacks=[variable])
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/utils/internal.py", line 22, in wrapper
    result = f(*args, **kwargs)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/model.py", line 650, in train
    self._train_sgd(iterations, display_every)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/model.py", line 668, in _train_sgd
    self._train_step(
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/model.py", line 562, in _train_step
    self.train_step(inputs, targets, auxiliary_vars)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/model.py", line 362, in train_step
    self.opt.step(closure)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/torch/optim/adam.py", line 146, in step
    loss = closure()
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/deepxde/model.py", line 359, in closure
    total_loss.backward()
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/torch/_tensor.py", line 522, in backward
    torch.autograd.backward(
  File "~/miniforge3/envs/de/lib/python3.10/site-packages/torch/autograd/__init__.py", line 266, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Specifically, the error occurs at total_loss.backward() during the 2nd iteration in train_step (in model.py).

Some detailed explanations of the codes: The difference from the official example is that I include a trainable variable x0 in the fun_init used to create the initial conditions ic1, ic2. Specifically, the function is changed from

def fun_init(x):
    return np.exp(-20 * x[:, 0:1])

to

def fun_init(x):
    return torch.exp(-10 * (dde.backend.as_tensor(x[:, 0:1]) - x0)**2)

which assumes that the initial distribution is centered at an unknown . Noticing that the input of fun_init will be a Numpy array, I change np to torch and add a dde.backend.as_tensor such that the computations are compatible with the torch.Tensor variable x0.


Solution

  • The problem stems from deepxde trying to use cached initial/boundary values instead of newly calculated ones. A detailed explanation has been posted in the discussion panel of deepxde.