Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TraversableBase intrusive interface #288

Open
wants to merge 7 commits into
base: master
Choose a base branch
from

Conversation

DoeringChristian
Copy link
Contributor

@DoeringChristian DoeringChristian commented Sep 25, 2024

This PR adds a TraversableBase class with the traverse_1_cb_ro and traverse_1_cb_rw functions and implements the nanobind::intrusive_base class.
The class can be implemented for any sub-class that should be traversable, and the functions can be generated by the DR_TRAVERSE macros.
It also changes the TraverseCallback::operator() function, which now returns a uint64_t. This is required to implement a make_opaque function, that is able to traverse C++ objects.

@DoeringChristian DoeringChristian marked this pull request as draft September 27, 2024 08:44
@DoeringChristian DoeringChristian force-pushed the traversable-base branch 2 times, most recently from f2a0e24 to a4632c1 Compare September 27, 2024 10:59
@DoeringChristian DoeringChristian force-pushed the traversable-base branch 2 times, most recently from b0a1d57 to 5293619 Compare October 18, 2024 09:11
@DoeringChristian DoeringChristian force-pushed the traversable-base branch 2 times, most recently from 31fb70f to ca2680d Compare December 2, 2024 10:25
@DoeringChristian DoeringChristian marked this pull request as ready for review December 5, 2024 19:35
@DoeringChristian
Copy link
Contributor Author

@merlinND, @wjakob This PR should be ready for review.
Note, there is one change from what we discussed last Friday.
I had to add a boolean to the argument of the traverse function, indicating whether traverse_cb_ro or traverse_cb_rw should be called. We cannot use the output of the callback to determine if the variable index should be overwritten, since overwriting it with $0$ might be desired (there is a test that fails in that case).

Copy link
Member

@merlinND merlinND left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @DoeringChristian,

Thanks for this concise PR! The code looks good to me overall, but I'm not so familiar with this part of the codebase so @wjakob will probably need to take a look as well.

} else if constexpr (is_detected_v<detail::det_traverse_1_cb_ro, Value *>) {
value.traverse_1_cb_ro(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the user be warned about this, or it's always fine to ignore?
In the latter case, you could remove the else block and add a comment to say explicitly that we silently return if we don't know how to traverse.

@@ -15,6 +15,7 @@

#pragma once

#include <type_traits>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#include <type_traits>
#include <type_traits>

@@ -54,6 +54,9 @@
#include <drjit/math.h>
#include <drjit-core/python.h>
#include <nanobind/stl/array.h>
#include <nanobind/intrusive/counter.h>
#include "nanobind/nanobind.h"
#include "drjit/traversable_base.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sort the imports

@@ -41,7 +42,7 @@ enum class CudaTextureFormat : uint32_t {
Float16 = 1, /// Half precision storage format
};

template <typename _Storage, size_t Dimension> class Texture {
template <typename _Storage, size_t Dimension> class Texture : TraversableBase {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no other DrJit class that needed to be made traversable?

} else if constexpr (is_detected_v<detail::det_traverse_1_cb_rw, Value *>) {
value.traverse_1_cb_rw(payload, fn);
} else {
// static_assert(false, "Failed to traverse field!");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above

}

/// Macro to generate traverse_1_cb_ro and traverse_1_cb_rw methods for each
/// member in the list.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Expand the docstring to explain to users what they should do with this, which fields should be included or now, etc.
Add a warning to tell them what can happen if they fail to include a member in the list.

Also this information should be repeated in the frozen functions documentation, since it's a dangerous pitfall for users who have custom plugins / classes implemented.

Comment on lines +73 to +77
#if defined(_MSC_VER)
#define DRJIT_EXPORT __declspec(dllexport)
#else
#define DRJIT_EXPORT __attribute__((visibility("default")))
#endif
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this already defined somewhere else?

Comment on lines +96 to +97
extern void traverse(const char *op, TraverseCallback &callback, nb::handle h,
bool rw = false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the rw parameter, explain why it's needed and what the user should pass.

@@ -57,7 +57,7 @@ struct TraverseCallback {
// Type-erased form which is needed in some cases to traverse into opaque
// C++ code. This one just gets called with Jit/AD variable indices, an
// associated Python/ instance/type is not available.
virtual void operator()(uint64_t index);
virtual uint64_t operator()(uint64_t index);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Document the return value.

cb(h, nb::cpp_function([&](uint64_t index) { tc(index); }));
} else if (nb::object cb = get_traverse_cb_rw(tp); cb.is_valid() && rw) {
cb(h, nb::cpp_function([&](uint64_t index) { return tc(index); }));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the user passed rw = true but only get_traverse_cb_ro(tp) is valid, not get_traverse_cb_rw(tp), is that an error? Should we throw?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants