Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some linear algebra fixes #42

Merged
merged 2 commits into from
Mar 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ SpecialFunctions = "0.8, 0.9, 0.10"
StatsBase = "0.32"
StatsFuns = "0.8, 0.9"
Tracker = "0.2.5"
Zygote = "0.4.7"
Zygote = "0.4.10"
ZygoteRules = "0.2"
julia = "1"

Expand Down
2 changes: 1 addition & 1 deletion src/DistributionsAD.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ using Tracker: Tracker, TrackedReal, TrackedVector, TrackedMatrix, TrackedArray,
TrackedVecOrMat, track, @grad, data
using SpecialFunctions: logabsgamma, digamma
using ZygoteRules: ZygoteRules, @adjoint, pullback
using LinearAlgebra: copytri!
using LinearAlgebra: copytri!, AbstractTriangular
using Distributions: AbstractMvLogNormal,
ContinuousMultivariateDistribution
using DiffRules, SpecialFunctions, FillArrays
Expand Down
73 changes: 38 additions & 35 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,42 @@ end

## Linear algebra ##

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = track(UpperTriangular, A)
@grad function LinearAlgebra.UpperTriangular(A::AbstractMatrix)
return UpperTriangular(data(A)), Δ->(UpperTriangular(Δ),)
# Work around https://github.com/FluxML/Tracker.jl/pull/9#issuecomment-480051767

upper(A::AbstractMatrix) = UpperTriangular(A)
lower(A::AbstractMatrix) = LowerTriangular(A)
function upper(C::Cholesky)
if C.uplo == 'U'
return upper(C.factors)
else
return copy(lower(C.factors)')
end
end
function lower(C::Cholesky)
if C.uplo == 'U'
return copy(upper(C.factors)')
else
return lower(C.factors)
end
end

LinearAlgebra.LowerTriangular(A::TrackedMatrix) = lower(A)
lower(A::TrackedMatrix) = track(lower, A)
@grad lower(A) = lower(Tracker.data(A)), ∇ -> (lower(∇),)

LinearAlgebra.UpperTriangular(A::TrackedMatrix) = upper(A)
upper(A::TrackedMatrix) = track(upper, A)
@grad upper(A) = upper(Tracker.data(A)), ∇ -> (upper(∇),)

function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return track(copy, A)
end
@grad function Base.copy(
A::TrackedArray{T, 2, <:Adjoint{T, <:AbstractTriangular{T, <:AbstractMatrix{T}}}},
) where {T <: Real}
return copy(data(A)), ∇ -> (copy(∇),)
end

function LinearAlgebra.cholesky(A::TrackedMatrix; check=true)
Expand All @@ -57,40 +90,10 @@ function turing_chol(A::AbstractMatrix, check)
end
turing_chol(A::TrackedMatrix, check) = track(turing_chol, A, check)
@grad function turing_chol(A::AbstractMatrix, check)
C, back = pullback(unsafe_cholesky, data(A), data(check))
C, back = pullback(_turing_chol, data(A), data(check))
return (C.factors, C.info), Δ->back((factors=data(Δ[1]),))
end

unsafe_cholesky(x, check) = cholesky(x, check=check)
@adjoint function unsafe_cholesky(Σ::Real, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Σ), nothing)
(Δ.factors[1, 1] / (2 * C.U[1, 1]), nothing)
end
end
@adjoint function unsafe_cholesky(Σ::Diagonal, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || (Diagonal(zero(diag(Δ.factors))), nothing)
(Diagonal(diag(Δ.factors) .* inv.(2 .* C.factors.diag)), nothing)
end
end
@adjoint function unsafe_cholesky(Σ::Union{StridedMatrix, Symmetric{<:Real, <:StridedMatrix}}, check)
C = cholesky(Σ; check=check)
return C, function(Δ::NamedTuple)
issuccess(C) || return (zero(Δ.factors), nothing)
U, Ū = C.U, Δ.factors
Σ̄ = Ū * U'
Σ̄ = copytri!(Σ̄, 'U')
Σ̄ = ldiv!(U, Σ̄)
BLAS.trsm!('R', 'U', 'T', 'N', one(eltype(Σ)), U.data, Σ̄)
@inbounds for n in diagind(Σ̄)
Σ̄[n] /= 2
end
return (UpperTriangular(Σ̄), nothing)
end
end
_turing_chol(x, check) = cholesky(x, check=check)

# Specialised logdet for cholesky to target the triangle directly.
logdet_chol_tri(U::AbstractMatrix) = 2 * sum(log, U[diagind(U)])
Expand Down
7 changes: 0 additions & 7 deletions test/others.jl
Original file line number Diff line number Diff line change
@@ -1,13 +1,6 @@
using StatsBase: entropy

if get_stage() in ("Others", "all")
@testset "unsafe_cholesky" begin
A = rand(3, 3); A = A + A' + 3I
@test Matrix(DistributionsAD.unsafe_cholesky(A, true)) == Matrix(cholesky(A))
@test !issuccess(DistributionsAD.unsafe_cholesky(rand(3,3), false))
@test_throws PosDefException DistributionsAD.unsafe_cholesky(rand(3,3), true)
end

@testset "TuringWishart" begin
dim = 3
A = Matrix{Float64}(I, dim, dim)
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using DistributionsAD, Test, LinearAlgebra, Combinatorics
using ForwardDiff: Dual
using StatsFuns: binomlogpdf, logsumexp
const FDM = FiniteDifferences
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform, unsafe_cholesky
using DistributionsAD: TuringMvNormal, TuringMvLogNormal, TuringUniform
using Distributions: meanlogdet

include("test_utils.jl")
Expand Down