Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support inplace methods directly #8

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,9 @@ tangent and adjoint methods (see this
## Current Technical Limitations

* Only supports `gmres`, `cg`, and `bicgstab` methods
* No support for inplace methods `gmres!`, `cg!`, and `bicgstab!`
* No support for options when using Enzyme
* No support for sparse matrices using Enzyme
* No support for linear operators

## Current Open Questions
* How to handle preconditioners?
* How to set the options for the tangent/adjoint solve based on the options for the forward solve? For example `bicgtab` may return `NaN` for the tangents or adjoints.

## Installation
Expand Down
147 changes: 71 additions & 76 deletions src/EnzymeRules/enzymerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,99 +5,93 @@ using .EnzymeRules
export augmented_primal, reverse, forward

for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:bicgstab, :gmres)
for solver in (:bicgstab!, :gmres!)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{VT};
M = I,
N = I,
verbose = 0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.val, A,b; M=M, N=N, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
db -= dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; M=M, N=N, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
return solver.val
else
return Duplicated((x, stats), (dx, dstats))
return solver
end
end
end
end
for solver in (:cg,)
for solver in (:cg!,)
@eval begin
function forward(
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{VT};
verbose = 0,
M = I,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) forward rule"
end
A = _A.val
b = _b.val
dx = []
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
Krylov.$solver(solver.val,A,b; M=M, verbose=verbose, options...)
if isa(_A, Duplicated) && isa(_b, Duplicated)
dA = _A.dval
db = _b.dval
db -= dA*x
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
db -= dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Duplicated) && isa(_b, Const)
dA = _A.dval
db = -dA*x
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
db = -dA*solver.val.x
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Duplicated)
db = _b.dval
dx, dstats = Krylov.$solver(A,db; M=M, verbose=verbose, options...)
Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...)
elseif isa(_A, Const) && isa(_b, Const)
nothing
else
error("Error in Krylov forward rule: $(typeof(_A)), $(typeof(_b))")
end

if RT <: Const
return (x, stats)
elseif RT <: DuplicatedNoNeed
return (dx, stats)
return solver.val
else
return Duplicated((x, stats), (dx, dstats))
return solver
end
end
end
Expand All @@ -106,116 +100,117 @@ end


for AMT in (:Matrix, :SparseMatrixCSC)
for solver in (:bicgstab, :gmres)
for solver in (:bicgstab!, :gmres!)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
ret::Type{<:Annotation},
solver::Annotation{ST},
A::Annotation{MT},
b::Annotation{VT};
M=I,
N=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b; M=M, N=N, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M, N)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Krylov.$solver(
solver.val, A.val,b.val;
M=M, verbose=verbose, options...
)

cache = (solver.val.x, A.val, verbose,M,N)
return AugmentedReturn(nothing, nothing, cache)
end

function reverse(
config,
::Const{typeof(Krylov.$solver)},
dret::Type{RT},
cache,
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{<:Vector};
_b::Annotation{VT};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M,N) = cache
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT}
(x, A, verbose,M,N) = cache
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) reverse"
end
adjM = adjoint(N)
adjN = adjoint(M)
_b.dval .= Krylov.$solver(adjoint(A), bx[]; M=adjM, N=adjN, verbose=verbose, options...)[1]
Krylov.$solver(
solver.dval,
adjoint(A), copy(solver.dval.x); M=adjM, N=adjN,
verbose=verbose, options...
)
copyto!(_b.dval, solver.dval.x)
if isa(_A, Duplicated)
_A.dval .= -x .* _b.dval'
end
return (nothing, nothing)
return (nothing, nothing, nothing)
end
end
end
for solver in (:cg,)
for solver in (:cg!,)
@eval begin
function augmented_primal(
config,
func::Const{typeof(Krylov.$solver)},
ret::Type{RT},
_A::Annotation{MT},
_b::Annotation{VT};
ret::Type{<:Annotation},
solver::Annotation{ST},
A::Annotation{MT},
b::Annotation{VT};
M=I,
verbose=0,
options...
) where {RT, MT <: $AMT, VT <: Vector}
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector}
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) augmented forward"
end
A = _A.val
b = _b.val
x, stats = Krylov.$solver(A,b; M=M, verbose=verbose, options...)
bx = zeros(length(x))
bstats = deepcopy(stats)
if needs_primal(config)
return AugmentedReturn(
(x, stats),
(bx, bstats),
(A,x, Ref(bx), verbose, M)
)
else
return AugmentedReturn(nothing, (bx, bstats), (A,x))
end
Krylov.$solver(
solver.val, A.val,b.val;
M=M, verbose=verbose, options...
)
cache = (solver.val.x, A.val,verbose,M)
return AugmentedReturn(nothing, nothing, cache)
end

function reverse(
config,
::Const{typeof(Krylov.$solver)},
dret::Type{RT},
cache,
solver::Annotation{ST},
_A::Annotation{MT},
_b::Annotation{<:Vector};
_b::Annotation{VT};
options...
) where {RT, MT <: $AMT}
(A,x,bx,verbose,M) = cache
) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT}
(x, A, verbose,M) = cache
psolver = $solver
pamt = $AMT
if verbose > 0
@info "($psolver, $pamt) reverse"
end
_b.dval .= Krylov.$solver(transpose(A), bx[]; M=M, verbose=verbose, options...)[1]
_A.dval .= -x .* _b.dval'
return (nothing, nothing)
Krylov.$solver(
solver.dval,
A, copy(solver.dval.x); M=M,
verbose=verbose, options...
)
copyto!(_b.dval, solver.dval.x)
if isa(_A, Duplicated)
_A.dval .= -x .* _b.dval'
end
return (nothing, nothing, nothing)
end
end
end
Expand Down
6 changes: 3 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ include("utils.jl")
atol = 1e-12
rtol = 0.0
@testset "DiffKrylov" begin
@testset "ForwardDiff" begin
include("forwarddiff.jl")
end
# @testset "ForwardDiff" begin
# include("forwarddiff.jl")
# end
@testset "Enzyme" begin
include("enzymediff.jl")
end
Expand Down
Loading
Loading