I am writing a matrix library for generic types using expression templates.
The basic matrix class is a template class Matrix <typename Scalar, int RowSize, int ColumnSize>
which inherits from MatrixXpr< Matrix<Scalar, RowSize, ColumnSize> >
where MatrixXpr is the parent class for the expression templates "MatrixSum", "MatrixProduct" etc.
For example:
template <typename Mat, typename Rix>
class MatrixProduct : public MatrixXpr< MatrixProduct<Mat,Rix> >
{
private:
const Mat& A_;
const Rix& B_;
public:
using value_type= std::common_type_t<typename Mat::value_type, typename Rix::value_type>;
MatrixProduct(const Mat& A, const Rix& B) : A_(A), B_(B) {}
value_type operator()(int i, int j) const {
value_type out{ 0 };
for (int k = 0; k < A_.Columns(); ++k) out += A_(i, k) * B_(k, j);
return out;
}
};
The * operator is then defined outside
template <typename Mat, typename Rix>
MatrixProduct<Mat, Rix> inline const operator*(const MatrixXpr<Mat>& A, const MatrixXpr<Rix>& B)
{
return MatrixProduct<Mat, Rix>(A, B);
}
Now I wish to implement also a Scalar*Matrix class. But I fail to define the correct value_type:
template <typename Scalar, typename Mat>
class ScalarMatrixProduct : public MatrixXpr< ScalarMatrixProduct<Scalar, Mat> >
{
private:
const Scalar& A_;
const Mat& B_;
public:
using value_type = std::common_type_t<typename Mat::value_type, typename Scalar>;
ScalarMatrixProduct(const Scalar& A, const Mat& B) : A_(A), B_(B) {}
value_type operator()(int i, int j) const {
return A_ * B_(i, j);
}
};
template <typename Scalar, typename Mat>
typename std::enable_if < (!is_matrix<Scalar>::value),
ScalarMatrixProduct<Scalar, Mat > >::type const operator*(const Scalar& A, const MatrixXpr<Mat>& B)
{
return ScalarMatrixProduct<Scalar, Mat>(A, B);
}
On Mac and Linux I get an compilation error of this sort:
template argument 2 is invalid 102 | using value_type = std::common_type_t<typename Mat::value_type, typename Scalar>;
Interestingly, it compiles on Windows.
Any hints for what's wrong would be helpful. Thanks in advance.
Complete example:
#include <type_traits>
#include <iostream>
#include <array>
#include <initializer_list>
///////////Expression Template Base Class for CRTP
template <class MatrixClass> struct MatrixXpr {
decltype(auto) operator()(int i, int j) const {
return static_cast<MatrixClass const&>(*this)(i, j);
}
operator MatrixClass& () {
return static_cast<MatrixClass&>(*this);
}
operator const MatrixClass& () const {
return static_cast<const MatrixClass&>(*this);
}
int Rows()
{
return static_cast<MatrixClass&>(*this).Rows();
}
int Columns()
{
return static_cast<MatrixClass&>(*this).Columns();
}
int Rows() const
{
return static_cast<const MatrixClass&>(*this).Rows();
}
int Columns() const
{
return static_cast<const MatrixClass&>(*this).Columns();
}
friend int Rows(const MatrixXpr& A)
{
return A.Rows();
}
friend int Columns(const MatrixXpr& A)
{
return A.Columns();
}
};
template <typename MatrixClass>
std::ostream& operator<<(std::ostream& os, const MatrixXpr<MatrixClass>& A)
{
for (int r = 0; r < Rows(A); ++r) {
os << '[';
for (int c = 0; c < Columns(A); ++c)
os << A(r, c) << (c + 1 < Columns(A) ? " " : "");
os << "]\n";
}
return os;
}
/////////// Matrix Product
template <typename Mat, typename Rix>
class MatrixProduct : public MatrixXpr< MatrixProduct<Mat, Rix> >
{
private:
const Mat& A_;
const Rix& B_;
public:
using value_type = std::common_type_t<typename Mat::value_type, typename Rix::value_type>;
MatrixProduct(const Mat& A, const Rix& B) : A_(A), B_(B)
{
std::cout << "MatrixMatrixProduct Constructor\n";
}
int Rows() const { return A_.Rows(); }
int Columns() const { return B_.Columns(); }
value_type operator()(int i, int j) const {
value_type out{ 0 };
for (int k = 0; k < A_.Columns(); ++k) out += A_(i, k) * B_(k, j);
return out;
}
};
/////////// Scalar Matrix Product
template <typename Scalar, typename Mat>
class ScalarMatrixProduct : public MatrixXpr< ScalarMatrixProduct<Scalar, Mat> >
{
private:
const Scalar& A_;
const Mat& B_;
public:
using value_type = std::common_type_t<typename Mat::value_type, typename Scalar>;
ScalarMatrixProduct(const Scalar& A, const Mat& B) : A_(A), B_(B) {
std::cout << "ScalarMatrixProduct Constructor\n";
}
int Rows() const { return B_.Rows(); }
int Columns() const { return B_.Columns(); }
value_type operator()(int i, int j) const {
return A_ * B_(i, j);
}
};
//The following two functions are Helpers for initializing an array.
//Source: https://stackoverflow.com/a/38934685/6176345
template<typename T, std::size_t N, std::size_t ...Ns>
std::array<T, N> make_array_impl(
std::initializer_list<T> list,
std::index_sequence<Ns...>)
{
return std::array<T, N>{ *(list.begin() + Ns) ... };
}
template<typename T, std::size_t N>
std::array<T, N> make_array(std::initializer_list<T> list) {
if (N > list.size())
throw std::out_of_range("Initializer list too small.");
return make_array_impl<T, N>(list, std::make_index_sequence<N>());
}
/////////// Matrix class
template <typename Scalar, int RowSize, int ColumnSize = RowSize>
class Matrix : public MatrixXpr< Matrix<Scalar, RowSize, ColumnSize> >
{
std::array<Scalar, RowSize* ColumnSize> data_;
public:
using value_type = Scalar;
const static int rows_ = RowSize;
const static int columns_ = ColumnSize;
int Rows() const { return rows_; }
int Columns() const { return columns_; }
Matrix() : data_{ Scalar(0) } {};
Matrix(const Matrix& other) = default;
Matrix(Matrix&& other) = default;
Matrix& operator=(const Matrix& other) = default;
Matrix& operator=(Matrix&& other) = default;
~Matrix() = default;
Matrix(std::initializer_list<Scalar> data) : data_(make_array<Scalar, RowSize* ColumnSize>(data)) {}
template <typename Source>
Matrix& operator=(const MatrixXpr<Source>& source)
{
for (int i = 0; i < rows_; ++i)
for (int j = 0; j < columns_; ++j)
data_[MatrixIndex(i, j)] = source(i, j);
return *this;
}
template <typename Source>
Matrix(const MatrixXpr<Source>& source)
{
for (int i = 0; i < rows_; ++i)
for (int j = 0; j < columns_; ++j)
data_[MatrixIndex(i, j)] = source(i, j);
}
Scalar& operator()(int i, int j) {
return data_[MatrixIndex(i, j)];
}
const Scalar& operator()(int i, int j) const {
return data_[MatrixIndex(i, j)];
}
private:
inline static int MatrixIndex(int i, int j)
{
return i * columns_ + j;
}
};
/////////// Multiplication operators
template <typename Mat, typename Rix>
MatrixProduct<Mat, Rix> inline const operator*(const MatrixXpr<Mat>& A, const MatrixXpr<Rix>& B)
{
std::cout << "Matrix Matrix Multiplication\n";
return MatrixProduct<Mat, Rix>(A, B);
}
template <typename Scalar, typename Mat>
typename std::enable_if_t<!std::is_base_of_v<MatrixXpr<Scalar>, Scalar>,
ScalarMatrixProduct<Scalar, Mat >> const operator*(const Scalar& A, const MatrixXpr<Mat>& B)
{
return ScalarMatrixProduct<Scalar, Mat>(A, B);
}
/////////// Failing example
int main()
{
Matrix<int, 2, 2> m = { 1,0,0,1 };
auto n = 3 * m;
std::cout << n;
std::cout << m * n;
//std::cout << n * m; // Error
return 0;
}
Edit:
The above code originally had two problems.
The first one is that my type checking failed to see which overload of the *operator was being used. The above implementation with std::is_base_of_v<MatrixXpr<Scalar>, Scalar>
fixed it and is is working correctly.
I do not know why this old code did not work. Here is the old version:
template <typename T>
struct is_matrix : std::false_type {};
template <typename T>
struct is_matrix<const T> : is_matrix<T> {};
template <typename MatrixClass>
struct is_matrix<MatrixXpr<MatrixClass> > : std::true_type {};
template <typename Scalar, typename Mat>
typename std::enable_if < (!is_matrix<Scalar>::value),
ScalarMatrixProduct<Scalar, Mat > >::type const operator*(const Scalar& A, const MatrixXpr<Mat>& B)
{
std::cout << "Scalar Matrix Multiplication\n";
return ScalarMatrixProduct<Scalar, Mat>(A, B);
}