Search code examples
c++matlabinheritanceeigenmex

How can I initialize an object in C++ using a variable number of fields from a matlab data api struct


I am building a matlab MEX function using the matlab c++ data api. My mex function accepts a struct with some fields of varying size, type and name as an input. The exact makeup of this struct could vary and is defined outside the scope of the program, but I know all the possible combinations of the constituent fields ahead of time. We will call this parameter called 'Grid', and Grid is passed to a helper function.

In this helper function, I would like to generate an instance of a derived class where the specific type of derived class will either depend on/correspond to the specific combination of the fields of Grid. The idea is that I can extract the fields of Grid and use them to create the instance of the correct derived class. I would like to achieve this without the need to rewrite my code every time I add a new derived class with a different possible combination of fields. How could I do this? I am open to alternate approaches and strategies as well.

For example, Grid might be defined in the matlab environment like:

Grid = struct('XPSF',X(:),'ZPSF',Z(:),'xe',Trans.ElementPos(:,1)*wvlToM,'TXDelay',TXdelay(:,8));

Then handled by the mex function and passed to the helper function whose definition looks like:

void extractFields(matlab::data::StructArray& Grid);

Currently, Grid can also be composed of a single value in place of XPSF or ZPSF. I anticipate possibly adding other fields to Grid in the future. For each of these possible combinations, I have a derived class that has some unique overridden functions:

class Beamform {
    public:
    //Constructor    
    Beamform();
    
    virtual ~Beamform() {}
    template <typename T> int sgn(T val) { return (T(0) < val) - (val < T(0)); }
    virtual void calcIDXT(...);
};

class bfmPlaneWave : public Beamform 
{
    public:
    double theta;
    Eigen::VectorXd xe, XPSF, ZPSF, dTX;
    
    template<typename Derived>
    bfmPlaneWave(double& angle, ...);

    template<typename Derived>
    void calcIDXT(...) override;
};

class bfmXRflMTX : public Beamform {
    public:
    double theta, zCoord;
    Eigen::VectorXd xe, XPSFin, XPSFout, dTX;

    template<typename Derived>
    bfmXRflMTX(double& angle, ...);

    template<typename Derived>
    void calcIDXT(...) override;
};

class bfmZRflMTX : public Beamform {
    public:
    double theta, xCoord;
    Eigen::VectorXd xe, ZPSFin, ZPSFout, dTX;

    template<typename Derived>
    bfmXRflMTX(double& angle, ...);

    template<typename Derived>
    void calcIDXT(...) override;
};

Solution

  • I would start by declaring a common pattern for construction. Something like this:

    class Beamform
    {
    public:
      virtual void calcIDXT(...) = 0;
      virtual ~Beamform() = default;
    };
    
    class bfmPlaneWave: public Beamform
    {
    public:
      /** Return nullptr if not compatible */
      static bfmPlaneWave* fromGrid(matlab::data::StructArray&);
      virtual void calcIDXT(...) override;
    };
    
    class bfmXRflMTX: public Beamform
    {
    public:
      /** Return nullptr if not compatible */
      static bfmXRflMTX* fromGrid(matlab::data::StructArray&);
      virtual void calcIDXT(...) override;
    };
    

    Then you could have a simple, central factory function that you extend as required:

    /**
     * Option 1: Use a central dispatch function which can be updated with a
     * simple two-liner
     */
    std::unique_ptr<Beamform> beamformForGrid(matlab::data::StructArray& Grid)
    {
      std::unique_ptr<Beamform> rtrn;
      if(rtrn.reset(bfmPlaneWave::fromGrid(Grid)), rtrn != nullptr)
        return rtrn;
      if(rtrn.reset(bfmXRflMTX::fromGrid(Grid)), rtrn != nullptr)
        return rtrn;
      // insert more here
      return rtrn;
    }
    

    However, if I understand you correctly, this is something that you don't want. In that case you could use a central registry and global constructors. Global constructors (those for global variables) are run when a DLL is loaded. This is similar to how for example CppUnit registers its unit tests.

    class AbstractBeamformFactory
    {
    public: 
      virtual ~AbstractBeamformFactory() = default;
      /** Return nullptr if not compatible */
      virtual Beamform* fromGrid(matlab::data::StructArray&) = 0;
    };
    /**
     * Registers existing factories
     *
     * Follows the singleton pattern.
     * Yes, it is frowned upon, but if it works, it works.
     */
    class BeamformRegistry
    {
      /**
       * Protects the list of factories
       *
       * A bit overkill seeing how neither Matlab nor global constructors are
       * particularly multithreaded, but better safe than sorry
       */
      mutable std::mutex mutex;
      std::vector<AbstractBeamformFactory*> factories;
    public:
      /**
       * Retrieves singleton instance
       * 
       * This isn't a global variable because we need to use it in other
       * global constructors and we can't force a specific order between
       * global constructors
       */
      static BeamformRegistry& globalInstance();
    
      void add(AbstractBeamformFactory* factory)
      {
        std::lock_guard<std::mutex> lock(mutex);
        factories.push_back(factory);
      }
      void remove(AbstractBeamformFactory* factory)
      {
        std::lock_guard<std::mutex> lock(mutex);
        factories.erase(std::find(factories.begin(), factories.end(), factory));
      }
      std::unique_ptr<Beamform> beamformForGrid(matlab::data::StructArray& Grid) const
      {
        std::unique_ptr<Beamform> rtrn;
        std::lock_guard<std::mutex> lock(mutex);
        for(AbstractBeamformFactory* factory: factories)
          if(rtrn.reset(factory->fromGrid(Grid)), rtrn != nullptr)
            break;
        return rtrn;
      }
    };
    /**
     * Implements AbstractBeamformFactory for a specific type of beamformer
     *
     * Create a global variable of this type in order to add it to the global
     * BeamformRegistry
     */
    template<class BeamformImplementation>
    class BeamformFactory: public AbstractBeamformFactory
    {
      bool registered;
    public:
      explicit BeamformFactory(bool registerGlobal=true)
        : registered(registerGlobal)
      {
        /* don't move this to the base class to avoid issues around
         * half-initialized objects in the registry
         */
        if(registerGlobal)
          BeamformRegistry::globalInstance().add(this);
      }
      virtual ~BeamformFactory()
      {
        if(registered)
          BeamformRegistry::globalInstance().remove(this);
      }
      virtual Beamform* fromGrid(matlab::data::StructArray& Grid) override
      { return BeamformImplementation::fromGrid(Grid); }
    };
    
    /* in CPP files */
    BeamformRegistry& BeamformRegistry::globalInstance()
    {
      static BeamformRegistry instance;
      return instance;
    }
    /*
     * Make global variables to add entries to registry.
     * These can be scattered across different cpp files
     */
    BeamformFactory<bfmPlaneWave> planeWaveFactoryInstance;
    BeamformFactory<bfmXRflMTX> XRflMTXFactoryInstance;
    

    Now you can simply call BeamformRegistry::globalInstance().beamformForGrid(Grid) to access all registered beamform implementations and to extend the number of implementations, you just scatter factory instances across your cpp files.

    One thing I'm unsure about is how this interacts with MEX. When does Matlab load its extensions? If this only happens in some form of lazy fashion, the global constructors may not execute soon enough. I guess it is worth checking with a few print statements.