Skip to content

Commit

Permalink
Refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
ax3l committed Jan 6, 2025
1 parent b23d94b commit d615266
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 65 deletions.
154 changes: 101 additions & 53 deletions src/Base/SmallMatrix.H
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,6 @@ namespace
get_value_type_t<T>
>;
}
}

namespace pyAMReX
{
using namespace amrex;

/** CPU: __array_interface__ v3
*
Expand All @@ -58,12 +53,14 @@ namespace pyAMReX
class T,
int NRows,
int NCols,
Order ORDER = Order::F,
amrex::Order ORDER = amrex::Order::F,
int StartIndex = 0
>
py::dict
array_interface (SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
array_interface (amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex> const & m)
{
using namespace amrex;

auto d = py::dict();
// provide C index order for shape and strides
auto shape = m.ordering == Order::F ? py::make_tuple(
Expand Down Expand Up @@ -110,27 +107,26 @@ namespace pyAMReX
return d;
}

template<
class T,
int NRows,
int NCols,
Order ORDER = Order::F,
int StartIndex = 0
>
void make_SmallMatrix(py::module &m, std::string typestr)
template<class SM>
py::class_<SM>
make_SmallMatrix_or_Vector (py::module &m, std::string typestr)
{
using namespace amrex;

using T = typename SM::value_type;
using T_no_cv = std::remove_cv_t<T>;
static constexpr int row_size = SM::row_size;
static constexpr int column_size = SM::column_size;
static constexpr Order ordering = SM::ordering;
static constexpr int starting_index = SM::starting_index;

// dispatch simpler via: py::format_descriptor<T>::format() naming
// but note the _const suffix that might be needed
auto const sm_name = std::string("SmallMatrix_")
.append(std::to_string(NRows)).append("x").append(std::to_string(NCols))
.append("_").append(ORDER == Order::F ? "F" : "C")
.append("_SI").append(std::to_string(StartIndex))
.append(std::to_string(row_size)).append("x").append(std::to_string(column_size))
.append("_").append(ordering == Order::F ? "F" : "C")
.append("_SI").append(std::to_string(starting_index))
.append("_").append(typestr);
using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
py::class_< SM > py_sm(m, sm_name.c_str());
py_sm
.def("__repr__",
Expand Down Expand Up @@ -177,7 +173,7 @@ namespace pyAMReX
py::format_descriptor<T_no_cv>::format() +
"' and received '" + buf.format + "'!");

// TODO: check that strides are either exact or None in buf
// TODO: check that strides are either exact or None in buf (e.g., F or C contiguous)
// TODO: transpose if SM order is not C?

auto sm = std::make_unique< SM >();
Expand All @@ -197,7 +193,7 @@ namespace pyAMReX
// CPU: __array_interface__ v3
// https://numpy.org/doc/stable/reference/arrays.interface.html
.def_property_readonly("__array_interface__", [](SM const & sm) {
return pyAMReX::array_interface(sm);
return array_interface(sm);
})

// CPU: __array_function__ interface (TODO)
Expand All @@ -210,8 +206,9 @@ namespace pyAMReX

// Nvidia GPUs: __cuda_array_interface__ v3
// https://numba.readthedocs.io/en/latest/cuda/cuda_array_interface.html
.def_property_readonly("__cuda_array_interface__", [](SM const & sm) {
auto d = pyAMReX::array_interface(sm);
.def_property_readonly("__cuda_array_interface__", [](SM const & sm)
{
auto d = array_interface(sm);

// data:
// Because the user of the interface may or may not be in the same context, the most common case is to use cuPointerGetAttribute with CU_POINTER_ATTRIBUTE_DEVICE_POINTER in the CUDA driver API (or the equivalent CUDA Runtime API) to retrieve a device pointer that is usable in the currently active context.
Expand Down Expand Up @@ -239,6 +236,21 @@ namespace pyAMReX
// https://github.com/dmlc/dlpack/blob/master/include/dlpack/dlpack.h
// https://docs.cupy.dev/en/stable/user_guide/interoperability.html#dlpack-data-exchange-protocol

;

return py_sm;
}

template<class SM>
void add_matrix_methods (py::class_<SM> & py_sm)
{
using T = typename SM::value_type;
using T_no_cv = std::remove_cv_t<T>;
static constexpr int row_size = SM::row_size;
static constexpr int column_size = SM::column_size;
static constexpr int starting_index = SM::starting_index;

py_sm
.def("dot", &SM::dot)
.def("prod", &SM::product) // NumPy name
.def("set_val", &SM::setVal)
Expand All @@ -254,66 +266,102 @@ namespace pyAMReX

// getter
.def("__getitem__", [](SM & sm, std::array<int, 2> const & key){
if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index ||
key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index)
if (key[0] < starting_index || key[0] >= row_size + starting_index ||
key[1] < starting_index || key[1] >= column_size + starting_index)
throw std::runtime_error(
"Index out of bounds: [" +
std::to_string(key[0]) + ", " +
std::to_string(key[1]) + "]");
return sm(key[0], key[1]);
})
;

// setter
if constexpr (is_not_const<T>())
{
py_sm
.def("__setitem__", [](SM & sm, std::array<int, 2> const & key, T const value){
.def("__setitem__", [](SM & sm, std::array<int, 2> const & key, T_no_cv const value){
if (key[0] < SM::starting_index || key[0] >= SM::row_size + SM::starting_index ||
key[1] < SM::starting_index || key[1] >= SM::column_size + SM::starting_index)
{
throw std::runtime_error(
"Index out of bounds: [" +
std::to_string(key[0]) + ", " +
std::to_string(key[1]) + "]");
}
sm(key[0], key[1]) = value;
})
;
}

// square matrix
if constexpr (NRows == NCols)
if constexpr (row_size == column_size)
{
py_sm
.def_static("identity", [](){ return SM::Identity(); })
.def_static("identity", []() { return SM::Identity(); })
.def("trace", &SM::trace)
.def("transpose_in_place", &SM::transposeInPlace)
;
}
}

// vector
if constexpr (NRows == 1 || NCols == 1)
{
py_sm
.def("__getitem__", [](SM & sm, int key){
if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
return sm(key);
})
.def("__setitem__", [](SM & sm, int key, T const value){
if (key < SM::starting_index || key >= SM::column_size * SM::row_size + SM::starting_index)
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
sm(key) = value;
})
;
} else {
using SV = SmallMatrix<T, NRows, 1, Order::F, StartIndex>;
using SRV = SmallMatrix<T, 1, NCols, Order::F, StartIndex>;
template<class T_SV>
void add_get_set_Vector (py::class_<T_SV> &py_v)
{
using self = T_SV;
using T = typename T_SV::value_type;
using T_no_cv = std::remove_cv_t<T>;

// operators for matrix-matrix & matrix-vector
py_sm
.def(py::self * py::self)
.def(py::self * SV())
.def(SRV() * py::self)
;
}
py_v
.def("__getitem__", [](self & sm, int key){
if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
return sm(key);
})
.def("__setitem__", [](self & sm, int key, T_no_cv const value){
if (key < self::starting_index || key >= self::column_size * self::row_size + self::starting_index)
throw std::runtime_error("Index out of bounds: " + std::to_string(key));
sm(key) = value;
})
;
}
}

namespace pyAMReX
{
template<
class T,
int NRows,
int NCols,
amrex::Order ORDER = amrex::Order::F,
int StartIndex = 0
>
void make_SmallMatrix (py::module &m, std::string typestr)
{
using namespace amrex;

using SM = SmallMatrix<T, NRows, NCols, ORDER, StartIndex>;
using SV = SmallMatrix<T, NRows, 1, Order::F, StartIndex>;
using SRV = SmallMatrix<T, 1, NCols, Order::F, StartIndex>;

py::class_<SM> py_sm = make_SmallMatrix_or_Vector<SM>(m, typestr);
py::class_<SV> py_sv = make_SmallMatrix_or_Vector<SV>(m, typestr);
py::class_<SRV> py_srv = make_SmallMatrix_or_Vector<SRV>(m, typestr);

// methods, getter, setter
add_matrix_methods(py_sm);
add_matrix_methods(py_sv);
add_matrix_methods(py_srv);

// vector setter/getter
add_get_set_Vector(py_sv);
add_get_set_Vector(py_srv);

// operators for matrix-matrix & matrix-vector
py_sm
.def(py::self * py::self)
.def(py::self * SV())
.def(SRV() * py::self)
;
}
}
13 changes: 1 addition & 12 deletions src/Base/SmallMatrix.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,9 @@ void init_SmallMatrix (py::module &m)
{
constexpr int NRows = 6;
constexpr int NCols = 6;
constexpr Order ORDER = Order::F;
constexpr amrex::Order ORDER = amrex::Order::F;
constexpr int StartIndex = 1;

// Vector
make_SmallMatrix< float, NRows, 1, ORDER, StartIndex >(m, "float");
make_SmallMatrix< double, NRows, 1, ORDER, StartIndex >(m, "double");
make_SmallMatrix< long double, NRows, 1, ORDER, StartIndex >(m, "longdouble");

// RowVector
make_SmallMatrix< float, 1, NCols, ORDER, StartIndex >(m, "float");
make_SmallMatrix< double, 1, NCols, ORDER, StartIndex >(m, "double");
make_SmallMatrix< long double, 1, NCols, ORDER, StartIndex >(m, "longdouble");

// Matrix
make_SmallMatrix< float, NRows, NCols, ORDER, StartIndex >(m, "float");
make_SmallMatrix< double, NRows, NCols, ORDER, StartIndex >(m, "double");
make_SmallMatrix< long double, NRows, NCols, ORDER, StartIndex >(m, "longdouble");
Expand Down

0 comments on commit d615266

Please sign in to comment.