Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft of Python bindings for SmallMatrix objects #767

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ target_sources(pyImpactX
ReferenceParticle.cpp
transformation.cpp
WakeConvolution.cpp
SmallMatrix.cpp
)
72 changes: 72 additions & 0 deletions src/python/SmallMatrix.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2021-2023 The ImpactX Community
*
* Authors: Ryan Sandberg, Axel Huebl
* License: BSD-3-Clause-LBNL
*/
#include "pyImpactX.H"
#include <AMReX_SmallMatrix.H>

namespace py = pybind11;

namespace pybind11 {
namespace detail {

template <typename T, int NRows, int NCols, amrex::Order ORDER, int StartIndex>
struct pybind11::detail::type_caster<amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>> {
public:
PYBIND11_TYPE_CASTER(amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>,
_("SmallMatrix[") + py::detail::make_caster<T>::name() + _("]"));

// Conversion from Python to C++
bool load(handle src, bool) {
// Ensure we have a numpy array
py::array_t<T> arr = py::cast<py::array_t<T>>(src);
py::buffer_info buf = arr.request();

// Check dimensions and shape
if (buf.ndim != 2) {
throw std::runtime_error("SmallMatrix requires a 2D array.");
}
if (buf.shape[0] != NRows || buf.shape[1] != NCols) {
throw std::runtime_error("SmallMatrix array shape must match NRows x NCols.");
}

// Create a SmallMatrix and copy data
amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex> mat;
T* ptr = static_cast<T*>(buf.ptr);
for (int i = 0; i < NRows * NCols; ++i) {
mat.m_mat[i] = ptr[i];
}

value = mat;
return true;
}

// Conversion from C++ to Python
static handle cast(const amrex::SmallMatrix<T, NRows, NCols, ORDER, StartIndex>& src,
return_value_policy /* policy */, handle /* parent */) {
py::array_t<T> arr({NRows, NCols});
py::buffer_info buf = arr.request();
T* ptr = static_cast<T*>(buf.ptr);
for (int i = 0; i < NRows * NCols; ++i) {
ptr[i] = src.m_mat[i];
}
return arr.release();
}
};

} // namespace detail
} // namespace pybind11


PYBIND11_MODULE(example, m) {
// You can now just bind constructors and methods normally without defining conversion code:
py::class_<amrex::SmallMatrix<double, 6, 6>>(m, "SmallMatrix6x6")
.def(py::init<>()) // Default init
.def("as_array", [](const amrex::SmallMatrix<double, 6, 6>& mat) {
return mat; // Will use type_caster to return a numpy array
});

// Now Python functions expecting a SmallMatrix<double,6,6> can pass a numpy array directly:
// def some_func(mat: SmallMatrix6x6): ...
}
1 change: 1 addition & 0 deletions src/python/pyImpactX.H
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include <pybind11/stl.h>
#include <pybind11/stl_bind.h>
#include <pybind11/functional.h>
#include <pybind11/numpy.h>

#include <particles/elements/All.H>

Expand Down
Loading