1

I'm re-writing a c++ FEM library in which I used expression template. In the previous version I was able to do, for example

BilinearForm s(mesh, Vh, Vh, [&gp, &mesh](const TrialFunction& u, const TestFunction& v) -> double
{
    return mesh.integrate(gp*dot(grad(u), grad(v)) );
});

Which would automatically evaluate the expression for each trial and test functions in Vh and return me a sparse Matrix.

The code for that was rather heavy and cumbersome so I wanted to re-write it a bit taking inspiration from this article. It is fairly straightforward for functions that are defined on a single element.

template<typename E1, typename E2>
class ElementWiseScalarProd : public ElementWiseScalarExpression< ElementWiseScalarProd<E1, E2> >
{
public:
    ElementWiseScalarProd(const E1& lhs, const E2& rhs) : m_lhs(lhs), m_rhs(rhs) {}
public:
    inline double operator[] (const size_t k) const { return m_lhs[k] * m_rhs[k]; }
public:
    inline bool containsTrial() const { return m_lhs.containsTrial() or m_rhs.containsTrial(); }
    inline bool containsTest()  const { return m_lhs.containsTest() or m_rhs.containsTest(); }
public:
    inline const Element* getElement() const { assert(m_lhs.getElement() == m_rhs.getElement()); return m_lhs.getElement(); }
private:
    const E1& m_lhs;
    const E2& m_rhs;
};

However, when I want to multiply a function that is defined on the whole mesh, things becomes a bit trickier. The return type of my operator becomes an array or a slice of an array.

template<typename E1, typename E2>
class FunctionProd : public FunctionExpression< FunctionProd<E1, E2> >
{
public:
    typedef ElementWiseScalarProd<E1::ElementWiseType, E2::ElementWiseType> ElementWiseType;
public:
    inline const ElementWiseType operator[] (const size_t e) const { return ElementWiseType(m_lhs[e], m_rhs[e]); }
public:
    inline const Mesh* getMesh() const { assert(m_lhs.getMesh() == m_rhs.getMesh()); return m_lhs.getMesh(); }
private:
    const E1& m_lhs;
    const E2& m_rhs;
};

It seems that, if I were to multiply a function defined on an element and a function defined on the whole mesh, my FunctionProd::operator[] should return a reference but it would mean I need to store the object it creates wouldn't it ? Is there a way to circumvent that ?

Thanks in advance

Edit: A similar question is answered here Nesting of subexpressions in expression templates

1 Answers1

1

You need to account for the value category of your inner expressions.

You can do that with a trait to distinguish lvalues and rvalues, and a bunch of perfect forwarding.

template <typename T>
struct expression_holder {
    using type = T;
};

template <typename T>
struct expression_holder<const T> : expression_holder<T> {
};

template <typename T>
struct expression_holder<T &> {
    using type = const T &;
};

template <typename T>
struct expression_holder<T &&> {
    using type = T;
};

template <typename T>
using expression_holder_t = typename expression_holder<T>::type;

template<typename E1, typename E2>
class ElementWiseScalarProd : public ElementWiseScalarExpression< ElementWiseScalarSum<E1, E2> >
{
public:
    ElementWiseScalarProd(E1&& lhs, E2&& rhs) : m_lhs(std::forward<E1>(lhs)), m_rhs(std::forward<E2>(rhs)) {}
...
private:
    expression_holder_t<E1> m_lhs;
    expression_holder_t<E2> m_rhs;
};

template<typename E1, typename E2> 
ElementWiseScalarProd<E1, E2> scalar_prod(E1 && lhs, E2 && rhs) {
    return ElementWiseScalarProd<E1, E2>(std::forward<E1>(lhs), std::forward<E2>(rhs));
}
Caleth
  • 52,200
  • 2
  • 44
  • 75
  • Thank you ! This is a bit beyond my c++ level. Can you recommend something I can read to get an understanding of what perfect forwarding does ? – Alexandre Hoffmann Jan 13 '22 at 10:25
  • @AlexandreHoffmann https://en.cppreference.com/w/cpp/utility/forward. I guess this use isn't technically perfect forwarding, as we are adding `const` to all the lvalues – Caleth Jan 13 '22 at 10:29
  • I tried to addapt the code from the wikipedia article with holder. But I'm not sure how the operator would turnout to be written... `template VecSum operator+(VecExpression&& u, VecExpression&& v) { return VecSum(std::forward(*static_cast(&u)), std::forward(*static_cast(&v))); } ` Doesn't do the trick – Alexandre Hoffmann Jan 13 '22 at 11:38
  • You really need C++20 for it to be nicely one overload: `template requires std::derived_from> && std::derived_from> VecSum operator+(E1&& lhs, E2&& rhs) { return VecSum(std::forward(lhs), std::forward(rhs)); }`. Otherwise you can have overloads for each of the combinations: `VecSum operator+(const VecExpression& u, const VecExpression& v); VecSum operator+(VecExpression&& u, VecExpression&& v);` and the two in-between – Caleth Jan 13 '22 at 13:23
  • That is, you have to have `T&&` for a type parameter `T`, not just a type involving a type parameter, to capture the value category from the call site. – Caleth Jan 13 '22 at 13:25
  • 1
    Thanks, I think I found a c++11 equivalent to your solution https://stackoverflow.com/questions/56619644/nesting-of-subexpressions-in-expression-templates – Alexandre Hoffmann Jan 13 '22 at 14:13