diff --git a/src/python/apply.cpp b/src/python/apply.cpp index 3216e61e..f8d6ea39 100644 --- a/src/python/apply.cpp +++ b/src/python/apply.cpp @@ -595,11 +595,12 @@ struct recursion_guard { ~recursion_guard() { recursion_level--; } }; -void TraverseCallback::operator()(uint64_t) { } +uint64_t TraverseCallback::operator()(uint64_t) { return 0; } void TraverseCallback::traverse_unknown(nb::handle) { } /// Invoke the given callback on leaf elements of the pytree 'h' -void traverse(const char *op, TraverseCallback &tc, nb::handle h) { +void traverse(const char *op, TraverseCallback &tc, nb::handle h, + bool rw) { nb::handle tp = h.type(); recursion_guard guard; @@ -614,30 +615,32 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { len = s.len(inst_ptr(h)); for (Py_ssize_t i = 0; i < len; ++i) - traverse(op, tc, nb::steal(s.item(h.ptr(), i))); + traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw); } else { tc(h); } } else if (tp.is(&PyTuple_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyList_Type)) { for (nb::handle h2 : nb::borrow(h)) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else if (tp.is(&PyDict_Type)) { for (nb::handle h2 : nb::borrow(h).values()) - traverse(op, tc, h2); + traverse(op, tc, h2, rw); } else { if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { for (auto [k, v] : ds) - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { for (nb::handle field : df) { nb::object k = field.attr(DR_STR(name)); - traverse(op, tc, nb::getattr(h, k)); + traverse(op, tc, nb::getattr(h, k), rw); } - } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { + } else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) { cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); + } else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) { + cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); } else { tc.traverse_unknown(h); } diff --git a/src/python/apply.h b/src/python/apply.h index 8e57f6cb..df9e0c4b 100644 --- a/src/python/apply.h +++ b/src/python/apply.h @@ -57,7 +57,7 @@ struct TraverseCallback { // Type-erased form which is needed in some cases to traverse into opaque // C++ code. This one just gets called with Jit/AD variable indices, an // associated Python/ instance/type is not available. - virtual void operator()(uint64_t index); + virtual uint64_t operator()(uint64_t index); // Traverse an unknown object virtual void traverse_unknown(nb::handle h); @@ -93,8 +93,8 @@ struct TransformPairCallback { }; /// Invoke the given callback on leaf elements of the pytree 'h' -extern void traverse(const char *op, TraverseCallback &callback, - nb::handle h); +extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, + bool rw = false); /// Parallel traversal of two compatible pytrees 'h1' and 'h2' extern void traverse_pair(const char *op, TraversePairCallback &callback, diff --git a/src/python/detail.cpp b/src/python/detail.cpp index 60150ba3..f929c3fe 100644 --- a/src/python/detail.cpp +++ b/src/python/detail.cpp @@ -114,10 +114,11 @@ void collect_indices(nb::handle h, dr::vector &indices, bool inc_ref) operator()(index_fn(inst_ptr(h))); } - void operator()(uint64_t index) override { + uint64_t operator()(uint64_t index) override { if (inc_ref) ad_var_inc_ref(index); result.push_back(index); + return 0; } }; @@ -288,7 +289,7 @@ void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { if (index_fn) operator()(index_fn(inst_ptr(h))); } - void operator()(uint64_t index) override { m_callback(index); } + uint64_t operator()(uint64_t index) override { m_callback(index); return 0; } nb::callable m_callback; PyTraverseCallback(nb::callable c) : m_callback(c) {} @@ -304,11 +305,11 @@ void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) { } void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) { - struct PyTraverseCallback : TransformCallback { - void operator()(nb::handle h1, nb::handle h2) override { - const ArraySupplement &s = supp(h1.type()); + struct PyTraverseCallback : TraverseCallback { + void operator()(nb::handle h) override { + const ArraySupplement &s = supp(h.type()); if (s.index) - s.init_index(operator()(s.index(inst_ptr(h1))), inst_ptr(h2)); + s.reset_index(operator()(s.index(inst_ptr(h))), inst_ptr(h)); } uint64_t operator()(uint64_t index) override { return nb::cast(m_callback(index)); @@ -323,7 +324,7 @@ void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) { auto dict = nb::borrow(nb::getattr(self, "__dict__")); for (auto value : dict.values()) { - transform("traverse_py_cb_rw", traverse_cb, value); + traverse("traverse_py_cb_rw", traverse_cb, value); } } diff --git a/src/python/eval.cpp b/src/python/eval.cpp index 751b1d68..098dac0e 100644 --- a/src/python/eval.cpp +++ b/src/python/eval.cpp @@ -57,12 +57,8 @@ static void make_opaque(nb::handle h) { if (rv) result = true; - if (index != index_new) { - nb::object tmp = nb::inst_alloc(tp); - s.init_index(index_new, inst_ptr(tmp)); - nb::inst_mark_ready(tmp); - nb::inst_replace_move(h, tmp); - } + if (index != index_new) + s.reset_index(index_new, inst_ptr(h)); ad_var_dec_ref(index_new); }