Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace functools.partial with jax.tree_util.Partial #39

Closed
wants to merge 1 commit into from

Conversation

eadadi
Copy link

@eadadi eadadi commented Sep 26, 2024

Updated various buffer files to replace functools.partial with jax.tree_util.Partial for consistency and improved functionality.

The motivation is to be able to use jax transformations over buffers.

For example, before this patch, this wasn't working:

    buffer = fbx.make_trajectory_buffer(**cfg)
    buffer = jax.device_put(buffer, jax.devices("cpu")[0])

Updated various buffer files to replace functools.partial with jax.tree_util.Partial for consistency and improved functionality.
@CLAassistant
Copy link

CLAassistant commented Sep 26, 2024

CLA assistant check
All committers have signed the CLA.

@eadadi
Copy link
Author

eadadi commented Sep 26, 2024

the following linter test is incorrect. nameclass is uppercase

flashbax/buffers/mixer.py:22:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/prioritised_trajectory_buffer.py:27:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_buffer.py:25:2: N813 camelcase 'Partial' imported as lowercase 'partial'
flashbax/buffers/trajectory_queue.py:19:2: N813 camelcase 'Partial' imported as lowercase 'partial'

@garymm
Copy link
Contributor

garymm commented Oct 2, 2024

I'm not a maintainer, but that linter seems correct to me. The original function name is CameCase, you're importing it as lowercase.

@sash-a
Copy link
Contributor

sash-a commented Oct 29, 2024

I'm confused with this one, is there a reason you'd want to put buffer on an accelerator, do you see speed ups over just putting the buffer state on the accelerator? Because buffer should just be a collection of functions?

@SimonDuToit
Copy link
Contributor

@eadadi Seeing as you haven't responded, Im closing this PR. If you feel this is a useful change, please reopen it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants