Search code examples
tensorflow2.0odetensorflow-probability

Solve ODE in tensorflow with tensor inputs


I am trying to solve many instances of the same ODE across different constants.

Here is my code:

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

class SimpleODEModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        
    def __call__(self, t_initial, x_initial, solution_times, parameters):
        with tf.GradientTape() as tape:
            tape.watch(parameters)
            solution = tfp.math.ode.BDF().solve(
                                    self.ode_system, 
                                    t_initial,
                                    x_initial,
                                    solution_times,
                                    constants={'parameters': parameters})
            tape.gradient(solution.states, parameters)
        return solution.states
    
    def ode_system(self, t, x, parameters):
        a = parameters[:, 0]
        b = parameters[:, 1]
        dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
        print(dx)
        return dx

constants = tf.constant([[1.0, 2.0],[3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
t_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
x_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
solution_times = tf.cast(tf.repeat(1.0, constants.shape[0]), dtype=tf.float32)

simple_ode = SimpleODEModule()

# This causes an error deep down int tfp.ode
# The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
simple_ode(t_initial, x_initial, solution_times, constants)

# Returns the expected output x(1.0) for each set of constants
simple_ode.ode_system(t_initial, x_initial, constants)

I am new to tensorflow, so I imagine I am not creating the correctly shaped tensors somewhere. I would expect this to "just work", iterating over the dimensions of the tensors to solve the ODE multiple times for each set of constants. Any help is appreciated.


Solution

  • I found a solution. Although I am not sure it is the best one. Instead of subclassing tf.Module I subclassed tf.keras.layers.Layer and it "just worked". Here is the change in the code:

    class ODELayer(tf.keras.layers.Layer):
        def __init__(self, num_outputs, ode_system):
            super(ODELayer, self).__init__()
            self.num_outputs = num_outputs
            self.ode_system = ode_system
    
        def call(self, input_tensor):
            return tf.map_fn(self.solve_ode, input_tensor)
        
        def solve_ode(self, parameters):
            with tf.GradientTape() as tape:
                tape.watch(parameters)
                solution = tfp.math.ode.BDF().solve(
                        self.ode_system,
                        0.0, 0.0, [1.0],
                        constants={'parameters': parameters}
                    )
                tape.gradient(solution.states, parameters)
            return solution.states
        
    def simple_ode(t, x, parameters):
        a = parameters[0]
        b = parameters[1]
        dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
        return dx
    

    Thanks to anyone who looked at this or attempted a solution.