Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove redundant sum() rules #1453

Merged
merged 9 commits into from
Jan 4, 2025
Merged

Remove redundant sum() rules #1453

merged 9 commits into from
Jan 4, 2025

Conversation

ToucheSir
Copy link
Member

The pullback is non-differentiable, which messes with nested AD (#1450). It's also not clear to me why this rule still exists when ChainRules has a seemingly GPU-compatible one. Let's see what CI says.

PR Checklist

  • Tests are added
  • Documentation, if applicable

@ToucheSir ToucheSir added CUDA All things GPU ChainRules adjoint -> rrule, and further integration labels Sep 1, 2023
@ToucheSir ToucheSir changed the base branch from master to bc/ci-noise September 1, 2023 20:58
@ToucheSir ToucheSir closed this Sep 1, 2023
@ToucheSir ToucheSir reopened this Sep 1, 2023
@mcabbott
Copy link
Member

mcabbott commented Sep 2, 2023

I thought this existed in order to opt-out of the Zygote rule for sum which makes a FillArray.

julia> gradient(sum, [2.0, 3.0])
(Fill(1.0, 2),)

We could delete that too, it saves one copy sometimes but rarely matters in real code, and causes problems.

@ToucheSir ToucheSir changed the base branch from bc/ci-noise to master September 4, 2023 23:44
@ToucheSir
Copy link
Member Author

ToucheSir commented Sep 5, 2023

Deleting that rule fixes all but one testsuite,

@test g isa Dict{Int, Int}
. Not sure how best to fix it. Perhaps we could generalize

Zygote.jl/src/lib/array.jl

Lines 340 to 342 in 6129613

@adjoint function sum(xs::AbstractArray{Bool}; dims = :)
sum(xs, dims = dims), Δ -> (nothing,)
end
to work on all Integers and convert it to a rrule(::ZygoteRuleConfig, ...) for future-proofing at the same time?

@mcabbott
Copy link
Member

mcabbott commented Sep 5, 2023

We could certainly delete the rule for bool arrays, as there's one here:

https://github.com/JuliaDiff/ChainRules.jl/blob/ba52ec89ddd97a07e79cc35a9fa39019915d203b/src/rulesets/Base/nondiff.jl#L80

IDK what the issue with that Dict test is.

(Considering integers to be differentiable was a mistake, IMO, but a breaking change to fix that, here or in CR.)

@ToucheSir
Copy link
Member Author

IDK what the issue with that Dict test is.

The old rule was arguably wrong, because it was passing through the gradient for the summed value without doing any form of projection. If this were a scalar function, asking to differentiate wrt an integer argument would return a float gradient. So in my mind the test is actually capturing incorrect and inconsistent behaviour of the current rule. If we agree on that, I'll just tweak the test and we'll be back on green CI (minus known AbstractFFT failures).

@mcabbott
Copy link
Member

mcabbott commented Sep 5, 2023

Sorry I didn't look closely, but if the change is just that now you get a Dict of Floats not Ints, then that seems totally fine, we just adjust the test.

@ToucheSir ToucheSir changed the title Remove GPU sum() rule Remove redundant sum() rules Sep 6, 2023
@ToucheSir
Copy link
Member Author

The one remaining failure:

sum, prod, cumsum: Test Failed at /var/lib/buildkite-agent/builds/gpuci-1/julialang/zygote-dot-jl/test/gradcheck.jl:117
  Expression: gradient(sum, [true, false, true]) == (nothing,)
 Evaluated: nothing == (nothing,)

Which comes from the isnothing ternary on

isnothing(grad) ? nothing : map(_project, args, grad)

@mcabbott do you recall why we're collapsing to nothing here? I can't recall how we're supposed to handle nothing vs (nothing,) vs (nothing, ..., nothing) when returned from the pullback.

@mcabbott
Copy link
Member

mcabbott commented Sep 8, 2023

My memory is that Zygote is eager to collapse any tuple of nothings to just nothing, but doesn't always manage to do so. I think at least withgradient and perhaps gradient try to restore them & always make a tuple. But I may have forgotten things.

@ToucheSir
Copy link
Member Author

ToucheSir commented Sep 8, 2023

It looks like gradient is not trying to make a tuple when it goes get singular nothing. Should we make it do so? A version of this problem (more aggressive collapsing of zeros after moving to CR rules) is also causing the last two (non-unbreaking) test failures in #1328, ref. https://github.com/FluxML/Zygote.jl/actions/runs/6117262926/job/16603631586?pr=1328#step:6:747.

@FerreolS
Copy link

Hi,
Is there any hope to merge this PR soon? Is there anything I can do in that direction?

@ToucheSir
Copy link
Member Author

Maybe, if we can get some consensus on the behaviour of gradient around collapsing zeros. See #1466 (comment). Once that's been established, the failing test here will either automatically pass or just requires a one-line tweak to start passing.

@mcabbott
Copy link
Member

mcabbott commented Jan 4, 2025

Want to rush this in & pretend it was part of breaking 0.7.0?

@ToucheSir
Copy link
Member Author

I totally forgot this one existed and was not looking forward to creating another PR to tackle https://discourse.julialang.org/t/second-order-gradient-with-lux-zygote-cuda-enzyme/124301. Take it away!

CC @pevnak for posterity, this likely would've addressed the issue you had with Zygote over Zygote.

@mcabbott
Copy link
Member

mcabbott commented Jan 4, 2025

Done, I think. Molly.jl error is InitError: Exception[GLFW.GLFWError(65550, "X11: The DISPLAY environment variable is missing"), ErrorException("glfwInit failed")], unrelated.

test/gradcheck.jl Outdated Show resolved Hide resolved
@mcabbott mcabbott merged commit 4bfc545 into master Jan 4, 2025
13 of 16 checks passed
@mcabbott mcabbott deleted the bc/rm-gpu-sum-adj branch January 4, 2025 23:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ChainRules adjoint -> rrule, and further integration CUDA All things GPU
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants