-
Notifications
You must be signed in to change notification settings - Fork 44
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TraversableBase intrusive interface #288
base: master
Are you sure you want to change the base?
TraversableBase intrusive interface #288
Conversation
29392b0
to
33ed806
Compare
f2a0e24
to
a4632c1
Compare
b0a1d57
to
5293619
Compare
5293619
to
33dfe2a
Compare
31fb70f
to
ca2680d
Compare
ca2680d
to
4cc7be1
Compare
4cc7be1
to
4a7a773
Compare
@merlinND, @wjakob This PR should be ready for review. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this concise PR! The code looks good to me overall, but I'm not so familiar with this part of the codebase so @wjakob will probably need to take a look as well.
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) { | ||
value.traverse_1_cb_ro(payload, fn); | ||
} else { | ||
// static_assert(false, "Failed to traverse field!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should the user be warned about this, or it's always fine to ignore?
In the latter case, you could remove the else
block and add a comment to say explicitly that we silently return if we don't know how to traverse.
@@ -15,6 +15,7 @@ | |||
|
|||
#pragma once | |||
|
|||
#include <type_traits> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#include <type_traits> | |
#include <type_traits> | |
@@ -54,6 +54,9 @@ | |||
#include <drjit/math.h> | |||
#include <drjit-core/python.h> | |||
#include <nanobind/stl/array.h> | |||
#include <nanobind/intrusive/counter.h> | |||
#include "nanobind/nanobind.h" | |||
#include "drjit/traversable_base.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sort the imports
@@ -41,7 +42,7 @@ enum class CudaTextureFormat : uint32_t { | |||
Float16 = 1, /// Half precision storage format | |||
}; | |||
|
|||
template <typename _Storage, size_t Dimension> class Texture { | |||
template <typename _Storage, size_t Dimension> class Texture : TraversableBase { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there no other DrJit class that needed to be made traversable?
} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) { | ||
value.traverse_1_cb_rw(payload, fn); | ||
} else { | ||
// static_assert(false, "Failed to traverse field!"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above
} | ||
|
||
/// Macro to generate traverse_1_cb_ro and traverse_1_cb_rw methods for each | ||
/// member in the list. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Expand the docstring to explain to users what they should do with this, which fields should be included or now, etc.
Add a warning to tell them what can happen if they fail to include a member in the list.
Also this information should be repeated in the frozen functions documentation, since it's a dangerous pitfall for users who have custom plugins / classes implemented.
#if defined(_MSC_VER) | ||
#define DRJIT_EXPORT __declspec(dllexport) | ||
#else | ||
#define DRJIT_EXPORT __attribute__((visibility("default"))) | ||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't this already defined somewhere else?
extern void traverse(const char *op, TraverseCallback &callback, nb::handle h, | ||
bool rw = false); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the rw
parameter, explain why it's needed and what the user should pass.
@@ -57,7 +57,7 @@ struct TraverseCallback { | |||
// Type-erased form which is needed in some cases to traverse into opaque | |||
// C++ code. This one just gets called with Jit/AD variable indices, an | |||
// associated Python/ instance/type is not available. | |||
virtual void operator()(uint64_t index); | |||
virtual uint64_t operator()(uint64_t index); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Document the return value.
cb(h, nb::cpp_function([&](uint64_t index) { tc(index); })); | ||
} else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) { | ||
cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); })); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the user passed rw = true
but only get_traverse_cb_ro(tp)
is valid, not get_traverse_cb_rw(tp)
, is that an error? Should we throw?
This PR adds a
TraversableBase
class with thetraverse_1_cb_ro
andtraverse_1_cb_rw
functions and implements thenanobind::intrusive_base
class.The class can be implemented for any sub-class that should be traversable, and the functions can be generated by the
DR_TRAVERSE
macros.It also changes the
TraverseCallback::operator()
function, which now returns auint64_t
. This is required to implement amake_opaque
function, that is able to traverseC++
objects.