I have implemented a simple fold function in C++ that accepts a lambda, and can fold multiple vectors at the same time at compile time. I am wondering if it could be simplified in some manner (I have provided both a recursive version and an iteratively recursive version - I am unsure which should have better performance): https://godbolt.org/z/39pW81
Performance optimizations are also welcome - in that regard is any of the two approaches faster?
template<int I, typename type_identity, typename type_head, int N, typename ...type_tail, int ...N_tail, typename Function>
auto foldHelperR(Function&& func, const type_identity& id, const tvecn<type_head, N>& head, const tvecn<type_tail, N_tail>&... tail)
{
if constexpr (I>0)
{
return func(foldHelperR<I-1>(std::forward<Function>(func), id, head, tail...), head[I], tail[I]...);
}
else
{
return func(id, head[0], tail[0]...);
}
}
template<int I, typename type_identity, typename type_head, int N, typename ...type_tail, int ...N_tail, typename Function>
auto foldHelperI(Function&& func, const type_identity id, const tvecn<type_head, N>& head, const tvecn<type_tail, N_tail>&... tail)
{
if constexpr (I<N-1)
{
return foldHelperI<I+1>(std::forward<Function>(func), func(id, head[I], tail[I]...), head, tail...);
}
else
{
return func(id, head[N-1], tail[N-1]...);
}
}
template<typename type_identity, typename type_head, int N_head, typename ...type_tail, int ...N_tail, typename Function = void (const type_identity&, const type_head&, const type_tail&...)>
constexpr auto fold(Function&& func, const type_identity& id, const tvecn<type_head, N_head>& head, const tvecn<type_tail, N_tail>&... tail)
{
static_assert(std::is_invocable_v<Function, const type_identity&, const type_head&, const type_tail &...>,
"The function cannot be invoked with these zip arguments (possibly wrong argument count).");
static_assert(all_equal_v<N_head, N_tail...>, "Vector sizes must match.");
//return foldHelperR<N_head-1>(std::forward<Function>(func), id, head, tail...);
return foldHelperI<0>(std::forward<Function>(func), id, head, tail...);
}
int main()
{
tvecn<int,3> a(1,2,3);
return fold([](auto x, auto y, auto z) {return x+y+z;}, 0, a, a);
}