Search code examples
c++templatesconstantfolding

Avoid switching on types to allow constant folding


I am trying to find a class hierarchy that permits to implement place holders for processor registers and operations on it. It should also allow for constants to be folded at run time. For sake of simplicity I'll only look at one operation, here multiplication. Place holders and constants should be accessible uniformly, i.e. have a common base class.

The code below defines the following classes:

class A: Base class for place holders (registers) and constants

class B: Place holder for a register (its structure holds the name of it)

class C: Base of all constants

class CI: int constant

class CF: float constant

#include <iostream>
#include <memory>
#include <cassert>

class A {
public:
  virtual ~A(){}
};

class B : public A {
};

class C : public A {};

class CI : public C {
public:
  typedef int Type_t;
  int getValue() {return 1;}
};

class CF : public C {
public:
  typedef float Type_t;
  float getValue() {return 1.1;}
};




typedef std::shared_ptr<A> Aptr;
typedef std::shared_ptr<B> Bptr;
typedef std::shared_ptr<C> Cptr;
typedef std::shared_ptr<CI> CIptr;
typedef std::shared_ptr<CF> CFptr;


template<class T, class T2> struct promote {};
template<> struct promote<float,int>   { typedef float Type_t; };
template<> struct promote<float,float> { typedef float Type_t; };
template<> struct promote<int,float>   { typedef float Type_t; };
template<> struct promote<int,int  >   { typedef int   Type_t; };


template<class T1, class T2>
typename promote<typename T1::element_type::Type_t,
         typename T2::element_type::Type_t>::Type_t
mul_const( const T1& c1 , const T2& c2 )
{
  std::cout << c1->getValue() * c2->getValue() << "\n";
  return c1->getValue() * c2->getValue();
}



template<class T>
std::shared_ptr<T> get(const Aptr& pA) {
  return std::dynamic_pointer_cast< T >( pA );
}


Aptr create_A(float f) { return std::make_shared<A>(); }
Aptr create_A(int i)   { return std::make_shared<A>(); }


Aptr mul_const( const Cptr& cp1 , const Cptr& cp2 )
{
  if (auto c1 = get<CI>(cp1))
    if (auto c2 = get<CF>(cp2)) {
      return create_A( mul_const(c1,c2) );
    }
  if (auto c1 = get<CF>(cp1))
    if (auto c2 = get<CI>(cp2)) {
      return create_A( mul_const(c1,c2) );
    }
  if (auto c1 = get<CI>(cp1))
    if (auto c2 = get<CI>(cp2)) {
      return create_A( mul_const(c1,c2) );
    }
  if (auto c1 = get<CF>(cp1))
    if (auto c2 = get<CF>(cp2)) {
      return create_A( mul_const(c1,c2) );
    }
  assert(!"oops");
}



Aptr mul( const Aptr& pA1, const Aptr& pA2 ) 
{
  if (auto c1 = get<C>(pA1))
    if (auto c2 = get<C>(pA2)) 
      {
    return mul_const(c1,c2);
      }
}


int main()
{
  Aptr pA1( new CF );
  Aptr pA2( new CI );

  Aptr result = mul( pA1, pA2 );
}

The problem I am having with the above code is the function Aptr mul_const( const Cptr& cp1 , const Cptr& cp2 ). It contains basically switching of types for all possible combinations of constant types. It works, but I would like to know if this can be done more elegantly?


Solution

  • I guess you could do what the compiler does, and convert the other parameter to float when one is float. You'll probably need a new function to do the conversion and a "isFloat" (or "isInt"). I'm not convinced it gives you that much benefit, really...

    // Add two virtual member functions here:
    class C : public A {
        public:
           virtual bool isInt() = 0;
           virtual float getAsFloat() = 0;
    };
    

    Then implement:

    class CI : public C {
    public:
      typedef int Type_t;
      int getValue() {return 1;}
      float getAsFloat() { return getValue(); }
      bool isInt() { return true; }
    };
    
    class CF : public C {
    public:
      typedef float Type_t;
      float getValue() {return 1.1;}
      float getAsFloat() { return getValue(); }
      bool isInt() { return false; }
    };
    

    Now, your mul_const becomes:

    Aptr mul_const( const Cptr& cp1 , const Cptr& cp2 )
    {
      if (cp1.isInt() && cp2.isInt())
      {
         CIptr c1 = get<CI>(cp1));
         CIptr c2 = get<CI>(cp2));
         std::cout << c1->getValue() * c2->getValue() << "\n";
         return CIptr(c1->getValue() * c2->getValue());
      }
      else
      {
         std::cout << cp1->getAsFloat() * cp2->getAsFloat() << "\n";
         return CFptr(cp2->getAsFloat() * cp2->getAsFloat());
      }
      // This becomes unreachable... Probably OK to delete.
      assert(!"oops");
    }
    

    [And I think a few template parts can be deleted... ]