Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraversableBase intrusive interface #288

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions include/drjit/array_traverse.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#pragma once

#include <type_traits>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <type_traits>
#include <type_traits>

#define DRJIT_STRUCT_NODEF(Name, ...) \
Name(const Name &) = default; \
Name(Name &&) = default; \
Expand Down Expand Up @@ -140,6 +141,18 @@ namespace detail {
using det_traverse_1_cb_rw =
decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr));

template <typename T>
using det_get = decltype(std::declval<T&>().get());

template <typename T>
using det_const_get = decltype(std::declval<const T &>().get());

template<typename T>
using det_begin = decltype(std::declval<T &>().begin());

template<typename T>
using det_end = decltype(std::declval<T &>().begin());

inline drjit::string get_label(const char *s, size_t i) {
auto skip = [](char c) {
return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ',';
Expand Down Expand Up @@ -198,6 +211,19 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint
is_detected_v<detail::det_traverse_1_cb_ro, Value>) {
if (value)
value->traverse_1_cb_ro(payload, fn);

} else if constexpr (is_detected_v<detail::det_begin, Value> &&
is_detected_v<detail::det_end, Value>) {
for (auto elem : value) {
traverse_1_fn_ro(elem, payload, fn);
}
} else if constexpr (is_detected_v<detail::det_const_get, Value>) {
const auto *tmp = value.get();
traverse_1_fn_ro(tmp, payload, fn);
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
value.traverse_1_cb_ro(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the user be warned about this, or it's always fine to ignore?
In the latter case, you could remove the else block and add a comment to say explicitly that we silently return if we don't know how to traverse.

}
}

Expand All @@ -220,6 +246,18 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64
is_detected_v<detail::det_traverse_1_cb_rw, Value>) {
if (value)
value->traverse_1_cb_rw(payload, fn);
} else if constexpr (is_detected_v<detail::det_begin, Value> &&
is_detected_v<detail::det_end, Value>) {
for (auto elem : value) {
traverse_1_fn_rw(elem, payload, fn);
}
} else if constexpr (is_detected_v<detail::det_get, Value>) {
auto *tmp = value.get();
traverse_1_fn_rw(tmp, payload, fn);
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
value.traverse_1_cb_rw(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

}
}

Expand Down
34 changes: 34 additions & 0 deletions include/drjit/python.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,9 @@
#include <drjit/math.h>
#include <drjit-core/python.h>
#include <nanobind/stl/array.h>
#include <nanobind/intrusive/counter.h>
#include "nanobind/nanobind.h"
#include "drjit/traversable_base.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort the imports


NAMESPACE_BEGIN(drjit)
struct ArrayBinding;
Expand Down Expand Up @@ -1076,4 +1079,35 @@ template <typename T, typename... Args> auto& bind_traverse(nanobind::class_<T,
return cls;
}

inline void traverse_py_cb_ro(const TraversableBase *base, void *payload,
void (*fn)(void *, uint64_t)) {
namespace nb = nanobind;
nb::handle self = base->self_py();
if (!self)
return;

auto detail = nb::module_::import_("drjit.detail");
nb::callable traverse_py_cb_ro =
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_ro"));

traverse_py_cb_ro(
self, nb::cpp_function([&](uint64_t index) { fn(payload, index); }));
}
inline void traverse_py_cb_rw(TraversableBase *base, void *payload,
uint64_t (*fn)(void *, uint64_t)) {

namespace nb = nanobind;
nb::handle self = base->self_py();
if (!self)
return;

auto detail = nb::module_::import_("drjit.detail");
nb::callable traverse_py_cb_rw =
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_rw"));

traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index) {
return fn(payload, index);
}));
}

NAMESPACE_END(drjit)
6 changes: 5 additions & 1 deletion include/drjit/texture.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <drjit/idiv.h>
#include <drjit/jit.h>
#include <drjit/tensor.h>
#include "drjit/traversable_base.h"

#pragma once

Expand All @@ -41,7 +42,7 @@ enum class CudaTextureFormat : uint32_t {
Float16 = 1, /// Half precision storage format
};

template <typename _Storage, size_t Dimension> class Texture {
template <typename _Storage, size_t Dimension> class Texture : TraversableBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no other DrJit class that needed to be made traversable?

public:
static constexpr bool IsCUDA = is_cuda_v<_Storage>;
static constexpr bool IsDiff = is_diff_v<_Storage>;
Expand Down Expand Up @@ -1386,6 +1387,9 @@ template <typename _Storage, size_t Dimension> class Texture {
WrapMode m_wrap_mode;
bool m_use_accel = false;
mutable bool m_migrated = false;

DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque,
m_inv_resolution);
};

NAMESPACE_END(drjit)
79 changes: 79 additions & 0 deletions include/drjit/traversable_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
#pragma once

#include "array_traverse.h"
#include "drjit-core/macros.h"
#include "nanobind/intrusive/counter.h"
#include "nanobind/intrusive/ref.h"
#include <drjit-core/jit.h>
#include <drjit/map.h>

NAMESPACE_BEGIN(drjit)

/// Interface for traversing C++ objects.
struct TraversableBase : nanobind::intrusive_base {
virtual void traverse_1_cb_ro(void *, void (*)(void *, uint64_t)) const = 0;
virtual void traverse_1_cb_rw(void *, uint64_t (*)(void *, uint64_t)) = 0;
};

/// Macro for generating call to traverse_1_fn_ro for a class member
#define DR_TRAVERSE_MEMBER_RO(member) \
drjit::log_member_open(false, #member); \
drjit::traverse_1_fn_ro(member, payload, fn); \
drjit::log_member_close();
/// Macro for generating call to traverse_1_fn_rw for a class member
#define DR_TRAVERSE_MEMBER_RW(member) \
drjit::log_member_open(true, #member); \
drjit::traverse_1_fn_rw(member, payload, fn); \
drjit::log_member_close();

inline void log_member_open(bool rw, const char *member) {
jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could guard the logging with a define at the top of the file, something like:

// #define DR_LOG_TRAVERSALS

...

#if defined(DR_LOG_TRAVERSALS)
    jit_log(...)
#endif

}

inline void log_member_close() { jit_log(LogLevel::Debug, "}"); }

#define DR_TRAVERSE_CB_RO(Base, ...) \
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \
const override { \
if constexpr (!std::is_same_v<Base, drjit::TraversableBase>) \
Base::traverse_1_cb_ro(payload, fn); \
DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \
}

#define DR_TRAVERSE_CB_RW(Base, ...) \
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \
override { \
if constexpr (!std::is_same_v<Base, drjit::TraversableBase>) \
Base::traverse_1_cb_rw(payload, fn); \
DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \
}

/// Macro to generate traverse_1_cb_ro and traverse_1_cb_rw methods for each
/// member in the list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand the docstring to explain to users what they should do with this, which fields should be included or now, etc.
Add a warning to tell them what can happen if they fail to include a member in the list.

Also this information should be repeated in the frozen functions documentation, since it's a dangerous pitfall for users who have custom plugins / classes implemented.

#define DR_TRAVERSE_CB(Base, ...) \
public: \
DR_TRAVERSE_CB_RO(Base, __VA_ARGS__) \
DR_TRAVERSE_CB_RW(Base, __VA_ARGS__)

#define DR_TRAMPOLINE_TRAVERSE_CB(Base) \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing the docstring.

public: \
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \
const override { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_ro(payload, fn); \
drjit::traverse_py_cb_ro(this, payload, fn); \
} \
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \
override { \
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \
Base ::traverse_1_cb_rw(payload, fn); \
drjit::traverse_py_cb_rw(this, payload, fn); \
}

#if defined(_MSC_VER)
#define DRJIT_EXPORT __declspec(dllexport)
#else
#define DRJIT_EXPORT __attribute__((visibility("default")))
#endif
Comment on lines +73 to +77
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already defined somewhere else?


NAMESPACE_END(drjit)
21 changes: 12 additions & 9 deletions src/python/apply.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,11 +595,12 @@ struct recursion_guard {
~recursion_guard() { recursion_level--; }
};

void TraverseCallback::operator()(uint64_t) { }
uint64_t TraverseCallback::operator()(uint64_t) { return 0; }
void TraverseCallback::traverse_unknown(nb::handle) { }

/// Invoke the given callback on leaf elements of the pytree 'h'
void traverse(const char *op, TraverseCallback &tc, nb::handle h) {
void traverse(const char *op, TraverseCallback &tc, nb::handle h,
bool rw) {
nb::handle tp = h.type();
recursion_guard guard;

Expand All @@ -614,30 +615,32 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) {
len = s.len(inst_ptr(h));

for (Py_ssize_t i = 0; i < len; ++i)
traverse(op, tc, nb::steal(s.item(h.ptr(), i)));
traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw);
} else {
tc(h);
}
} else if (tp.is(&PyTuple_Type)) {
for (nb::handle h2 : nb::borrow<nb::tuple>(h))
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else if (tp.is(&PyList_Type)) {
for (nb::handle h2 : nb::borrow<nb::list>(h))
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else if (tp.is(&PyDict_Type)) {
for (nb::handle h2 : nb::borrow<nb::dict>(h).values())
traverse(op, tc, h2);
traverse(op, tc, h2, rw);
} else {
if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) {
for (auto [k, v] : ds)
traverse(op, tc, nb::getattr(h, k));
traverse(op, tc, nb::getattr(h, k), rw);
} else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) {
for (nb::handle field : df) {
nb::object k = field.attr(DR_STR(name));
traverse(op, tc, nb::getattr(h, k));
traverse(op, tc, nb::getattr(h, k), rw);
}
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) {
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) {
cb(h, nb::cpp_function([&](uint64_t index) { tc(index); }));
} else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) {
cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); }));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user passed rw = true but only get_traverse_cb_ro(tp) is valid, not get_traverse_cb_rw(tp), is that an error? Should we throw?

} else {
tc.traverse_unknown(h);
}
Expand Down
6 changes: 3 additions & 3 deletions src/python/apply.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ struct TraverseCallback {
// Type-erased form which is needed in some cases to traverse into opaque
// C++ code. This one just gets called with Jit/AD variable indices, an
// associated Python/ instance/type is not available.
virtual void operator()(uint64_t index);
virtual uint64_t operator()(uint64_t index);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the return value.


// Traverse an unknown object
virtual void traverse_unknown(nb::handle h);
Expand Down Expand Up @@ -93,8 +93,8 @@ struct TransformPairCallback {
};

/// Invoke the given callback on leaf elements of the pytree 'h'
extern void traverse(const char *op, TraverseCallback &callback,
nb::handle h);
extern void traverse(const char *op, TraverseCallback &callback, nb::handle h,
bool rw = false);
Comment on lines +96 to +97
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the rw parameter, explain why it's needed and what the user should pass.


/// Parallel traversal of two compatible pytrees 'h1' and 'h2'
extern void traverse_pair(const char *op, TraversePairCallback &callback,
Expand Down
50 changes: 49 additions & 1 deletion src/python/detail.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,11 @@ void collect_indices(nb::handle h, dr::vector<uint64_t> &indices, bool inc_ref)
operator()(index_fn(inst_ptr(h)));
}

void operator()(uint64_t index) override {
uint64_t operator()(uint64_t index) override {
if (inc_ref)
ad_var_inc_ref(index);
result.push_back(index);
return 0;
}
};

Expand Down Expand Up @@ -281,6 +282,51 @@ bool leak_warnings() {
return nb::leak_warnings() || jit_leak_warnings() || ad_leak_warnings();
}

void traverse_py_cb_ro_impl(nb::handle self, nb::callable c) {
struct PyTraverseCallback : TraverseCallback {
void operator()(nb::handle h) override {
auto index_fn = supp(h.type()).index;
if (index_fn)
operator()(index_fn(inst_ptr(h)));
}
uint64_t operator()(uint64_t index) override { m_callback(index); return 0; }
nb::callable m_callback;

PyTraverseCallback(nb::callable c) : m_callback(c) {}
};

PyTraverseCallback traverse_cb(std::move(c));

auto dict = nb::borrow<nb::dict>(nb::getattr(self, "__dict__"));

for (auto value : dict.values()) {
traverse("traverse_py_cb_ro", traverse_cb, value);
}
}

void traverse_py_cb_rw_impl(nb::handle self, nb::callable c) {
struct PyTraverseCallback : TraverseCallback {
void operator()(nb::handle h) override {
const ArraySupplement &s = supp(h.type());
if (s.index)
s.reset_index(operator()(s.index(inst_ptr(h))), inst_ptr(h));
}
uint64_t operator()(uint64_t index) override {
return nb::cast<uint64_t>(m_callback(index));
}
nb::callable m_callback;

PyTraverseCallback(nb::callable c) : m_callback(c) {}
};

PyTraverseCallback traverse_cb(std::move(c));

auto dict = nb::borrow<nb::dict>(nb::getattr(self, "__dict__"));

for (auto value : dict.values()) {
traverse("traverse_py_cb_rw", traverse_cb, value);
}
}

void export_detail(nb::module_ &) {
nb::module_ d = nb::module_::import_("drjit.detail");
Expand Down Expand Up @@ -344,6 +390,8 @@ void export_detail(nb::module_ &) {

d.def("leak_warnings", &leak_warnings, doc_leak_warnings);
d.def("set_leak_warnings", &set_leak_warnings, doc_set_leak_warnings);
d.def("traverse_py_cb_ro", &traverse_py_cb_ro_impl);
d.def("traverse_py_cb_rw", traverse_py_cb_rw_impl);

trace_func_handle = d.attr("trace_func");
}
8 changes: 2 additions & 6 deletions src/python/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,8 @@ static void make_opaque(nb::handle h) {
if (rv)
result = true;

if (index != index_new) {
nb::object tmp = nb::inst_alloc(tp);
s.init_index(index_new, inst_ptr(tmp));
nb::inst_mark_ready(tmp);
nb::inst_replace_move(h, tmp);
}
if (index != index_new)
s.reset_index(index_new, inst_ptr(h));

ad_var_dec_ref(index_new);
}
Expand Down
Loading