Search code examples
pythontensorflowtensorflow-probability

performing many gradient-based optimizations in parallel with TensorFlow


I have a model which requires solving a system of ODEs with tfp.math.ode.BDF, and I would like to find the individual least-squares fits of this model to n > 1000 datasets. That is to say, if my model has m parameters then at the end of the optimization process I will have an n by m tensor of best-fit parameter values.

What would be the best way to perform this optimization in parallel? At this point I’m planning to define an objective function that adds up the n individual sums of square residuals, and then uses tfp.optimizer.lbfgs_minimize to find the best-fit values of the combined n×m parameters.


Solution

  • This won't be super helpful, and I don't know what ODEs are, but if I were attempting to optimizing in parallel ANY set of n datasets against n models (be it linear regression or deep learning or any other model that's optimized via gradients), I'd simply instantiate n copies of the model, forward prop all n datasets against a corresponding model, and critically, sum up the loss. Automatic differentiation will take care of the rest.

    The n models do not interact with each other, other than the final summation of loss, so auto diff will keep the gradients isolated as well. I haven't used l-bfgs in a while but IIRC it's like SGD except it factors in the second order partial derivative, so it's probably implemented with a gradient tape inside a gradient tape.

    Here is a minimal example of constructing a single tensorflow model that contains n=2 models. Each model contains m=2 weights (namely, m and b from the y=mx+b equation). You train on n datasets which have a 1-to-1 relationship with the n models. You will get n by m weights out of it.

    import tensorflow as tf
    
    
    def get_model():
      """Create linear regression model of the form y=mx+b."""
      model = tf.keras.Sequential()
      model.add(tf.keras.layers.Dense(1, input_shape=(1,)))
      return model
    
    # Setup two different datasets. 
    x1 = tf.range(start=-1., limit=1., delta=0.001)
    y1 = x1 * .1 + .2
    
    x2 = tf.range(start=-1., limit=1., delta=0.001)
    y2 = x2 * .3 + .4
    
    
    def create_parallel_model():
      """Creates a single model that hold multiple independent models.
      
      Returns:
        parallel_model: The single parallel model holding multiple models. 
        model1: First componenent model. 
        model2: Second componenent model. 
        """
      model1 = get_model()
      model2 = get_model()
    
      inp1 = tf.keras.Input((1,), name='inp1')
      inp2 = tf.keras.Input((1,), name='inp2')
    
      outp1 = model1(inp1)
      outp2 = model2(inp2)
    
      parallel_model = tf.keras.Model([inp1, inp2], [outp1, outp2])
    
      return parallel_model, model1, model2
    
    
    pmodel, model1, model2 = create_parallel_model()
    pmodel.compile('sgd', 'mse')
    pmodel.fit(x=[x1, x2], y=[y1, y2], epochs=100)
    
    print("First model's dense layer, m and b should be 0.1 and 0.2.")
    print(model1.layers[0].get_weights())
    
    print("Second model's dense layer, m and b should be 0.3 and 0.4.")
    print(model2.layers[0].get_weights())
    

    Your output should look like

    First model's dense layer, m and b should be 0.1 and 0.2.
    [array([[0.10000037]], dtype=float32), array([0.20000002], dtype=float32)]
    Second model's dense layer, m and b should be 0.3 and 0.4.
    [array([[0.30000147]], dtype=float32), array([0.39999998], dtype=float32)]
    

    Another way to do this is to simply train one model and dataset at a time within a single process, but run multiple processes in parallel. If you use TF's dynamic memory allocation, TF won't hog up all the GPU ram and Nvidia GPU's do support multi-processing (https://docs.nvidia.com/deploy/mps/index.html).

    physical_devices = tf.config.list_physical_devices('GPU') 
    for gpu_instance in physical_devices: 
        tf.config.experimental.set_memory_growth(gpu_instance, True)