diff --git a/cmake/algebra-plugins-compiler-options-cpp.cmake b/cmake/algebra-plugins-compiler-options-cpp.cmake index 7a6b0d78..1385e6e7 100644 --- a/cmake/algebra-plugins-compiler-options-cpp.cmake +++ b/cmake/algebra-plugins-compiler-options-cpp.cmake @@ -28,7 +28,7 @@ elseif( "${CMAKE_CXX_COMPILER_ID}" MATCHES "MSVC" ) # Basic flags for all build modes. string( REGEX REPLACE "/W[0-9]" "" CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}" ) - algebra_add_flag( CMAKE_CXX_FLAGS "/W4" ) + algebra_add_flag( CMAKE_CXX_FLAGS "/W4 /bigobj" ) # Fail on warnings, if asked for that behaviour. if( ALGEBRA_PLUGINS_FAIL_ON_WARNINGS ) diff --git a/common/include/algebra/concepts.hpp b/common/include/algebra/concepts.hpp index 251a1f46..0374d218 100644 --- a/common/include/algebra/concepts.hpp +++ b/common/include/algebra/concepts.hpp @@ -88,6 +88,34 @@ concept column_matrix = matrix && (algebra::traits::columns == 1); template concept column_matrix3D = column_matrix && (algebra::traits::rows == 3); + +template +concept matrix_compatible = matrix&& matrix&& std::convertible_to< + algebra::traits::index_t, algebra::traits::index_t>&& + std::convertible_to, + algebra::traits::index_t>; + +template +concept matrix_multipliable = + matrix_compatible && + (algebra::traits::columns == + algebra::traits::rows)&&requires(algebra::traits::scalar_t sa, + algebra::traits::scalar_t sb) { + {(sa * sb) + (sa * sb)}; +}; + +template +concept matrix_multipliable_into = + matrix_multipliable&& matrix_compatible&& + matrix_compatible && + (algebra::traits::rows == algebra::traits::rows)&&( + algebra::traits::columns == + algebra::traits::columns< + MB>)&&requires(algebra::traits::scalar_t sa, + algebra::traits::scalar_t sb, + algebra::traits::scalar_t& sc) { + {sc += (sa * sb)}; +}; /// @} /// Transform concept diff --git a/frontend/array_cmath/include/algebra/array_cmath.hpp b/frontend/array_cmath/include/algebra/array_cmath.hpp index 0d787dfb..581ce4f3 100644 --- a/frontend/array_cmath/include/algebra/array_cmath.hpp +++ b/frontend/array_cmath/include/algebra/array_cmath.hpp @@ -93,6 +93,14 @@ using cmath::determinant; using cmath::inverse; using cmath::transpose; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/eigen_eigen/CMakeLists.txt b/frontend/eigen_eigen/CMakeLists.txt index cfcad057..8e0d8ce9 100644 --- a/frontend/eigen_eigen/CMakeLists.txt +++ b/frontend/eigen_eigen/CMakeLists.txt @@ -8,7 +8,7 @@ algebra_add_library( algebra_eigen_eigen eigen_eigen "include/algebra/eigen_eigen.hpp" ) target_link_libraries( algebra_eigen_eigen - INTERFACE algebra::common algebra::eigen_storage algebra::eigen_math + INTERFACE algebra::common algebra::eigen_storage algebra::eigen_math algebra::generic_math Eigen3::Eigen ) algebra_test_public_headers( algebra_eigen_eigen "algebra/eigen_eigen.hpp" ) diff --git a/frontend/eigen_eigen/include/algebra/eigen_eigen.hpp b/frontend/eigen_eigen/include/algebra/eigen_eigen.hpp index 328629a4..db285330 100644 --- a/frontend/eigen_eigen/include/algebra/eigen_eigen.hpp +++ b/frontend/eigen_eigen/include/algebra/eigen_eigen.hpp @@ -9,6 +9,7 @@ // Project include(s). #include "algebra/math/eigen.hpp" +#include "algebra/math/generic.hpp" #include "algebra/storage/eigen.hpp" // Eigen include(s). @@ -65,6 +66,14 @@ using eigen::math::set_zero; using eigen::math::transpose; using eigen::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/eigen_generic/include/algebra/eigen_generic.hpp b/frontend/eigen_generic/include/algebra/eigen_generic.hpp index 6dfa680d..864e9996 100644 --- a/frontend/eigen_generic/include/algebra/eigen_generic.hpp +++ b/frontend/eigen_generic/include/algebra/eigen_generic.hpp @@ -71,6 +71,14 @@ using generic::math::set_zero; using generic::math::transpose; using generic::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/fastor_fastor/CMakeLists.txt b/frontend/fastor_fastor/CMakeLists.txt index d2467ff2..ffec83b1 100644 --- a/frontend/fastor_fastor/CMakeLists.txt +++ b/frontend/fastor_fastor/CMakeLists.txt @@ -8,6 +8,6 @@ algebra_add_library( algebra_fastor_fastor fastor_fastor "include/algebra/fastor_fastor.hpp" ) target_link_libraries( algebra_fastor_fastor - INTERFACE algebra::common algebra::fastor_storage algebra::fastor_math ) + INTERFACE algebra::common algebra::fastor_storage algebra::fastor_math algebra::generic_math ) algebra_test_public_headers( algebra_fastor_fastor "algebra/fastor_fastor.hpp" ) diff --git a/frontend/fastor_fastor/include/algebra/fastor_fastor.hpp b/frontend/fastor_fastor/include/algebra/fastor_fastor.hpp index f3621774..34ba3714 100644 --- a/frontend/fastor_fastor/include/algebra/fastor_fastor.hpp +++ b/frontend/fastor_fastor/include/algebra/fastor_fastor.hpp @@ -9,6 +9,7 @@ // Project include(s). #include "algebra/math/fastor.hpp" +#include "algebra/math/generic.hpp" #include "algebra/storage/fastor.hpp" // Fastor include(s). @@ -67,6 +68,14 @@ using fastor::math::set_zero; using fastor::math::transpose; using fastor::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/smatrix_generic/include/algebra/smatrix_generic.hpp b/frontend/smatrix_generic/include/algebra/smatrix_generic.hpp index 7cc40f83..f9a933fe 100644 --- a/frontend/smatrix_generic/include/algebra/smatrix_generic.hpp +++ b/frontend/smatrix_generic/include/algebra/smatrix_generic.hpp @@ -62,6 +62,14 @@ using generic::math::set_zero; using generic::math::transpose; using generic::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/smatrix_smatrix/CMakeLists.txt b/frontend/smatrix_smatrix/CMakeLists.txt index 4eaef534..d9a061db 100644 --- a/frontend/smatrix_smatrix/CMakeLists.txt +++ b/frontend/smatrix_smatrix/CMakeLists.txt @@ -8,6 +8,6 @@ algebra_add_library( algebra_smatrix_smatrix smatrix_smatrix "include/algebra/smatrix_smatrix.hpp" ) target_link_libraries( algebra_smatrix_smatrix - INTERFACE algebra::common algebra::smatrix_storage algebra::smatrix_math ) + INTERFACE algebra::common algebra::smatrix_storage algebra::smatrix_math algebra::generic_math ) algebra_test_public_headers( algebra_smatrix_smatrix "algebra/smatrix_smatrix.hpp" ) diff --git a/frontend/smatrix_smatrix/include/algebra/smatrix_smatrix.hpp b/frontend/smatrix_smatrix/include/algebra/smatrix_smatrix.hpp index d2d0b099..f6946946 100644 --- a/frontend/smatrix_smatrix/include/algebra/smatrix_smatrix.hpp +++ b/frontend/smatrix_smatrix/include/algebra/smatrix_smatrix.hpp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "algebra/math/generic.hpp" #include "algebra/math/smatrix.hpp" #include "algebra/storage/smatrix.hpp" @@ -58,6 +59,14 @@ using smatrix::math::set_zero; using smatrix::math::transpose; using smatrix::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/vc_aos/CMakeLists.txt b/frontend/vc_aos/CMakeLists.txt index cca2890a..f5dd4889 100644 --- a/frontend/vc_aos/CMakeLists.txt +++ b/frontend/vc_aos/CMakeLists.txt @@ -8,6 +8,6 @@ algebra_add_library( algebra_vc_aos vc_aos "include/algebra/vc_aos.hpp" ) target_link_libraries( algebra_vc_aos - INTERFACE algebra::common algebra::vc_aos_storage algebra::vc_aos_math ) + INTERFACE algebra::common algebra::vc_aos_storage algebra::vc_aos_math algebra::generic_math ) algebra_test_public_headers( algebra_vc_aos "algebra/vc_aos.hpp" ) diff --git a/frontend/vc_aos/include/algebra/vc_aos.hpp b/frontend/vc_aos/include/algebra/vc_aos.hpp index cfec4604..66947c85 100644 --- a/frontend/vc_aos/include/algebra/vc_aos.hpp +++ b/frontend/vc_aos/include/algebra/vc_aos.hpp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "algebra/math/generic.hpp" #include "algebra/math/vc_aos.hpp" #include "algebra/storage/vc_aos.hpp" @@ -63,6 +64,14 @@ using vc_aos::math::set_zero; using vc_aos::math::transpose; using vc_aos::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/vc_aos_generic/include/algebra/vc_aos_generic.hpp b/frontend/vc_aos_generic/include/algebra/vc_aos_generic.hpp index d1a94a3d..f12ce1c2 100644 --- a/frontend/vc_aos_generic/include/algebra/vc_aos_generic.hpp +++ b/frontend/vc_aos_generic/include/algebra/vc_aos_generic.hpp @@ -70,6 +70,14 @@ using generic::math::set_zero; using generic::math::transpose; using generic::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/frontend/vc_soa/CMakeLists.txt b/frontend/vc_soa/CMakeLists.txt index 097c4765..66432be3 100644 --- a/frontend/vc_soa/CMakeLists.txt +++ b/frontend/vc_soa/CMakeLists.txt @@ -8,7 +8,7 @@ algebra_add_library( algebra_vc_soa vc_soa "include/algebra/vc_soa.hpp" ) target_link_libraries( algebra_vc_soa - INTERFACE algebra::common algebra::vc_soa_storage algebra::vc_soa_math + INTERFACE algebra::common algebra::vc_soa_storage algebra::vc_soa_math algebra::generic_math algebra::vc_aos_math ) algebra_test_public_headers( algebra_vc_soa "algebra/vc_soa.hpp" ) diff --git a/frontend/vc_soa/include/algebra/vc_soa.hpp b/frontend/vc_soa/include/algebra/vc_soa.hpp index 0a20409e..09b5c0ef 100644 --- a/frontend/vc_soa/include/algebra/vc_soa.hpp +++ b/frontend/vc_soa/include/algebra/vc_soa.hpp @@ -8,6 +8,7 @@ #pragma once // Project include(s). +#include "algebra/math/generic.hpp" #include "algebra/math/impl/vc_aos_transform3.hpp" #include "algebra/math/vc_soa.hpp" #include "algebra/storage/vc_soa.hpp" @@ -71,6 +72,14 @@ using vc_soa::math::set_zero; using vc_soa::math::transpose; using vc_soa::math::zero; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + } // namespace matrix namespace vc_soa { diff --git a/frontend/vecmem_cmath/include/algebra/vecmem_cmath.hpp b/frontend/vecmem_cmath/include/algebra/vecmem_cmath.hpp index daf11920..c3d5a560 100644 --- a/frontend/vecmem_cmath/include/algebra/vecmem_cmath.hpp +++ b/frontend/vecmem_cmath/include/algebra/vecmem_cmath.hpp @@ -92,6 +92,14 @@ using generic::math::determinant; using generic::math::inverse; using generic::math::transpose; +using generic::math::set_inplace_product_left; +using generic::math::set_inplace_product_left_transpose; +using generic::math::set_inplace_product_right; +using generic::math::set_inplace_product_right_transpose; +using generic::math::set_product; +using generic::math::set_product_left_transpose; +using generic::math::set_product_right_transpose; + /// @} } // namespace matrix diff --git a/math/generic/include/algebra/math/impl/generic_matrix.hpp b/math/generic/include/algebra/math/impl/generic_matrix.hpp index 1bb4ce8d..706c1b14 100644 --- a/math/generic/include/algebra/math/impl/generic_matrix.hpp +++ b/math/generic/include/algebra/math/impl/generic_matrix.hpp @@ -83,6 +83,196 @@ ALGEBRA_HOST_DEVICE inline auto transpose(const M &m) { return ret; } +// Set matrix C to the product AB +template +ALGEBRA_HOST_DEVICE inline void +set_product(MC &C, const MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + for (index_t j = 0; j < algebra::traits::columns; ++j) { + value_t t = 0.f; + + for (index_t k = 0; k < algebra::traits::rows; ++k) { + t += algebra::traits::element_getter_t()(A, i, k) * + algebra::traits::element_getter_t()(B, k, j); + } + + algebra::traits::element_getter_t()(C, i, j) = t; + } + } +} + +// Set matrix C to the product A^TB +template +ALGEBRA_HOST_DEVICE inline void +set_product_left_transpose(MC &C, const MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into< + decltype(transpose(std::declval())), MB, MC>) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + for (index_t j = 0; j < algebra::traits::columns; ++j) { + value_t t = 0.f; + + for (index_t k = 0; k < algebra::traits::rows; ++k) { + t += algebra::traits::element_getter_t()(A, k, i) * + algebra::traits::element_getter_t()(B, k, j); + } + + algebra::traits::element_getter_t()(C, i, j) = t; + } + } +} + +// Set matrix C to the product AB^T +template +ALGEBRA_HOST_DEVICE inline void +set_product_right_transpose(MC &C, const MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into< + MA, decltype(transpose(std::declval())), MC>) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + for (index_t j = 0; j < algebra::traits::columns; ++j) { + value_t t = 0.f; + + for (index_t k = 0; k < algebra::traits::columns; ++k) { + t += algebra::traits::element_getter_t()(A, i, k) * + algebra::traits::element_getter_t()(B, j, k); + } + + algebra::traits::element_getter_t()(C, i, j) = t; + } + } +} + +// Set matrix A to the product AB in place +template +ALGEBRA_HOST_DEVICE inline void +set_inplace_product_right(MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + algebra::traits::get_matrix_t, value_t> + Q; + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + algebra::traits::element_getter_t()(Q, 0, j) = + algebra::traits::element_getter_t()(A, i, j); + } + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + value_t t = 0.f; + + for (index_t k = 0; k < algebra::traits::rows; ++k) { + t += algebra::traits::element_getter_t()(Q, 0, k) * + algebra::traits::element_getter_t()(B, k, j); + } + + algebra::traits::element_getter_t()(A, i, j) = t; + } + } +} + +// Set matrix A to the product BA in place +template +ALGEBRA_HOST_DEVICE inline void +set_inplace_product_left(MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + algebra::traits::get_matrix_t, value_t> + Q; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + algebra::traits::element_getter_t()(Q, 0, i) = + algebra::traits::element_getter_t()(A, i, j); + } + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + value_t t = 0.f; + + for (index_t k = 0; k < algebra::traits::columns; ++k) { + t += algebra::traits::element_getter_t()(B, i, k) * + algebra::traits::element_getter_t()(Q, 0, k); + } + + algebra::traits::element_getter_t()(A, i, j) = t; + } + } +} + +// Set matrix A to the product AB^T in place +template +ALGEBRA_HOST_DEVICE inline void +set_inplace_product_right_transpose(MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into< + MA, decltype(transpose(std::declval())), MA>) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + algebra::traits::get_matrix_t, value_t> + Q; + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + algebra::traits::element_getter_t()(Q, 0, j) = + algebra::traits::element_getter_t()(A, i, j); + } + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + value_t T = 0.f; + + for (index_t k = 0; k < algebra::traits::columns; ++k) { + T += algebra::traits::element_getter_t()(Q, 0, k) * + algebra::traits::element_getter_t()(B, j, k); + } + + algebra::traits::element_getter_t()(A, i, j) = T; + } + } +} + +// Set matrix A to the product B^TA in place +template +ALGEBRA_HOST_DEVICE inline void +set_inplace_product_left_transpose(MA &A, const MB &B) requires( + algebra::concepts::matrix_multipliable_into< + decltype(transpose(std::declval())), MA, MA>) { + using index_t = algebra::traits::index_t; + using value_t = algebra::traits::value_t; + + for (index_t j = 0; j < algebra::traits::columns; ++j) { + algebra::traits::get_matrix_t, value_t> + Q; + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + algebra::traits::element_getter_t()(Q, 0, i) = + algebra::traits::element_getter_t()(A, i, j); + } + + for (index_t i = 0; i < algebra::traits::rows; ++i) { + value_t T = 0.f; + + for (index_t k = 0; k < algebra::traits::rows; ++k) { + T += algebra::traits::element_getter_t()(B, k, i) * + algebra::traits::element_getter_t()(Q, 0, k); + } + + algebra::traits::element_getter_t()(A, i, j) = T; + } + } +} + /// @returns the determinant of @param m template ALGEBRA_HOST_DEVICE inline algebra::traits::scalar_t determinant( diff --git a/tests/common/test_host_basics.hpp b/tests/common/test_host_basics.hpp index 72282212..1e6dbc3a 100644 --- a/tests/common/test_host_basics.hpp +++ b/tests/common/test_host_basics.hpp @@ -20,6 +20,7 @@ #include #include #include +#include /// Test case class, to be specialised for the different plugins - vectors template @@ -28,7 +29,240 @@ TYPED_TEST_SUITE_P(test_host_basics_vector); /// Test case class, to be specialised for the different plugins - matrices template -class test_host_basics_matrix : public testing::Test, public test_base {}; +class test_host_basics_matrix : public testing::Test, public test_base { + protected: + template + void test_matrix_ops_any_matrix() { + // Test the set_product method. + { + typename A::template matrix m1; + typename A::template matrix m2; + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < ROWS; ++j) { + algebra::getter::element(m1, i, j) = + static_cast>(i * ROWS + + j); + } + } + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < COLS; ++j) { + algebra::getter::element(m2, i, j) = + static_cast>(i * COLS + + j); + } + } + + { + typename A::template matrix r1 = m1 * m2; + typename A::template matrix r2; + algebra::matrix::set_product(r2, m1, m2); + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < COLS; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + } + + // Test the set_product_right_transpose method. + { + typename A::template matrix m1; + typename A::template matrix m2; + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < ROWS; ++j) { + algebra::getter::element(m1, i, j) = + static_cast>(i * ROWS + + j); + } + } + + for (std::size_t i = 0; i < COLS; ++i) { + for (std::size_t j = 0; j < ROWS; ++j) { + algebra::getter::element(m2, i, j) = + static_cast>(i * COLS + + j); + } + } + + { + typename A::template matrix r1 = + m1 * algebra::matrix::transpose(m2); + typename A::template matrix r2; + algebra::matrix::set_product_right_transpose(r2, m1, m2); + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < COLS; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + } + + // Test the set_product_left_transpose method. + { + typename A::template matrix m1; + typename A::template matrix m2; + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < ROWS; ++j) { + algebra::getter::element(m1, i, j) = + static_cast>(i * ROWS + + j); + } + } + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < COLS; ++j) { + algebra::getter::element(m2, i, j) = + static_cast>(i * COLS + + j); + } + } + + { + typename A::template matrix r1 = + algebra::matrix::transpose(m1) * m2; + typename A::template matrix r2; + algebra::matrix::set_product_left_transpose(r2, m1, m2); + + for (std::size_t i = 0; i < ROWS; ++i) { + for (std::size_t j = 0; j < COLS; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + } + } + + template + void test_matrix_ops_square_matrix() { + { + typename A::template matrix m1; + typename A::template matrix m2; + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + algebra::getter::element(m1, i, j) = + static_cast>(i * N + j); + algebra::getter::element(m2, i, j) = + static_cast>( + -1 * (i * N + j) + 42); + } + } + + // Test the set_product method. + { + typename A::template matrix r1 = m1 * m2; + typename A::template matrix r2; + algebra::matrix::set_product(r2, m1, m2); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_product_right_transpose method. + { + typename A::template matrix r1 = + m1 * algebra::matrix::transpose(m2); + typename A::template matrix r2; + algebra::matrix::set_product_right_transpose(r2, m1, m2); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_product_left_transpose method. + { + typename A::template matrix r1 = + algebra::matrix::transpose(m1) * m2; + typename A::template matrix r2; + algebra::matrix::set_product_left_transpose(r2, m1, m2); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_inplace_product_right method. + { + typename A::template matrix r1 = m1 * m2; + typename A::template matrix r2 = m1; + algebra::matrix::set_inplace_product_right(r2, m2); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_inplace_product_left method. + { + typename A::template matrix r1 = m1 * m2; + typename A::template matrix r2 = m2; + algebra::matrix::set_inplace_product_left(r2, m1); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_inplace_product_right_transpose method. + { + typename A::template matrix r1 = + m1 * algebra::matrix::transpose(m2); + typename A::template matrix r2 = m1; + algebra::matrix::set_inplace_product_right_transpose(r2, m2); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + + // Test the set_inplace_product_left_transpose method. + { + typename A::template matrix r1 = + algebra::matrix::transpose(m1) * m2; + typename A::template matrix r2 = m2; + algebra::matrix::set_inplace_product_left_transpose(r2, m1); + + for (std::size_t i = 0; i < N; ++i) { + for (std::size_t j = 0; j < N; ++j) { + ASSERT_NEAR(algebra::getter::element(r1, i, j), + algebra::getter::element(r2, i, j), this->m_epsilon); + } + } + } + } + + this->template test_matrix_ops_any_matrix(); + } +}; TYPED_TEST_SUITE_P(test_host_basics_matrix); /// Test case class, to be specialised for the different plugins - transforms @@ -170,6 +404,8 @@ TYPED_TEST_P(test_host_basics_vector, getter) { } TYPED_TEST_P(test_host_basics_matrix, matrix_2x3) { + static constexpr typename TypeParam::size_type ROWS = 2; + static constexpr typename TypeParam::size_type COLS = 3; using matrix_2x3_t = typename TypeParam::template matrix<2, 3>; @@ -224,12 +460,17 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_2x3) { ASSERT_NEAR(v2[0], 14, this->m_epsilon); ASSERT_NEAR(v2[1], 32, this->m_epsilon); + + this->template test_matrix_ops_any_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_3x1) { // Print the linear algebra types of this backend using algebra::operator<<; + static constexpr typename TypeParam::size_type ROWS = 3; + static constexpr typename TypeParam::size_type COLS = 1; + // Cross product on vector3 and matrix<3,1> typename TypeParam::template matrix<3, 1> vF; algebra::getter::element(vF, 0, 0) = 5.f; @@ -248,6 +489,8 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_3x1) { // Dot product on vector3 and matrix<3,1> auto dot = algebra::vector::dot(vG, vF); ASSERT_NEAR(dot, 0.f, this->m_epsilon); + + this->template test_matrix_ops_any_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_6x4) { @@ -366,9 +609,13 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_6x4) { // Test printing std::cout << m << std::endl; + + this->template test_matrix_ops_any_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_3x3) { + static constexpr typename TypeParam::size_type N = 3; + { typename TypeParam::vector3 v = {10.f, 20.f, 30.f}; typename TypeParam::template matrix<3, 3> m33; @@ -425,9 +672,12 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_3x3) { ASSERT_NEAR(algebra::getter::element(m33_inv, 2, 2), -10.f / 20.f, this->m_isclose); } + + this->template test_matrix_ops_square_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_2x2) { + static constexpr typename TypeParam::size_type N = 2; typename TypeParam::template matrix<2, 2> m22; algebra::getter::element(m22, 0, 0) = 4.f; @@ -449,9 +699,12 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_2x2) { this->m_isclose); ASSERT_NEAR(algebra::getter::element(m22_inv, 1, 1), 4.f / 16.f, this->m_isclose); + + this->template test_matrix_ops_square_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_6x6) { + static constexpr typename TypeParam::size_type N = 6; // Test 6 X 6 big matrix determinant typename TypeParam::template matrix<6, 6> m66_big; @@ -584,6 +837,8 @@ TYPED_TEST_P(test_host_basics_matrix, matrix_6x6) { auto m66_small_det = algebra::matrix::determinant(m66_small); ASSERT_NEAR((m66_small_det - 4.30636e-11f) / 4.30636e-11f, 0.f, 2.f * this->m_isclose); + + this->template test_matrix_ops_square_matrix(); } TYPED_TEST_P(test_host_basics_matrix, matrix_small_mixed) {