-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraversableBase intrusive interface #288
base: master
Are you sure you want to change the base?
Changes from all commits
798740b
00857c1
5038256
fa974fa
d0fc270
4efa918
4a7a773
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,6 +15,7 @@ | |
|
||
#pragma once | ||
|
||
#include <type_traits> | ||
#define DRJIT_STRUCT_NODEF(Name, ...) \ | ||
Name(const Name &) = default; \ | ||
Name(Name &&) = default; \ | ||
|
@@ -140,6 +141,18 @@ namespace detail { | |
using det_traverse_1_cb_rw = | ||
decltype(T(nullptr)->traverse_1_cb_rw(nullptr, nullptr)); | ||
|
||
template <typename T> | ||
using det_get = decltype(std::declval<T&>().get()); | ||
|
||
template <typename T> | ||
using det_const_get = decltype(std::declval<const T &>().get()); | ||
|
||
template<typename T> | ||
using det_begin = decltype(std::declval<T &>().begin()); | ||
|
||
template<typename T> | ||
using det_end = decltype(std::declval<T &>().begin()); | ||
|
||
inline drjit::string get_label(const char *s, size_t i) { | ||
auto skip = [](char c) { | ||
return c == ' ' || c == '\r' || c == '\n' || c == '\t' || c == ','; | ||
|
@@ -198,6 +211,19 @@ void traverse_1_fn_ro(const Value &value, void *payload, void (*fn)(void *, uint | |
is_detected_v<detail::det_traverse_1_cb_ro, Value>) { | ||
if (value) | ||
value->traverse_1_cb_ro(payload, fn); | ||
|
||
} else if constexpr (is_detected_v<detail::det_begin, Value> && | ||
is_detected_v<detail::det_end, Value>) { | ||
for (auto elem : value) { | ||
traverse_1_fn_ro(elem, payload, fn); | ||
} | ||
} else if constexpr (is_detected_v<detail::det_const_get, Value>) { | ||
const auto *tmp = value.get(); | ||
traverse_1_fn_ro(tmp, payload, fn); | ||
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) { | ||
value.traverse_1_cb_ro(payload, fn); | ||
} else { | ||
// static_assert(false, "Failed to traverse field!"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the user be warned about this, or it's always fine to ignore? |
||
} | ||
} | ||
|
||
|
@@ -220,6 +246,18 @@ void traverse_1_fn_rw(Value &value, void *payload, uint64_t (*fn)(void *, uint64 | |
is_detected_v<detail::det_traverse_1_cb_rw, Value>) { | ||
if (value) | ||
value->traverse_1_cb_rw(payload, fn); | ||
} else if constexpr (is_detected_v<detail::det_begin, Value> && | ||
is_detected_v<detail::det_end, Value>) { | ||
for (auto elem : value) { | ||
traverse_1_fn_rw(elem, payload, fn); | ||
} | ||
} else if constexpr (is_detected_v<detail::det_get, Value>) { | ||
auto *tmp = value.get(); | ||
traverse_1_fn_rw(tmp, payload, fn); | ||
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) { | ||
value.traverse_1_cb_rw(payload, fn); | ||
} else { | ||
// static_assert(false, "Failed to traverse field!"); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above |
||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,6 +54,9 @@ | |
#include <drjit/math.h> | ||
#include <drjit-core/python.h> | ||
#include <nanobind/stl/array.h> | ||
#include <nanobind/intrusive/counter.h> | ||
#include "nanobind/nanobind.h" | ||
#include "drjit/traversable_base.h" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sort the imports |
||
|
||
NAMESPACE_BEGIN(drjit) | ||
struct ArrayBinding; | ||
|
@@ -1076,4 +1079,35 @@ template <typename T, typename... Args> auto& bind_traverse(nanobind::class_<T, | |
return cls; | ||
} | ||
|
||
inline void traverse_py_cb_ro(const TraversableBase *base, void *payload, | ||
void (*fn)(void *, uint64_t)) { | ||
namespace nb = nanobind; | ||
nb::handle self = base->self_py(); | ||
if (!self) | ||
return; | ||
|
||
auto detail = nb::module_::import_("drjit.detail"); | ||
nb::callable traverse_py_cb_ro = | ||
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_ro")); | ||
|
||
traverse_py_cb_ro( | ||
self, nb::cpp_function([&](uint64_t index) { fn(payload, index); })); | ||
} | ||
inline void traverse_py_cb_rw(TraversableBase *base, void *payload, | ||
uint64_t (*fn)(void *, uint64_t)) { | ||
|
||
namespace nb = nanobind; | ||
nb::handle self = base->self_py(); | ||
if (!self) | ||
return; | ||
|
||
auto detail = nb::module_::import_("drjit.detail"); | ||
nb::callable traverse_py_cb_rw = | ||
nb::borrow<nb::callable>(nb::getattr(detail, "traverse_py_cb_rw")); | ||
|
||
traverse_py_cb_rw(self, nb::cpp_function([&](uint64_t index) { | ||
return fn(payload, index); | ||
})); | ||
} | ||
|
||
NAMESPACE_END(drjit) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#include <drjit/idiv.h> | ||
#include <drjit/jit.h> | ||
#include <drjit/tensor.h> | ||
#include "drjit/traversable_base.h" | ||
|
||
#pragma once | ||
|
||
|
@@ -41,7 +42,7 @@ enum class CudaTextureFormat : uint32_t { | |
Float16 = 1, /// Half precision storage format | ||
}; | ||
|
||
template <typename _Storage, size_t Dimension> class Texture { | ||
template <typename _Storage, size_t Dimension> class Texture : TraversableBase { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there no other DrJit class that needed to be made traversable? |
||
public: | ||
static constexpr bool IsCUDA = is_cuda_v<_Storage>; | ||
static constexpr bool IsDiff = is_diff_v<_Storage>; | ||
|
@@ -1386,6 +1387,9 @@ template <typename _Storage, size_t Dimension> class Texture { | |
WrapMode m_wrap_mode; | ||
bool m_use_accel = false; | ||
mutable bool m_migrated = false; | ||
|
||
DR_TRAVERSE_CB(drjit::TraversableBase, m_value, m_shape_opaque, | ||
m_inv_resolution); | ||
}; | ||
|
||
NAMESPACE_END(drjit) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
#pragma once | ||
|
||
#include "array_traverse.h" | ||
#include "drjit-core/macros.h" | ||
#include "nanobind/intrusive/counter.h" | ||
#include "nanobind/intrusive/ref.h" | ||
#include <drjit-core/jit.h> | ||
#include <drjit/map.h> | ||
|
||
NAMESPACE_BEGIN(drjit) | ||
|
||
/// Interface for traversing C++ objects. | ||
struct TraversableBase : nanobind::intrusive_base { | ||
virtual void traverse_1_cb_ro(void *, void (*)(void *, uint64_t)) const = 0; | ||
virtual void traverse_1_cb_rw(void *, uint64_t (*)(void *, uint64_t)) = 0; | ||
}; | ||
|
||
/// Macro for generating call to traverse_1_fn_ro for a class member | ||
#define DR_TRAVERSE_MEMBER_RO(member) \ | ||
drjit::log_member_open(false, #member); \ | ||
drjit::traverse_1_fn_ro(member, payload, fn); \ | ||
drjit::log_member_close(); | ||
/// Macro for generating call to traverse_1_fn_rw for a class member | ||
#define DR_TRAVERSE_MEMBER_RW(member) \ | ||
drjit::log_member_open(true, #member); \ | ||
drjit::traverse_1_fn_rw(member, payload, fn); \ | ||
drjit::log_member_close(); | ||
|
||
inline void log_member_open(bool rw, const char *member) { | ||
jit_log(LogLevel::Debug, "%s%s{", rw ? "rw " : "ro ", member); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You could guard the logging with a
|
||
} | ||
|
||
inline void log_member_close() { jit_log(LogLevel::Debug, "}"); } | ||
|
||
#define DR_TRAVERSE_CB_RO(Base, ...) \ | ||
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \ | ||
const override { \ | ||
if constexpr (!std::is_same_v<Base, drjit::TraversableBase>) \ | ||
Base::traverse_1_cb_ro(payload, fn); \ | ||
DRJIT_MAP(DR_TRAVERSE_MEMBER_RO, __VA_ARGS__) \ | ||
} | ||
|
||
#define DR_TRAVERSE_CB_RW(Base, ...) \ | ||
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \ | ||
override { \ | ||
if constexpr (!std::is_same_v<Base, drjit::TraversableBase>) \ | ||
Base::traverse_1_cb_rw(payload, fn); \ | ||
DRJIT_MAP(DR_TRAVERSE_MEMBER_RW, __VA_ARGS__) \ | ||
} | ||
|
||
/// Macro to generate traverse_1_cb_ro and traverse_1_cb_rw methods for each | ||
/// member in the list. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Expand the docstring to explain to users what they should do with this, which fields should be included or now, etc. Also this information should be repeated in the frozen functions documentation, since it's a dangerous pitfall for users who have custom plugins / classes implemented. |
||
#define DR_TRAVERSE_CB(Base, ...) \ | ||
public: \ | ||
DR_TRAVERSE_CB_RO(Base, __VA_ARGS__) \ | ||
DR_TRAVERSE_CB_RW(Base, __VA_ARGS__) | ||
|
||
#define DR_TRAMPOLINE_TRAVERSE_CB(Base) \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing the docstring. |
||
public: \ | ||
void traverse_1_cb_ro(void *payload, void (*fn)(void *, uint64_t)) \ | ||
const override { \ | ||
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \ | ||
Base ::traverse_1_cb_ro(payload, fn); \ | ||
drjit::traverse_py_cb_ro(this, payload, fn); \ | ||
} \ | ||
void traverse_1_cb_rw(void *payload, uint64_t (*fn)(void *, uint64_t)) \ | ||
override { \ | ||
if constexpr (!std ::is_same_v<Base, drjit ::TraversableBase>) \ | ||
Base ::traverse_1_cb_rw(payload, fn); \ | ||
drjit::traverse_py_cb_rw(this, payload, fn); \ | ||
} | ||
|
||
#if defined(_MSC_VER) | ||
#define DRJIT_EXPORT __declspec(dllexport) | ||
#else | ||
#define DRJIT_EXPORT __attribute__((visibility("default"))) | ||
#endif | ||
Comment on lines
+73
to
+77
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't this already defined somewhere else? |
||
|
||
NAMESPACE_END(drjit) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -595,11 +595,12 @@ struct recursion_guard { | |
~recursion_guard() { recursion_level--; } | ||
}; | ||
|
||
void TraverseCallback::operator()(uint64_t) { } | ||
uint64_t TraverseCallback::operator()(uint64_t) { return 0; } | ||
void TraverseCallback::traverse_unknown(nb::handle) { } | ||
|
||
/// Invoke the given callback on leaf elements of the pytree 'h' | ||
void traverse(const char *op, TraverseCallback &tc, nb::handle h) { | ||
void traverse(const char *op, TraverseCallback &tc, nb::handle h, | ||
bool rw) { | ||
nb::handle tp = h.type(); | ||
recursion_guard guard; | ||
|
||
|
@@ -614,30 +615,32 @@ void traverse(const char *op, TraverseCallback &tc, nb::handle h) { | |
len = s.len(inst_ptr(h)); | ||
|
||
for (Py_ssize_t i = 0; i < len; ++i) | ||
traverse(op, tc, nb::steal(s.item(h.ptr(), i))); | ||
traverse(op, tc, nb::steal(s.item(h.ptr(), i)), rw); | ||
} else { | ||
tc(h); | ||
} | ||
} else if (tp.is(&PyTuple_Type)) { | ||
for (nb::handle h2 : nb::borrow<nb::tuple>(h)) | ||
traverse(op, tc, h2); | ||
traverse(op, tc, h2, rw); | ||
} else if (tp.is(&PyList_Type)) { | ||
for (nb::handle h2 : nb::borrow<nb::list>(h)) | ||
traverse(op, tc, h2); | ||
traverse(op, tc, h2, rw); | ||
} else if (tp.is(&PyDict_Type)) { | ||
for (nb::handle h2 : nb::borrow<nb::dict>(h).values()) | ||
traverse(op, tc, h2); | ||
traverse(op, tc, h2, rw); | ||
} else { | ||
if (nb::dict ds = get_drjit_struct(tp); ds.is_valid()) { | ||
for (auto [k, v] : ds) | ||
traverse(op, tc, nb::getattr(h, k)); | ||
traverse(op, tc, nb::getattr(h, k), rw); | ||
} else if (nb::object df = get_dataclass_fields(tp); df.is_valid()) { | ||
for (nb::handle field : df) { | ||
nb::object k = field.attr(DR_STR(name)); | ||
traverse(op, tc, nb::getattr(h, k)); | ||
traverse(op, tc, nb::getattr(h, k), rw); | ||
} | ||
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid()) { | ||
} else if (nb::object cb = get_traverse_cb_ro(tp); cb.is_valid() && !rw) { | ||
cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); | ||
} else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) { | ||
cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the user passed |
||
} else { | ||
tc.traverse_unknown(h); | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,7 +57,7 @@ struct TraverseCallback { | |
// Type-erased form which is needed in some cases to traverse into opaque | ||
// C++ code. This one just gets called with Jit/AD variable indices, an | ||
// associated Python/ instance/type is not available. | ||
virtual void operator()(uint64_t index); | ||
virtual uint64_t operator()(uint64_t index); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the return value. |
||
|
||
// Traverse an unknown object | ||
virtual void traverse_unknown(nb::handle h); | ||
|
@@ -93,8 +93,8 @@ struct TransformPairCallback { | |
}; | ||
|
||
/// Invoke the given callback on leaf elements of the pytree 'h' | ||
extern void traverse(const char *op, TraverseCallback &callback, | ||
nb::handle h); | ||
extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, | ||
bool rw = false); | ||
Comment on lines
+96
to
+97
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Document the |
||
|
||
/// Parallel traversal of two compatible pytrees 'h1' and 'h2' | ||
extern void traverse_pair(const char *op, TraversePairCallback &callback, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.