Search code examples
openmdao

Optimizing Distributed I/O with serial output


I am having trouble understanding how to optimize a distributed component with a serial output. This is my attempt with an example problem given in the openmdao docs.

import numpy as np

import openmdao.api as om
from openmdao.utils.array_utils import evenly_distrib_idxs
from openmdao.utils.mpi import MPI


class MixedDistrib2(om.ExplicitComponent):

    def setup(self):
        # Distributed Input
        self.add_input('in_dist', shape_by_conn=True, distributed=True)
        # Serial Input
        self.add_input('in_serial', val=1)
        # Distributed Output
        self.add_output('out_dist', copy_shape='in_dist', distributed=True)
        # Serial Output
        self.add_output('out_serial', copy_shape='in_serial')
        #self.declare_partials('*','*', method='cs')

    def compute(self, inputs, outputs):
        x = inputs['in_dist']
        y = inputs['in_serial']
        # "Computationally Intensive" operation that we wish to parallelize.
        f_x = x**2 - 2.0*x + 4.0
        # These operations are repeated on all procs.
        f_y = y ** 0.5
        g_y = y**2 + 3.0*y - 5.0
        # Compute square root of our portion of the distributed input.
        g_x = x ** 0.5
        # Distributed output
        outputs['out_dist'] = f_x + f_y
        # Serial output
        if MPI and comm.size > 1:
            # We need to gather the summed values to compute the total sum over all procs.
            local_sum = np.array(np.sum(g_x))
            total_sum = local_sum.copy()
            self.comm.Allreduce(local_sum, total_sum, op=MPI.SUM)
            outputs['out_serial'] = g_y * total_sum
        else:
            # Recommended to make sure your code can run in serial too, for testing.
            outputs['out_serial'] = g_y * np.sum(g_x)

size = 7
if MPI:
    comm = MPI.COMM_WORLD
    rank = comm.rank
    sizes, offsets = evenly_distrib_idxs(comm.size, size)
else:
    # When running in serial, the entire variable is on rank 0.
    rank = 0
    sizes = {rank : size}
    offsets = {rank : 0}

prob = om.Problem()
model = prob.model

# Create a distributed source for the distributed input.
ivc = om.IndepVarComp()
ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
ivc.add_output('x_serial', val=1)

model.add_subsystem("indep", ivc)
model.add_subsystem("D1", MixedDistrib2())
model.add_subsystem('con_cmp1', om.ExecComp('con1 = y**2'), promotes=['con1', 'y'])

model.connect('indep.x_dist', 'D1.in_dist')
model.connect('indep.x_serial', ['D1.in_serial','y'])

prob.driver = om.ScipyOptimizeDriver()
prob.driver.options['optimizer'] = 'SLSQP'

model.add_design_var('indep.x_serial', lower=5, upper=10)
model.add_constraint('con1', upper=90)

model.add_objective('D1.out_serial')

prob.setup(force_alloc_complex=True)
#prob.setup()

# Set initial values of distributed variable.
x_dist_init = [1,1,1,1,1,1,1]
prob.set_val('indep.x_dist', x_dist_init)

# Set initial values of serial variable.
prob.set_val('indep.x_serial', 10)

#prob.run_model()

prob.run_driver()
print('x_dist', prob.get_val('indep.x_dist', get_remote=True))
print('x_serial', prob.get_val('indep.x_serial'))
print('Obj', prob.get_val('D1.out_serial'))

The problem is with defining partials with 'fd' or 'cs'. I cannot define partials of serial output w.r.t distributed input. So I used prob.setup(force_alloc_complex=True) to use complex step. But gives me this warning DerivativesWarning:Constraints or objectives [('D1.out_serial', inds=[0])] cannot be impacted by the design variables of the problem. I understand this is because the total derivative is 0 which causes the warning but I dont understand the reason. Clearly the total derivative should not be 0 here. But I guess this is because I didn't explicitly declare_partials in the component. I tried removing the distributed components and ran it again with declare_partials and this works correctly(code below).

import numpy as np

import openmdao.api as om


class MixedDistrib2(om.ExplicitComponent):

    def setup(self):

        self.add_input('in_dist', np.zeros(7))
        self.add_input('in_serial', val=1)

        self.add_output('out_serial', val=0)
        self.declare_partials('*','*', method='cs')

    def compute(self, inputs, outputs):
        x = inputs['in_dist']
        y = inputs['in_serial']

        g_y = y**2 + 3.0*y - 5.0
        g_x = x ** 0.5

        outputs['out_serial'] = g_y * np.sum(g_x)    

prob = om.Problem()
model = prob.model

model.add_subsystem("D1", MixedDistrib2(), promotes_inputs=['in_dist', 'in_serial'], promotes_outputs=['out_serial'])
model.add_subsystem('con_cmp1', om.ExecComp('con1 = in_serial**2'), promotes=['con1', 'in_serial'])

prob.driver = om.ScipyOptimizeDriver()
prob.driver.options['optimizer'] = 'SLSQP'

model.add_design_var('in_serial', lower=5, upper=10)
model.add_constraint('con1', upper=90)

model.add_objective('out_serial')

prob.setup(force_alloc_complex=True)

prob.set_val('in_dist', [1,1,1,1,1,1,1])
prob.set_val('in_serial', 10)

prob.run_model()
prob.check_totals()

prob.run_driver()

print('x_dist', prob.get_val('in_dist', get_remote=True))
print('x_serial', prob.get_val('in_serial'))
print('Obj', prob.get_val('out_serial'))

What I am trying to understand is

  1. How to use 'fd' or 'cs' in Distributed component with a serial output?
  2. What is the meaning of prob.setup(force_alloc_complex=True) ? Is not forcing to use cs in all the components in the problem ? If so why does the total derivative becomes 0?

Solution

  • When I run your code in OpenMDAO V 3.11.0 (after uncommenting the declare_partials call) I get the following error:

    RuntimeError: 'D1' <class MixedDistrib2>: component has defined partial ('out_serial', 'in_dist') which is a serial output wrt a distributed input. This is only supported using the matrix free API.
    

    As the error indicates, you can't use the matrix-based api for derivatives in this situations. The reasons why are a bit subtle, and probably outside the scope of what needs to be delt with to answer your question here. It boils down to OpenMDAO not knowing why kind of distributed operations are being done in the compute and having no way to manage those details when you propagate things in reverse.

    So you need to use the matrix-free derivative APIs in this situation. When you use the matrix-free APIs you DO NOT declare any partials, because you don't want OpenMDAO to allocate any memory for you to store partials in (and you wouldn't use that memory even if it did).

    I've coded them for your example here, but I need to note a few important details:

    1. Your example has a distributed IVC, but as of OpenMDAO V3.11.0 you can't get total derivatives with respect to distributed design variables. I assume you just made it that way to make your simple test case, but in case your real problem was set up this way, you need to note this and not do it this way. Instead, make the IVC serial, and use src indices to distribute the correct parts to each proc.
    2. In the example below, the derivatives are correct. However, there seems to be a bug in the check_partials output when running in paralle. So the reverse mode partials look like they are off by a factor of the comm size... this will have to get fixed in later releases.
    3. I only did the derivatives for out_serial. out_dist will work similarly and is left as an excersize for the reader :)
    4. You'll notice that I duplicates some code in the compute and compute_jacvec_product methods. You can abstract this duplicate code out into its own method (or call compute from within compute_jacvec_product by providing your own output dictionary). However, you might be asking why the duplicate call is needed at all? Why can't u store the values from the compute call. The answer is, in large part, that OpenMDAO does not guarantee that compute is always called before compute_jacvec_product. However, I'll also point out that this kind of code duplication is very AD-like. Any AD code will have the same kind of duplication built in, even though you don't see it.
    import numpy as np
    
    import openmdao.api as om
    from openmdao.utils.array_utils import evenly_distrib_idxs
    from openmdao.utils.mpi import MPI
    
    
    class MixedDistrib2(om.ExplicitComponent):
    
        def setup(self):
            # Distributed Input
            self.add_input('in_dist', shape_by_conn=True, distributed=True)
            # Serial Input
            self.add_input('in_serial', val=1)
            # Distributed Output
            self.add_output('out_dist', copy_shape='in_dist', distributed=True)
            # Serial Output
            self.add_output('out_serial', copy_shape='in_serial')
    
            # self.declare_partials('*','*', method='fd')
    
        def compute(self, inputs, outputs):
            x = inputs['in_dist']
            y = inputs['in_serial']
            # "Computationally Intensive" operation that we wish to parallelize.
            f_x = x**2 - 2.0*x + 4.0 
            # These operations are repeated on all procs.
            f_y = y ** 0.5
            g_y = y**2 + 3.0*y - 5.0
            # Compute square root of our portion of the distributed input.
            g_x = x ** 0.5
            # Distributed output
            outputs['out_dist'] = f_x + f_y
            # Serial output
            if MPI and comm.size > 1:
                # We need to gather the summed values to compute the total sum over all procs.
                local_sum = np.array(np.sum(g_x))
                total_sum = local_sum.copy()
                self.comm.Allreduce(local_sum, total_sum, op=MPI.SUM)
                outputs['out_serial'] = g_y * total_sum
            else:
                # Recommended to make sure your code can run in serial too, for testing.
                outputs['out_serial'] = g_y * np.sum(g_x)
    
    
        def compute_jacvec_product(self, inputs, d_inputs, d_outputs, mode):
    
            x = inputs['in_dist']
            y = inputs['in_serial']
    
            g_y = y**2 + 3.0*y - 5.0
    
            # "Computationally Intensive" operation that we wish to parallelize.
            f_x = x**2 - 2.0*x + 4.0 
            # These operations are repeated on all procs.
            f_y = y ** 0.5
            g_y = y**2 + 3.0*y - 5.0
            # Compute square root of our portion of the distributed input.
            g_x = x ** 0.5
            # Distributed output
            out_dist = f_x + f_y
    
            # Serial output
            if MPI and comm.size > 1:
                # We need to gather the summed values to compute the total sum over all procs.
                local_sum = np.array(np.sum(g_x))
                total_sum = local_sum.copy()
                self.comm.Allreduce(local_sum, total_sum, op=MPI.SUM)
                # total_sum
            else:
                # Recommended to make sure your code can run in serial too, for testing.
                total_sum = np.sum(g_x)
    
            num_x = len(x)
    
            d_f_x__d_x = np.diag(2*x - 2.)
            d_f_y__d_y = np.ones(num_x)*0.5*y**-0.5
    
            d_g_y__d_y = 2*y + 3.
            d_g_x__d_x = 0.5*x**-0.5
    
            d_out_dist__d_x =  d_f_x__d_x # square matrix
            d_out_dist__d_y =  d_f_y__d_y # num_x,1
    
            d_out_serial__d_y =  d_g_y__d_y # scalar
            d_out_serial__d_x =  g_y*d_g_x__d_x.reshape((1,num_x))
    
            if mode == 'fwd':
                if 'out_serial' in d_outputs:
                    if 'in_dist' in d_inputs:
                        d_outputs['out_serial'] += d_out_serial__d_x.dot(d_inputs['in_dist'])
                    if 'in_serial' in d_inputs:
                        d_outputs['out_serial'] += d_out_serial__d_y.dot(d_inputs['in_serial'])
            elif mode == 'rev':
                if 'out_serial' in d_outputs:
                    if 'in_dist' in d_inputs:
                        d_inputs['in_dist'] += d_out_serial__d_x.T.dot(d_outputs['out_serial'])
                    if 'in_serial' in d_inputs:
                        d_inputs['in_serial'] += total_sum*d_out_serial__d_y.T.dot(d_outputs['out_serial'])
    
    size = 7
    if MPI:
        comm = MPI.COMM_WORLD
        rank = comm.rank
        sizes, offsets = evenly_distrib_idxs(comm.size, size)
    else:
        # When running in serial, the entire variable is on rank 0.
        rank = 0
        sizes = {rank : size}
        offsets = {rank : 0}
    
    prob = om.Problem()
    model = prob.model
    
    # Create a distributed source for the distributed input.
    ivc = om.IndepVarComp()
    ivc.add_output('x_dist', np.zeros(sizes[rank]), distributed=True)
    ivc.add_output('x_serial', val=1)
    
    model.add_subsystem("indep", ivc)
    model.add_subsystem("D1", MixedDistrib2())
    model.add_subsystem('con_cmp1', om.ExecComp('con1 = y**2'), promotes=['con1', 'y'])
    
    model.connect('indep.x_dist', 'D1.in_dist')
    model.connect('indep.x_serial', ['D1.in_serial','y'])
    
    prob.driver = om.ScipyOptimizeDriver()
    prob.driver.options['optimizer'] = 'SLSQP'
    
    model.add_design_var('indep.x_serial', lower=5, upper=10)
    model.add_constraint('con1', upper=90)
    
    model.add_objective('D1.out_serial')
    
    prob.setup(force_alloc_complex=True)
    #prob.setup()
    
    # Set initial values of distributed variable.
    x_dist_init = np.ones(sizes[rank])
    prob.set_val('indep.x_dist', x_dist_init)
    
    # Set initial values of serial variable.
    prob.set_val('indep.x_serial', 10)
    
    prob.run_model()
    
    prob.check_partials()
    
    # prob.run_driver()
    print('x_dist', prob.get_val('indep.x_dist', get_remote=True))
    print('x_serial', prob.get_val('indep.x_serial'))
    print('Obj', prob.get_val('D1.out_serial'))