We are writing an expression template library to handle operations on values with a sparse gradient vector (first order automatic differentiation). I am trying to figure out how to make it possible to nest sub-expressions by reference or values depending on whether the expressions are temporaries or not.
We have a class Scalar which is containing a value and a sparse gradient vector. We use expression templates (like Eigen) to prevent the construction and allocation of too many temporaries Scalar
objects. Thus we have the class Scalar
inheriting from ScalarBase<Scalar>
(CRTP).
A binary operation (eg +, *) between objects of type ScalarBase< Left >
and ScalarBase< Right >
return a ScalarBinaryOp<Left, Right,BinaryOp>
object which inherits from ScalarBase< ScalarBinaryOp<Left, Right,BinaryOp> >
:
template< typename Left, typename Right >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
const ScalarBase< Left >& left, const ScalarBase< Right >& right )
{
return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( static_cast< const Left& >( left ),
static_cast< const Right& >( right ), BinaryAdditionOp{} );
}
ScalarBinaryOp
must hold a value or reference to the operands objects of type Left
and Right
. The type of the holder is defined by template specialization of RefTypeSelector< Expression >::Type
.
Currently this is always a const reference. It works at the moment for our test cases but this does not seem correct or safe to hold a reference to temporary subexpressions.
Obviously we also do not want that a Scalar
object containing the sparse gradient vector be copied. If x
and y
are Scalar
, the expression x+y
should hold const reference to x
and y
. However if f
is a function from Scalar
to Scalar
, x+f(y)
should hold a const reference to x
and the value of f(y)
.
Hence I would like to pass the information about whether subexpressions are temporaries or not. I can add this to the expression type parameters:
ScalarBinaryOp< typename Left, typename Right, typename BinaryOp , bool LeftIsTemporary, bool RightIsTemporary >
and to the RefTypeSelector
:
RefTypeSelector< Expression, ExpressionIsTemporary >::Type
But then I would need to define for every binary operators 4 methods:
ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, false > operator+(
const ScalarBase< Left >& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, true > operator+(
const ScalarBase< Left >& left, ScalarBase< Right >&& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, false > operator+(
ScalarBase< Left >&& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, true > operator+(
ScalarBase< Left >&& left, ScalarBase< Right >&& right )
I would prefer to be able to achieve this with perfect forwarding. However I do not know how I can achieve this here. First I cannot use simple "universal references" because they match almost anything. I guess it might be possible to combine universal references and SFINAE to only allow certain parameter types but I am not sure this is the way to go. Also I would like to know if I could encode the information about whether Left and Right were originally lvalue or rvalue references in the types Left and Right which parameterize the ScalarBinaryOp
instead of using the 2 additional booleans parameter and how to retrieve that information.
I have to support gcc 4.8.5 which is mostly c++11 compliant.
update 2019/08/15: implementation
template < typename Expr >
class RefTypeSelector
{
private:
using Expr1 = typename std::decay<Expr>::type;
public:
using Type = typename std::conditional<std::is_lvalue_reference<Expr>::value, const Expr1&,Expr1>::type;
};
template< typename Left, typename Right, typename Op >
class ScalarBinaryOp : public ScalarBase< ScalarBinaryOp< Left, Right, Op > >
{
public:
template <typename L, typename R>
ScalarBinaryOp( L&& left, R&& right, const Op& op )
: left_( std::forward<L>(left) )
, right_( std::forward<R>(right) ))
, ...
{
...
}
...
private:
/** LHS expression */
typename RefTypeSelector< Left >::Type left_;
/** RHS expression */
typename RefTypeSelector< Right >::Type right_;
...
}
template< typename Left, typename Right,
typename Left1 = typename std::decay<Left>::type,
typename Right1 = typename std::decay<Right>::type,
typename std::enable_if<std::is_base_of<ScalarBase<Left1>, Left1>::value,int>::type = 0,
typename std::enable_if<std::is_base_of<ScalarBase<Right1>, Right1>::value,int>::type = 0 >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
Left&& left, Right&& right )
{
return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( std::forward<Left>( left ),
std::forward<Right>( right ), BinaryAdditionOp{} );
}
You can encode lvalue/rvalue information into Left
and Right
types. For example:
ScalarBinaryOp<Left&&, Right&&> operator+(
ScalarBase<Left>&& left, ScalarBase<Right>&& right)
{
return ...;
}
with ScalarBinaryOp
being something like this:
template<class L, class R>
struct ScalarBinaryOp
{
using Left = std::remove_reference_t<L>;
using Right = std::remove_reference_t<R>;
using My_left = std::conditional_t<
std::is_rvalue_reference_v<L>, Left, const Left&>;
using My_right = std::conditional_t<
std::is_rvalue_reference_v<R>, Left, const Right&>;
...
My_left left_;
My_right right_;
};
Alternatively, you can be explicit and store everything by value, except for Scalar
s. To be able to store a Scalar
by value, you use a wrapper class:
x + Value_wrapper(f(y))
The wrapper is simple:
struct Value_wrapper : Base<Value_wrapper> {
Value_wrapper(Scalar&& scalar) : scalar_(std::move(scalar)) {}
operator Scalar() const {
return std::move(scalar_);
}
Scalar&& scalar_;
};
RefTypeSelector
has specialization for Value_wrapper
:
template<> struct RefTypeSelector<Value_wrapper> {
using Type = Scalar;
};
The binary operator definition remains the same:
template<class Left, class Right>
ScalarBinaryOp<Left, Right> operator+(const Base<Left>& left, const Base<Right>& right) {
return {static_cast<const Left&>(left), static_cast<const Right&>(right)};
}
Complete example: https://godbolt.org/z/sJ3NfG
(I used some C++17 features above only to simplify the notation.)