diff --git a/tests/custom_type_ext.cpp b/tests/custom_type_ext.cpp index 50c936ee..ee60cf6f 100644 --- a/tests/custom_type_ext.cpp +++ b/tests/custom_type_ext.cpp @@ -1,6 +1,9 @@ #include #include #include +#include +#include +#include namespace nb = nanobind; namespace dr = drjit; @@ -42,6 +45,51 @@ struct CustomHolder { Value m_value; }; +class Object : public drjit::TraversableBase { + DR_TRAVERSE_CB(drjit::TraversableBase); +}; + +template +class CustomBase : public Object{ +public: + CustomBase() : Object() {} + + virtual Value &value() { + jit_raise("test"); + }; + + DR_TRAVERSE_CB(Object); +}; + +template +class PyCustomBase : public CustomBase{ +public: + using Base = CustomBase; + NB_TRAMPOLINE(Base, 1); + + PyCustomBase() : Base() {} + + Value &value() override { NB_OVERRIDE_PURE(value); } + + DR_TRAMPOLINE_TRAVERSE_CB(Base); +}; + +template +class CustomA: public CustomBase{ +public: + using Base = CustomBase; + + CustomA() {} + CustomA(const Value &v) : m_value(v) {} + + Value &value() override { return m_value; } + +private: + Value m_value; + + DR_TRAVERSE_CB(Base, m_value); +}; + template void bind(nb::module_ &m) { dr::ArrayBinding b; @@ -64,12 +112,42 @@ template void bind(nb::module_ &m) { .def(nb::init()) .def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference); + using CustomBase = CustomBase; + using PyCustomBase = PyCustomBase; + using CustomA = CustomA; + + auto object = nb::class_( + m, "Object", + nb::intrusive_ptr( + [](Object *o, PyObject *po) noexcept { o->set_self_py(po); })); + + auto base = nb::class_(m, "CustomBase") + .def(nb::init()) + .def("value", nb::overload_cast<>(&CustomBase::value)); + jit_log(LogLevel::Debug, "binding base"); + + drjit::bind_traverse(base); + + auto a = nb::class_(m, "CustomA").def(nb::init()); + + drjit::bind_traverse(a); + m.def("cpp_make_opaque", [](CustomFloatHolder &holder) { dr::make_opaque(holder); } ); } NB_MODULE(custom_type_ext, m) { + nb::intrusive_init( + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_INCREF(o); + }, + [](PyObject *o) noexcept { + nb::gil_scoped_acquire guard; + Py_DECREF(o); + }); + #if defined(DRJIT_ENABLE_LLVM) nb::module_ llvm = m.def_submodule("llvm"); bind(llvm); diff --git a/tests/test_custom_type_ext.py b/tests/test_custom_type_ext.py index 90c9b7a2..4ad4c8d6 100644 --- a/tests/test_custom_type_ext.py +++ b/tests/test_custom_type_ext.py @@ -1,6 +1,7 @@ import drjit as dr import pytest +dr.set_log_level(dr.LogLevel.Info) def get_pkg(t): with dr.detail.scoped_rtld_deepbind(): @@ -69,3 +70,65 @@ def test03_cpp_make_opaque(t): pkg.cpp_make_opaque(holder) assert holder.value().state == dr.VarState.Evaluated + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test04_traverse_opaque(t): + # Tests that it is possible to traverse an opaque C++ object + pkg = get_pkg(t) + print(f"{dir(pkg)=}") + Float = t + + v = dr.arange(Float, 10) + + a = pkg.CustomA(v) + assert dr.detail.collect_indices(a) == [v.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test05_traverse_py(t): + # Tests the implementation of `%raverse_py_cb_ro`, + # used for traversal of PyTrees inside of C++ objects + Float = t + + v = dr.arange(Float, 10) + + class PyClass: + def __init__(self, v) -> None: + self.v = v + + c = PyClass(v) + + result = [] + + def callback(index): + result.append(index) + + dr.detail.traverse_py_cb_ro(c, callback) + + assert result == [v.index] + + +@pytest.test_arrays("float32,-diff,shape=(*),jit") +def test06_trampoline_traversal(t): + # Tests that classes inhereting from trampoline classes are traversed + # automatically + pkg = get_pkg(t) + print(f"{dir(pkg)=}") + Float = t + + v = dr.opaque(Float, 0, 3) + + class B(pkg.CustomBase): + def __init__(self, v) -> None: + super().__init__() + self.v = v + + def value(self): + return self.v + + b = B(v) + + b.value() + + assert dr.detail.collect_indices(b) == [v.index]