Skip to content

Commit

Permalink
Added traversal tests
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Dec 2, 2024
1 parent 4efa918 commit 4a7a773
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 0 deletions.
78 changes: 78 additions & 0 deletions tests/custom_type_ext.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
#include <drjit/python.h>
#include <drjit/autodiff.h>
#include <drjit/packet.h>
#include <drjit/traversable_base.h>
#include <nanobind/nanobind.h>
#include <nanobind/trampoline.h>

namespace nb = nanobind;
namespace dr = drjit;
Expand Down Expand Up @@ -42,6 +45,51 @@ struct CustomHolder {
Value m_value;
};

class Object : public drjit::TraversableBase {
DR_TRAVERSE_CB(drjit::TraversableBase);
};

template <typename Value>
class CustomBase : public Object{
public:
CustomBase() : Object() {}

virtual Value &value() {
jit_raise("test");
};

DR_TRAVERSE_CB(Object);
};

template <typename Value>
class PyCustomBase : public CustomBase<Value>{
public:
using Base = CustomBase<Value>;
NB_TRAMPOLINE(Base, 1);

PyCustomBase() : Base() {}

Value &value() override { NB_OVERRIDE_PURE(value); }

DR_TRAMPOLINE_TRAVERSE_CB(Base);
};

template <typename Value>
class CustomA: public CustomBase<Value>{
public:
using Base = CustomBase<Value>;

CustomA() {}
CustomA(const Value &v) : m_value(v) {}

Value &value() override { return m_value; }

private:
Value m_value;

DR_TRAVERSE_CB(Base, m_value);
};


template <JitBackend Backend> void bind(nb::module_ &m) {
dr::ArrayBinding b;
Expand All @@ -64,12 +112,42 @@ template <JitBackend Backend> void bind(nb::module_ &m) {
.def(nb::init<Float>())
.def("value", &CustomFloatHolder::value, nanobind::rv_policy::reference);

using CustomBase = CustomBase<Float>;
using PyCustomBase = PyCustomBase<Float>;
using CustomA = CustomA<Float>;

auto object = nb::class_<Object>(
m, "Object",
nb::intrusive_ptr<Object>(
[](Object *o, PyObject *po) noexcept { o->set_self_py(po); }));

auto base = nb::class_<CustomBase, Object, PyCustomBase>(m, "CustomBase")
.def(nb::init())
.def("value", nb::overload_cast<>(&CustomBase::value));
jit_log(LogLevel::Debug, "binding base");

drjit::bind_traverse(base);

auto a = nb::class_<CustomA>(m, "CustomA").def(nb::init<Float>());

drjit::bind_traverse(a);

m.def("cpp_make_opaque",
[](CustomFloatHolder &holder) { dr::make_opaque(holder); }
);
}

NB_MODULE(custom_type_ext, m) {
nb::intrusive_init(
[](PyObject *o) noexcept {
nb::gil_scoped_acquire guard;
Py_INCREF(o);
},
[](PyObject *o) noexcept {
nb::gil_scoped_acquire guard;
Py_DECREF(o);
});

#if defined(DRJIT_ENABLE_LLVM)
nb::module_ llvm = m.def_submodule("llvm");
bind<JitBackend::LLVM>(llvm);
Expand Down
63 changes: 63 additions & 0 deletions tests/test_custom_type_ext.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import drjit as dr
import pytest

dr.set_log_level(dr.LogLevel.Info)

def get_pkg(t):
with dr.detail.scoped_rtld_deepbind():
Expand Down Expand Up @@ -69,3 +70,65 @@ def test03_cpp_make_opaque(t):

pkg.cpp_make_opaque(holder)
assert holder.value().state == dr.VarState.Evaluated


@pytest.test_arrays("float32,-diff,shape=(*),jit")
def test04_traverse_opaque(t):
# Tests that it is possible to traverse an opaque C++ object
pkg = get_pkg(t)
print(f"{dir(pkg)=}")
Float = t

v = dr.arange(Float, 10)

a = pkg.CustomA(v)
assert dr.detail.collect_indices(a) == [v.index]


@pytest.test_arrays("float32,-diff,shape=(*),jit")
def test05_traverse_py(t):
# Tests the implementation of `%raverse_py_cb_ro`,
# used for traversal of PyTrees inside of C++ objects
Float = t

v = dr.arange(Float, 10)

class PyClass:
def __init__(self, v) -> None:
self.v = v

c = PyClass(v)

result = []

def callback(index):
result.append(index)

dr.detail.traverse_py_cb_ro(c, callback)

assert result == [v.index]


@pytest.test_arrays("float32,-diff,shape=(*),jit")
def test06_trampoline_traversal(t):
# Tests that classes inhereting from trampoline classes are traversed
# automatically
pkg = get_pkg(t)
print(f"{dir(pkg)=}")
Float = t

v = dr.opaque(Float, 0, 3)

class B(pkg.CustomBase):
def __init__(self, v) -> None:
super().__init__()
self.v = v

def value(self):
return self.v

b = B(v)

b.value()

assert dr.detail.collect_indices(b) == [v.index]

0 comments on commit 4a7a773

Please sign in to comment.