From 44c82f0198d4c56cbcb22f905dc53f94c6083608 Mon Sep 17 00:00:00 2001 From: Axel Huebl Date: Sun, 5 Jan 2025 18:07:19 -0800 Subject: [PATCH] Refactor --- src/Base/SmallMatrix.H | 158 +++++++++++++++++++++++++-------------- src/Base/SmallMatrix.cpp | 13 +--- 2 files changed, 104 insertions(+), 67 deletions(-) diff --git a/src/Base/SmallMatrix.H b/src/Base/SmallMatrix.H index ca0a8a9b..3c8dd4d5 100644 --- a/src/Base/SmallMatrix.H +++ b/src/Base/SmallMatrix.H @@ -44,11 +44,6 @@ namespace get_value_type_t >; } -} - -namespace pyAMReX -{ - using namespace amrex; /** CPU: __array_interface__ v3 * @@ -58,12 +53,14 @@ namespace pyAMReX class T, int NRows, int NCols, - Order ORDER = Order::F, + amrex::Order ORDER = amrex::Order::F, int StartIndex = 0 > py::dict - array_interface (SmallMatrix const & m) + array_interface (amrex::SmallMatrix const & m) { + using namespace amrex; + auto d = py::dict(); // provide C index order for shape and strides auto shape = m.ordering == Order::F ? py::make_tuple( @@ -110,27 +107,26 @@ namespace pyAMReX return d; } - template< - class T, - int NRows, - int NCols, - Order ORDER = Order::F, - int StartIndex = 0 - > - void make_SmallMatrix(py::module &m, std::string typestr) + template + py::class_ + make_SmallMatrix_or_Vector (py::module &m, std::string typestr) { using namespace amrex; + using T = typename SM::value_type; using T_no_cv = std::remove_cv_t; + static constexpr int row_size = SM::row_size; + static constexpr int column_size = SM::column_size; + static constexpr Order ordering = SM::ordering; + static constexpr int starting_index = SM::starting_index; // dispatch simpler via: py::format_descriptor::format() naming // but note the _const suffix that might be needed auto const sm_name = std::string("SmallMatrix_") - .append(std::to_string(NRows)).append("x").append(std::to_string(NCols)) - .append("_").append(ORDER == Order::F ? "F" : "C") - .append("_SI").append(std::to_string(StartIndex)) + .append(std::to_string(row_size)).append("x").append(std::to_string(column_size)) + .append("_").append(ordering == Order::F ? "F" : "C") + .append("_SI").append(std::to_string(starting_index)) .append("_").append(typestr); - using SM = SmallMatrix; py::class_< SM > py_sm(m, sm_name.c_str()); py_sm .def("__repr__", @@ -177,7 +173,7 @@ namespace pyAMReX py::format_descriptor::format() + "' and received '" + buf.format + "'!"); - // TODO: check that strides are either exact or None in buf + // TODO: check that strides are either exact or None in buf (e.g., F or C contiguous) // TODO: transpose if SM order is not C? auto sm = std::make_unique< SM >(); @@ -197,7 +193,7 @@ namespace pyAMReX // CPU: __array_interface__ v3 // https://numpy.org/doc/stable/reference/arrays.interface.html .def_property_readonly("__array_interface__", [](SM const & sm) { - return pyAMReX::array_interface(sm); + return array_interface(sm); }) // CPU: __array_function__ interface (TODO) @@ -210,8 +206,9 @@ namespace pyAMReX // Nvidia GPUs: __cuda_array_interface__ v3 // https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html - .def_property_readonly("__cuda_array_interface__", [](SM const & sm) { - auto d = pyAMReX::array_interface(sm); + .def_property_readonly("__cuda_array_interface__", [](SM const & sm) + { + auto d = array_interface(sm); // data: // Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context. @@ -239,6 +236,21 @@ namespace pyAMReX // https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h // https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol + ; + + return py_sm; + } + + template + void add_matrix_methods (py::class_ & py_sm) + { + using T = typename SM::value_type; + using T_no_cv = std::remove_cv_t; + static constexpr int row_size = SM::row_size; + static constexpr int column_size = SM::column_size; + static constexpr int starting_index = SM::starting_index; + + py_sm .def("dot", &SM::dot) .def("prod", &SM::product) // NumPy name .def("set_val", &SM::setVal) @@ -254,8 +266,8 @@ namespace pyAMReX // getter .def("__getitem__", [](SM & sm, std::array const & key){ - if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index || - key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index) + if (key[0] < starting_index || key[0] >= row_size + starting_index || + key[1] < starting_index || key[1] >= column_size + starting_index) throw std::runtime_error( "Index out of bounds: [" + std::to_string(key[0]) + ", " + @@ -263,57 +275,93 @@ namespace pyAMReX return sm(key[0], key[1]); }) ; + // setter if constexpr (is_not_const()) { py_sm - .def("__setitem__", [](SM & sm, std::array const & key, T const value){ + .def("__setitem__", [](SM & sm, std::array const & key, T_no_cv const value){ if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index || key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index) + { throw std::runtime_error( "Index out of bounds: [" + std::to_string(key[0]) + ", " + std::to_string(key[1]) + "]"); + } sm(key[0], key[1]) = value; }) ; } // square matrix - if constexpr (NRows == NCols) + if constexpr (row_size == column_size) { py_sm - .def_static("identity", [](){ return SM::Identity(); }) - .def("trace", &SM::trace) - .def("transpose_in_place", &SM::transposeInPlace) + .def_static("identity", []() { return SM::Identity(); }) + .def("trace", [](SM & sm){ return sm.trace(); }) + .def("transpose_in_place", [](SM & sm){ return sm.transposeInPlace(); }) ; } + } - // vector - if constexpr (NRows == 1 || NCols == 1) - { - py_sm - .def("__getitem__", [](SM & sm, int key){ - if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index) - throw std::runtime_error("Index out of bounds: " + std::to_string(key)); - return sm(key); - }) - .def("__setitem__", [](SM & sm, int key, T const value){ - if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index) - throw std::runtime_error("Index out of bounds: " + std::to_string(key)); - sm(key) = value; - }) - ; - } else { - using SV = SmallMatrix; - using SRV = SmallMatrix; + template + void add_get_set_Vector (py::class_ &py_v) + { + using self = T_SV; + using T = typename T_SV::value_type; + using T_no_cv = std::remove_cv_t; - // operators for matrix-matrix & matrix-vector - py_sm - .def(py::self * py::self) - .def(py::self * SV()) - .def(SRV() * py::self) - ; - } + py_v + .def("__getitem__", [](self & sm, int key){ + if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index) + throw std::runtime_error("Index out of bounds: " + std::to_string(key)); + return sm(key); + }) + .def("__setitem__", [](self & sm, int key, T_no_cv const value){ + if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index) + throw std::runtime_error("Index out of bounds: " + std::to_string(key)); + sm(key) = value; + }) + ; + } +} + +namespace pyAMReX +{ + template< + class T, + int NRows, + int NCols, + amrex::Order ORDER = amrex::Order::F, + int StartIndex = 0 + > + void make_SmallMatrix (py::module &m, std::string typestr) + { + using namespace amrex; + + using SM = SmallMatrix; + using SV = SmallMatrix; + using SRV = SmallMatrix; + + py::class_ py_sm = make_SmallMatrix_or_Vector(m, typestr); + py::class_ py_sv = make_SmallMatrix_or_Vector(m, typestr); + py::class_ py_srv = make_SmallMatrix_or_Vector(m, typestr); + + // methods, getter, setter + add_matrix_methods(py_sm); + add_matrix_methods(py_sv); + add_matrix_methods(py_srv); + + // vector setter/getter + add_get_set_Vector(py_sv); + add_get_set_Vector(py_srv); + + // operators for matrix-matrix & matrix-vector + py_sm + .def(py::self * py::self) + .def(py::self * SV()) + .def(SRV() * py::self) + ; } } diff --git a/src/Base/SmallMatrix.cpp b/src/Base/SmallMatrix.cpp index cb8c886e..c33f373d 100644 --- a/src/Base/SmallMatrix.cpp +++ b/src/Base/SmallMatrix.cpp @@ -16,20 +16,9 @@ void init_SmallMatrix (py::module &m) { constexpr int NRows = 6; constexpr int NCols = 6; - constexpr Order ORDER = Order::F; + constexpr amrex::Order ORDER = amrex::Order::F; constexpr int StartIndex = 1; - // Vector - make_SmallMatrix< float, NRows, 1, ORDER, StartIndex >(m, "float"); - make_SmallMatrix< double, NRows, 1, ORDER, StartIndex >(m, "double"); - make_SmallMatrix< long double, NRows, 1, ORDER, StartIndex >(m, "longdouble"); - - // RowVector - make_SmallMatrix< float, 1, NCols, ORDER, StartIndex >(m, "float"); - make_SmallMatrix< double, 1, NCols, ORDER, StartIndex >(m, "double"); - make_SmallMatrix< long double, 1, NCols, ORDER, StartIndex >(m, "longdouble"); - - // Matrix make_SmallMatrix< float, NRows, NCols, ORDER, StartIndex >(m, "float"); make_SmallMatrix< double, NRows, NCols, ORDER, StartIndex >(m, "double"); make_SmallMatrix< long double, NRows, NCols, ORDER, StartIndex >(m, "longdouble");