Search code examples
c++templateseigeneigen3

Vector of different types of template classes?


The code relevant to the problem I'm having is below. I'm trying to write a neural net using Eigen. I want to use Eigen's tensors to implement layers for my neural net but I'm not sure how. Eigen's tensors require me to input two template arguments, a type and an int for the number of dimensions of the tensor. The only type I'm going to use is a double, but each Layer needs to take an input Tensor and return an output Tensor, possibly of different dimensions, so I need to have the Layer class have a template for those two numbers. However, doing this prevents me from having a std::vector of Layers. Is there any way around this? Also, as you can probably tell, the Layer class is abstract because other classes are going to inherit from it (which is why I want to be able to put it in a std::vector). I've looked at Boost's variant class, but I'm not sure I can use it, because I don't think I can go and explicitly add every possible type of Layer I might use before running it, and I don't know if it's somehow possible to do that automatically using templates somehow.

#include <unsupported/Eigen/CXX11/Tensor>

template<int inputDims, int outputDims>
class Layer{
public:
    virtual ~Layer();
    virtual Eigen::Tensor<double,outputDims> fire(Eigen::Tensor<double,inputDims>) = 0;
    virtual Eigen::Tensor<double,outputDims> derivative(Eigen::Tensor<double,inputDims>) = 0;
};

std::vector<Layer> v; //Doesn't compile

Solution

  • Something like this?

    class BaseLayer 
    {
    public:
        virtual ~BaseLayer() {};
    };
    
    template<int inputDims, int outputDims>
    class Layer : public BaseLayer
    {
    public:
    
        void fire() { std::cout << inputDims << ":" << outputDims << std::endl; }
        void derivative() { std::cout << inputDims << ":" << outputDims << std::endl; }
    };
    
    
    int main()
    {
        std::vector<BaseLayer *> v;
        BaseLayer *l1 = new Layer<10, 20>;
        v.push_back(l1);
    
        Layer<10, 20> *l11 = dynamic_cast<Layer<10, 20> *>(l1);
        if (l11)
            l11->fire();
    
        BaseLayer *l2 = new Layer<3,4>;
        v.push_back(l2);
    
        Layer<3,4> *l22 = dynamic_cast<Layer<3,4> *>(l2);
        if (l22)
            l22->derivative();
    
        return 0;
    }
    

    Prints:

    10:20
    3:4