diff --git a/src/BackwardAlgorithms/BoxBackward.jl b/src/BackwardAlgorithms/BoxBackward.jl index 92d4d11..9a97de5 100644 --- a/src/BackwardAlgorithms/BoxBackward.jl +++ b/src/BackwardAlgorithms/BoxBackward.jl @@ -51,5 +51,10 @@ for T in (Sigmoid, LeakyReLU) h = _inverse(high(Y), act) return Hyperrectangle(; low=l, high=h) end + + # disambiguation + function backward(Y::Singleton, act::$T, algo::BoxBackward) + return Singleton(backward(element(Y), act, algo)) + end end end diff --git a/src/BackwardAlgorithms/PolyhedraBackward.jl b/src/BackwardAlgorithms/PolyhedraBackward.jl index f395244..24c0853 100644 --- a/src/BackwardAlgorithms/PolyhedraBackward.jl +++ b/src/BackwardAlgorithms/PolyhedraBackward.jl @@ -113,6 +113,10 @@ end # apply inverse ReLU activation function function backward(Y::LazySet, act::ReLU, ::PolyhedraBackward) + return _backward_PolyhedraBackward(Y, act) +end + +function _backward_PolyhedraBackward(Y::LazySet, act::ReLU) n = dim(Y) if n == 1 X = _backward_1D(Y, act) @@ -347,6 +351,14 @@ end # disambiguation for T in (:ReLU, :LeakyReLU) @eval begin + function backward(Y::Singleton, act::$T, algo::PolyhedraBackward) + if all(>(0), element(Y)) + return Singleton(backward(element(Y), act, algo)) + else + return _backward_PolyhedraBackward(Y, act) + end + end + function backward(Y::UnionSetArray, act::$T, algo::PolyhedraBackward) return _backward_union(Y, act, algo) end diff --git a/src/BackwardAlgorithms/backward_default.jl b/src/BackwardAlgorithms/backward_default.jl index 5c07fc6..4ffac29 100644 --- a/src/BackwardAlgorithms/backward_default.jl +++ b/src/BackwardAlgorithms/backward_default.jl @@ -70,7 +70,6 @@ append_sets!(Xs, X::LazySet) = push!(Xs, X) append_sets!(Xs, X::UnionSetArray) = append!(Xs, array(X)) # apply inverse piecewise-affine activation function to a union of sets -# COV_EXCL_START for T in (:ReLU, :LeakyReLU) @eval begin function backward(Y::UnionSetArray, act::$T, algo::BackwardAlgorithm) @@ -78,7 +77,6 @@ for T in (:ReLU, :LeakyReLU) end end end -# COV_EXCL_STOP function _backward_union(Y::LazySet{N}, act::ActivationFunction, algo::BackwardAlgorithm) where {N} @@ -97,10 +95,19 @@ function backward(y::AbstractVector, act::ActivationFunction, ::BackwardAlgorith end _inverse(x::AbstractVector, act::ActivationFunction) = [_inverse(xi, act) for xi in x] -_inverse(x::Number, ::ReLU) = x >= zero(x) ? x : zero(x) +_inverse(x::Number, ::ReLU) = x > zero(x) ? x : throw(ArgumentError("ReLU cannot be inverted")) _inverse(x::Number, ::Sigmoid) = @. -log(1 / x - 1) _inverse(x::Number, act::LeakyReLU) = x >= zero(x) ? x : x / act.slope +# invertible activations defined for numbers can be defined for singletons +for T in (:Sigmoid, :LeakyReLU) + @eval begin + function backward(Y::Singleton, act::$T, algo::BackwardAlgorithm) + return Singleton(backward(element(Y), act, algo)) + end + end +end + # activation functions must be explicitly supported for sets function backward(X::LazySet, act::ActivationFunction, algo::BackwardAlgorithm) throw(ArgumentError("activation function $act not supported by algorithm " * diff --git a/test/BackwardAlgorithms/backward.jl b/test/BackwardAlgorithms/backward.jl index 9cab768..af404c4 100755 --- a/test/BackwardAlgorithms/backward.jl +++ b/test/BackwardAlgorithms/backward.jl @@ -210,8 +210,8 @@ end ## union is too complex -> only perform partial tests @test X ⊆ Y && low(X) == [-Inf, 1.0] && high(X) == [Inf, Inf] # union - Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([0.0, 0.0])]) - @test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Pneg]) + Y = UnionSetArray([LineSegment([1.0, 1.0], [2.0, 2.0]), Singleton([1.0, 1.0])]) + @test backward(Y, ReLU(), algo) == UnionSetArray([Y[1], Singleton([1.0, 1.0])]) # 3D # positive point @@ -363,6 +363,17 @@ end for algo in (BoxBackward(),) @test isequivalent(backward(Y, lr, algo), X) end + + # default algorithm for union + for algo in (DummyBackward(),) + y1 = Singleton([2.0]) + y2 = Singleton([3.0]) + x1 = backward(y1, lr, algo) + x2 = backward(y2, lr, algo) + Y2 = UnionSetArray([y1, y2]) + X2 = backward(Y2, lr, algo) + @test X2 == UnionSetArray([x1, x2]) + end end @testset "Backward layer" begin