Skip to content

Commit

Permalink
Changed traverse function, allowing assigment of c++ objects
Browse files Browse the repository at this point in the history
  • Loading branch information
DoeringChristian committed Dec 2, 2024
1 parent 33dfe2a commit 9f42220
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 25 deletions.
21 changes: 12 additions & 9 deletions src/python/apply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,12 @@ struct recursion_guard {
~recursion_guard() { recursion_level--; }
};

void TraverseCallback::operator()(uint64_t) { }
uint64_t TraverseCallback::operator()(uint64_t) { return 0; }
void TraverseCallback::traverse_unknown(nb::handle) { }

/// Invoke the given callback on leaf elements of the pytree 'h'
void traverse(const char *op, TraverseCallback &tc, nb::handle h) {
void traverse(const char *op, TraverseCallback &tc, nb::handle h,
bool rw) {
nb::handle tp = h.type();
recursion_guard guard;

Expand All @@ -614,30 +615,32 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) {
len = s.len(inst_ptr(h));

for (Py_ssize_t i = 0; i < len; ++i)
traverse(op, tc, nb::steal(s.item(h.ptr(), i)));
traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw);
} else {
tc(h);
}
} else if (tp.is(&PyTuple_Type)) {
for (nb::handle h2 : nb::borrow<nb::tuple>(h))
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else if (tp.is(&PyList_Type)) {
for (nb::handle h2 : nb::borrow<nb::list>(h))
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else if (tp.is(&PyDict_Type)) {
for (nb::handle h2 : nb::borrow<nb::dict>(h).values())
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else {
if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) {
for (auto [k, v] : ds)
traverse(op, tc, nb::getattr(h, k));
traverse(op, tc, nb::getattr(h, k), rw);
} else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) {
for (nb::handle field : df) {
nb::object k = field.attr(DR_STR(name));
traverse(op, tc, nb::getattr(h, k));
traverse(op, tc, nb::getattr(h, k), rw);
}
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) {
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) {
cb(h, nb::cpp_function([&](uint64_t index) { tc(index); }));
} else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) {
cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); }));
} else {
tc.traverse_unknown(h);
}
Expand Down
6 changes: 3 additions & 3 deletions src/python/apply.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct TraverseCallback {
// Type-erased form which is needed in some cases to traverse into opaque
// C++ code. This one just gets called with Jit/AD variable indices, an
// associated Python/ instance/type is not available.
virtual void operator()(uint64_t index);
virtual uint64_t operator()(uint64_t index);

// Traverse an unknown object
virtual void traverse_unknown(nb::handle h);
Expand Down Expand Up @@ -93,8 +93,8 @@ struct TransformPairCallback {
};

/// Invoke the given callback on leaf elements of the pytree 'h'
extern void traverse(const char *op, TraverseCallback &callback,
nb::handle h);
extern void traverse(const char *op, TraverseCallback &callback, nb::handle h,
bool rw = false);

/// Parallel traversal of two compatible pytrees 'h1' and 'h2'
extern void traverse_pair(const char *op, TraversePairCallback &callback,
Expand Down
15 changes: 8 additions & 7 deletions src/python/detail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ void collect_indices(nb::handle h, dr::vector<uint64_t> &indices, bool inc_ref)
operator()(index_fn(inst_ptr(h)));
}

void operator()(uint64_t index) override {
uint64_t operator()(uint64_t index) override {
if (inc_ref)
ad_var_inc_ref(index);
result.push_back(index);
return 0;
}
};

Expand Down Expand Up @@ -288,7 +289,7 @@ void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) {
if (index_fn)
operator()(index_fn(inst_ptr(h)));
}
void operator()(uint64_t index) override { m_callback(index); }
uint64_t operator()(uint64_t index) override { m_callback(index); return 0; }
nb::callable m_callback;

PyTraverseCallback(nb::callable c) : m_callback(c) {}
Expand All @@ -304,11 +305,11 @@ void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) {
}

void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) {
struct PyTraverseCallback : TransformCallback {
void operator()(nb::handle h1, nb::handle h2) override {
const ArraySupplement &s = supp(h1.type());
struct PyTraverseCallback : TraverseCallback {
void operator()(nb::handle h) override {
const ArraySupplement &s = supp(h.type());
if (s.index)
s.init_index(operator()(s.index(inst_ptr(h1))), inst_ptr(h2));
s.reset_index(operator()(s.index(inst_ptr(h))), inst_ptr(h));
}
uint64_t operator()(uint64_t index) override {
return nb::cast<uint64_t>(m_callback(index));
Expand All @@ -323,7 +324,7 @@ void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) {
auto dict = nb::borrow<nb::dict>(nb::getattr(self, "__dict__"));

for (auto value : dict.values()) {
transform("traverse_py_cb_rw", traverse_cb, value);
traverse("traverse_py_cb_rw", traverse_cb, value);
}
}

Expand Down
8 changes: 2 additions & 6 deletions src/python/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,8 @@ static void make_opaque(nb::handle h) {
if (rv)
result = true;

if (index != index_new) {
nb::object tmp = nb::inst_alloc(tp);
s.init_index(index_new, inst_ptr(tmp));
nb::inst_mark_ready(tmp);
nb::inst_replace_move(h, tmp);
}
if (index != index_new)
s.reset_index(index_new, inst_ptr(h));

ad_var_dec_ref(index_new);
}
Expand Down

0 comments on commit 9f42220

Please sign in to comment.