Skip to content

Commit

Permalink
Fix reverse failure (#1396)
Browse files Browse the repository at this point in the history
* Add internal function `_reverse` and overloads

* Add unit tests

* Correct issue number

* Label testset

* Add missing wrappers

* Avoid `collect` in `_reverse` for `Hermitian` and `Symmetric`

Co-authored-by: David Widmann <[email protected]>

* Use `_reverse` instead of `reverse`

Co-authored-by: David Widmann <[email protected]>

* Fix wrong names

:)

Co-authored-by: David Widmann <[email protected]>

* Add end user test case

* Add `using Zygote: _reverse`

---------

Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
simsurace and devmotion authored Mar 15, 2023
1 parent 756dd37 commit 2490c79
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
17 changes: 15 additions & 2 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Random, FillArrays, AbstractFFTs
using FillArrays: AbstractFill, getindex_value
using Base.Broadcast: broadcasted, broadcast_shape
using Distributed: pmap, AbstractWorkerPool
using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular
using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular

@adjoint Array(xs::AbstractArray) = Array(xs), ȳ -> (ȳ,)
@adjoint Array(xs::Array) = Array(xs), ȳ -> (ȳ,)
Expand Down Expand Up @@ -165,10 +167,21 @@ end
# This is also used by comprehensions, which do guarantee iteration order.
# Not done for pmap, presumably because all is lost if you are relying on its order.
_tryreverse(m, backs, Δ) = backs, Δ
_tryreverse(m::typeof(map), backs, Δ) = reverse(backs), reverse(Δ)
_tryreverse(m::typeof(map), backs, Δ) = _reverse(backs), _reverse(Δ)

_tryreverse(m, x) = x
_tryreverse(m::typeof(map), x) = reverse(x)
_tryreverse(m::typeof(map), x) = _reverse(x)

# Fallback
_reverse(x) = reverse(x)

# Known cases in the standard library on which `reverse` errors (issue #1393)
_reverse(x::LowerTriangular) = UpperTriangular(_reverse(parent(x)))
_reverse(x::UpperTriangular) = LowerTriangular(_reverse(parent(x)))
_reverse(x::UnitLowerTriangular) = UnitUpperTriangular(_reverse(parent(x)))
_reverse(x::UnitUpperTriangular) = UnitLowerTriangular(_reverse(parent(x)))
_reverse(x::Hermitian) = Hermitian(_reverse(x.data), x.uplo == 'U' ? :L : :U)
_reverse(x::Symmetric) = Symmetric(_reverse(x.data), x.uplo == 'U' ? :L : :U)

# With mismatched lengths, map stops early. With mismatched shapes, it makes a vector.
# So we keep axes(x) to restore gradient dx to its full length & correct shape.
Expand Down
33 changes: 32 additions & 1 deletion test/lib/array.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
using ChainRulesTestUtils
using Zygote: ZygoteRuleConfig, _pullback
using LinearAlgebra: Diagonal, Hermitian, LowerTriangular, UpperTriangular
using LinearAlgebra: UnitLowerTriangular, UnitUpperTriangular
using Zygote: ZygoteRuleConfig, _pullback, _reverse

# issue 897

Expand Down Expand Up @@ -65,3 +67,32 @@ end
end
@test gradient(f_comprehension, w)[1] == ones(5)
end

@testset "_reverse" begin
m = [1 2 3; 4 5 6; 7 8 9]
@testset "$wrapper" for wrapper in [
Hermitian, Symmetric, LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
]
M = wrapper(m)
@test collect(_reverse(M)) == _reverse(collect(M))
end
end

@testset "rrule for `map`" begin
@testset "MWE from #1393" begin
# https://github.com/FluxML/Zygote.jl/issues/1393#issuecomment-1468496804
struct Foo1393 x::Float64 end
(f::Foo1393)(x) = f.x * x
x = randn(5, 5)
out, pb = Zygote.pullback(x -> map(Foo1393(5.0), x), x)
@testset "$wrapper" for wrapper in [
Hermitian, Symmetric, LowerTriangular, UpperTriangular,
UnitLowerTriangular, UnitUpperTriangular,
]
m = wrapper(rand(5, 5))
res = only(pb(m))
@test res == 5m
end
end
end

0 comments on commit 2490c79

Please sign in to comment.