From 03acd3bce14bc25b53829045065cc1422386feb0 Mon Sep 17 00:00:00 2001 From: tmigot Date: Sun, 15 Dec 2024 12:06:38 -0500 Subject: [PATCH] Update JSOBestie template --- .breakage/Project.toml | 3 + .breakage/get_jso_users.jl | 18 ++ .copier-answers.jso.yml | 9 + .github/workflows/Breakage.yml | 201 +++++++++++++++++ src/EnzymeRules/enzymerules.jl | 129 +++++++---- src/ForwardDiff/forwarddiff.jl | 141 ++++++------ test/create_matrix.jl | 2 +- test/enzymediff.jl | 12 +- test/forwarddiff.jl | 6 +- test/get_div_grad.jl | 390 ++++++++++++++++----------------- test/utils.jl | 140 ++++++------ 11 files changed, 669 insertions(+), 382 deletions(-) create mode 100644 .breakage/Project.toml create mode 100644 .breakage/get_jso_users.jl create mode 100644 .copier-answers.jso.yml create mode 100644 .github/workflows/Breakage.yml diff --git a/.breakage/Project.toml b/.breakage/Project.toml new file mode 100644 index 0000000..a65fa1c --- /dev/null +++ b/.breakage/Project.toml @@ -0,0 +1,3 @@ +[deps] +GitHub = "bc5e4493-9b4d-5f90-b8aa-2b2bcaad7a26" +PkgDeps = "839e9fc8-855b-5b3c-a3b7-2833d3dd1f59" diff --git a/.breakage/get_jso_users.jl b/.breakage/get_jso_users.jl new file mode 100644 index 0000000..689612e --- /dev/null +++ b/.breakage/get_jso_users.jl @@ -0,0 +1,18 @@ +import GitHub, PkgDeps # both export users() + +length(ARGS) >= 1 || error("specify at least one JSO package as argument") + +jso_repos, _ = GitHub.repos("JuliaSmoothOptimizers") +jso_names = [splitext(x.name)[1] for x ∈ jso_repos] + +name = splitext(ARGS[1])[1] +name ∈ jso_names || error("argument should be one of ", jso_names) + +dependents = String[] +try + global dependents = filter(x -> x ∈ jso_names, PkgDeps.users(name)) +catch e + # package not registered; don't insert into dependents +end + +println(dependents) diff --git a/.copier-answers.jso.yml b/.copier-answers.jso.yml new file mode 100644 index 0000000..043bcf6 --- /dev/null +++ b/.copier-answers.jso.yml @@ -0,0 +1,9 @@ +PackageName: "DiffKrylov" +PackageOwner: "JuliaSmoothOptimizers" +PackageUUID: "de7797c5-59cd-4eef-8a62-bdd3ca55b1f9" +_src_path: "https://github.com/JuliaSmoothOptimizers/JSOBestieTemplate.jl" +_commit: "7bcc7d0a2905be7ca9d1dd6519e8df59ae89605c" +AddBreakage: true +AddBenchmark: false +AddBenchmarkCI: true +AddCirrusCI: false diff --git a/.github/workflows/Breakage.yml b/.github/workflows/Breakage.yml new file mode 100644 index 0000000..3883812 --- /dev/null +++ b/.github/workflows/Breakage.yml @@ -0,0 +1,201 @@ +# Ref: https://securitylab.github.com/research/github-actions-preventing-pwn-requests +name: Breakage + +# read-only repo token +# no access to secrets +on: + pull_request: + +jobs: + # Build dynamically the matrix on which the "break" job will run. + # The matrix contains the packages that depend on ${{ env.pkg }}. + # Job "setup_matrix" outputs variable "matrix", which is in turn + # the output of the "getmatrix" step. + # The contents of "matrix" is a JSON description of a matrix used + # in the next step. It has the form + # { + # "pkg": [ + # "PROPACK", + # "LLSModels", + # "FletcherPenaltySolver" + # ] + # } + setup_matrix: + runs-on: ubuntu-latest + outputs: + matrix: ${{ steps.getmatrix.outputs.matrix }} + env: + pkg: ${{ github.event.repository.name }} + steps: + - uses: actions/checkout@v4 + - uses: julia-actions/setup-julia@v2 + with: + version: 1 + arch: x64 + - id: getmatrix + run: | + julia -e 'using Pkg; Pkg.Registry.add(RegistrySpec(url = "https://github.com/JuliaRegistries/General.git"))' + julia --project=.breakage -e 'using Pkg; Pkg.update(); Pkg.instantiate()' + pkgs=$(julia --project=.breakage .breakage/get_jso_users.jl ${{ env.pkg }}) + vs='["latest", "stable"]' + # Check if pkgs is empty, and set it to a JSON array if necessary + if [[ -z "$pkgs" || "$pkgs" == "String[]" ]]; then + echo "No packages found; exiting successfully." + exit 0 + fi + vs='["latest", "stable"]' + matrix=$(jq -cn --argjson deps "$pkgs" --argjson vers "$vs" '{pkg: $deps, pkgversion: $vers}') # don't escape quotes like many posts suggest + echo "matrix=$matrix" >> "$GITHUB_OUTPUT" + + break: + needs: setup_matrix + if: needs.setup_matrix.result == 'success' && needs.setup_matrix.outputs.matrix != '' + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: ${{ fromJSON(needs.setup_matrix.outputs.matrix) }} + + steps: + - uses: actions/checkout@v4 + + # Install Julia + - uses: julia-actions/setup-julia@v2 + with: + version: 1 + arch: x64 + - uses: actions/cache@v4 + env: + cache-name: cache-artifacts + with: + path: ~/.julia/artifacts + key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }} + restore-keys: | + ${{ runner.os }}-test-${{ env.cache-name }}- + ${{ runner.os }}-test- + ${{ runner.os }}- + - uses: julia-actions/julia-buildpkg@v1 + + # Breakage test + - name: 'Breakage of ${{ matrix.pkg }}, ${{ matrix.pkgversion }} version' + env: + PKG: ${{ matrix.pkg }} + VERSION: ${{ matrix.pkgversion }} + run: | + set -v + mkdir -p ./breakage + git clone https://github.com/JuliaSmoothOptimizers/$PKG.jl.git + cd $PKG.jl + if [ $VERSION == "stable" ]; then + TAG=$(git tag -l "v*" --sort=-creatordate | head -n1) + if [ -z "$TAG" ]; then + TAG="no_tag" + else + git checkout $TAG + fi + else + TAG=$VERSION + fi + export TAG + julia -e 'using Pkg; + PKG, TAG, VERSION = ENV["PKG"], ENV["TAG"], ENV["VERSION"] + joburl = joinpath(ENV["GITHUB_SERVER_URL"], ENV["GITHUB_REPOSITORY"], "actions/runs", ENV["GITHUB_RUN_ID"]) + open("../breakage/breakage-$PKG-$VERSION", "w") do io + try + TAG == "no_tag" && error("No tag for $VERSION") + pkg"activate ."; + pkg"instantiate"; + pkg"dev ../"; + if TAG == "latest" + global TAG = chomp(read(`git rev-parse --short HEAD`, String)) + end + pkg"build"; + pkg"test"; + + print(io, "[![](https://img.shields.io/badge/$TAG-Pass-green)]($joburl)"); + catch e + @error e; + print(io, "[![](https://img.shields.io/badge/$TAG-Fail-red)]($joburl)"); + end; + end' + + - uses: actions/upload-artifact@v4 + with: + name: breakage-${{ matrix.pkg }}-${{ matrix.pkgversion }} + path: breakage/breakage-* + + upload: + needs: break + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/download-artifact@v4 + with: + path: breakage + pattern: breakage-* + merge-multiple: true + + - run: ls -R + - run: | + cd breakage + echo "| Package name | latest | stable |" > summary.md + echo "|--|--|--|" >> summary.md + count=0 + for file in breakage-* + do + if [ $count == "0" ]; then + name=$(echo $file | cut -f2 -d-) + echo -n "| $name | " + else + echo -n "| " + fi + cat $file + if [ $count == "0" ]; then + echo -n " " + count=1 + else + echo " |" + count=0 + fi + done >> summary.md + + - name: PR comment with file + uses: actions/github-script@v6 + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + // Import file content from summary.md + const fs = require('fs') + const filePath = 'breakage/summary.md' + const msg = fs.readFileSync(filePath, 'utf8') + + // Get the current PR number from context + const prNumber = context.payload.pull_request.number + + // Fetch existing comments on the PR + const { data: comments } = await github.rest.issues.listComments({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber + }) + + // Find a previous comment by the bot to update + const botComment = comments.find(comment => comment.user.id === 41898282) + + if (botComment) { + // Update the existing comment + await github.rest.issues.updateComment({ + owner: context.repo.owner, + repo: context.repo.repo, + comment_id: botComment.id, + body: msg + }) + } else { + // Create a new comment + await github.rest.issues.createComment({ + owner: context.repo.owner, + repo: context.repo.repo, + issue_number: prNumber, + body: msg + }) + } diff --git a/src/EnzymeRules/enzymerules.jl b/src/EnzymeRules/enzymerules.jl index e76e4d5..660a028 100644 --- a/src/EnzymeRules/enzymerules.jl +++ b/src/EnzymeRules/enzymerules.jl @@ -16,8 +16,8 @@ for AMT in (:Matrix, :SparseMatrixCSC) M = I, N = I, verbose = 0, - options... - ) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} + options..., + ) where {RT<:Annotation,ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector} psolver = $solver pamt = $AMT if verbose > 0 @@ -25,19 +25,51 @@ for AMT in (:Matrix, :SparseMatrixCSC) end A = _A.val b = _b.val - Krylov.$solver(solver.val, A,b; M=M, N=N, verbose=verbose, options...) + Krylov.$solver( + solver.val, + A, + b; + M = M, + N = N, + verbose = verbose, + options..., + ) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval - db -= dA*solver.val.x - Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) + db -= dA * solver.val.x + Krylov.$solver( + solver.dval, + A, + db; + M = M, + N = N, + verbose = verbose, + options..., + ) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval - db = -dA*x - Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) + db = -dA * x + Krylov.$solver( + solver.dval, + A, + db; + M = M, + N = N, + verbose = verbose, + options..., + ) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - Krylov.$solver(solver.dval,A,db; M=M, N=N, verbose=verbose, options...) + Krylov.$solver( + solver.dval, + A, + db; + M = M, + N = N, + verbose = verbose, + options..., + ) elseif isa(_A, Const) && isa(_b, Const) nothing else @@ -61,8 +93,8 @@ for AMT in (:Matrix, :SparseMatrixCSC) _b::Annotation{VT}; verbose = 0, M = I, - options... - ) where {RT <: Annotation, ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} + options..., + ) where {RT<:Annotation,ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector} psolver = $solver pamt = $AMT if verbose > 0 @@ -70,19 +102,19 @@ for AMT in (:Matrix, :SparseMatrixCSC) end A = _A.val b = _b.val - Krylov.$solver(solver.val,A,b; M=M, verbose=verbose, options...) + Krylov.$solver(solver.val, A, b; M = M, verbose = verbose, options...) if isa(_A, Duplicated) && isa(_b, Duplicated) dA = _A.dval db = _b.dval - db -= dA*solver.val.x - Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) + db -= dA * solver.val.x + Krylov.$solver(solver.dval, A, db; M = M, verbose = verbose, options...) elseif isa(_A, Duplicated) && isa(_b, Const) dA = _A.dval - db = -dA*solver.val.x - Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) + db = -dA * solver.val.x + Krylov.$solver(solver.dval, A, db; M = M, verbose = verbose, options...) elseif isa(_A, Const) && isa(_b, Duplicated) db = _b.dval - Krylov.$solver(solver.dval,A,db; M=M, verbose=verbose, options...) + Krylov.$solver(solver.dval, A, db; M = M, verbose = verbose, options...) elseif isa(_A, Const) && isa(_b, Const) nothing else @@ -109,22 +141,26 @@ for AMT in (:Matrix, :SparseMatrixCSC) solver::Annotation{ST}, A::Annotation{MT}, b::Annotation{VT}; - M=I, - N=I, - verbose=0, - options... - ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} + M = I, + N = I, + verbose = 0, + options..., + ) where {ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector} psolver = $solver pamt = $AMT if verbose > 0 @info "($psolver, $pamt) augmented forward" end Krylov.$solver( - solver.val, A.val,b.val; - M=M, verbose=verbose, options... + solver.val, + A.val, + b.val; + M = M, + verbose = verbose, + options..., ) - cache = (solver.val.x, A.val, verbose,M,N) + cache = (solver.val.x, A.val, verbose, M, N) return AugmentedReturn(nothing, nothing, cache) end @@ -136,9 +172,9 @@ for AMT in (:Matrix, :SparseMatrixCSC) solver::Annotation{ST}, _A::Annotation{MT}, _b::Annotation{VT}; - options... - ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT} - (x, A, verbose,M,N) = cache + options..., + ) where {ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector,RT} + (x, A, verbose, M, N) = cache psolver = $solver pamt = $AMT if verbose > 0 @@ -148,8 +184,12 @@ for AMT in (:Matrix, :SparseMatrixCSC) adjN = adjoint(M) Krylov.$solver( solver.dval, - adjoint(A), copy(solver.dval.x); M=adjM, N=adjN, - verbose=verbose, options... + adjoint(A), + copy(solver.dval.x); + M = adjM, + N = adjN, + verbose = verbose, + options..., ) copyto!(_b.dval, solver.dval.x) if isa(_A, Duplicated) @@ -168,20 +208,24 @@ for AMT in (:Matrix, :SparseMatrixCSC) solver::Annotation{ST}, A::Annotation{MT}, b::Annotation{VT}; - M=I, - verbose=0, - options... - ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector} + M = I, + verbose = 0, + options..., + ) where {ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector} psolver = $solver pamt = $AMT if verbose > 0 @info "($psolver, $pamt) augmented forward" end Krylov.$solver( - solver.val, A.val,b.val; - M=M, verbose=verbose, options... + solver.val, + A.val, + b.val; + M = M, + verbose = verbose, + options..., ) - cache = (solver.val.x, A.val,verbose,M) + cache = (solver.val.x, A.val, verbose, M) return AugmentedReturn(nothing, nothing, cache) end @@ -193,9 +237,9 @@ for AMT in (:Matrix, :SparseMatrixCSC) solver::Annotation{ST}, _A::Annotation{MT}, _b::Annotation{VT}; - options... - ) where {ST <: Krylov.KrylovSolver, MT <: $AMT, VT <: Vector, RT} - (x, A, verbose,M) = cache + options..., + ) where {ST<:Krylov.KrylovSolver,MT<:$AMT,VT<:Vector,RT} + (x, A, verbose, M) = cache psolver = $solver pamt = $AMT if verbose > 0 @@ -203,8 +247,11 @@ for AMT in (:Matrix, :SparseMatrixCSC) end Krylov.$solver( solver.dval, - A, copy(solver.dval.x); M=M, - verbose=verbose, options... + A, + copy(solver.dval.x); + M = M, + verbose = verbose, + options..., ) copyto!(_b.dval, solver.dval.x) if isa(_A, Duplicated) diff --git a/src/ForwardDiff/forwarddiff.jl b/src/ForwardDiff/forwarddiff.jl index 06bc831..81244df 100644 --- a/src/ForwardDiff/forwarddiff.jl +++ b/src/ForwardDiff/forwarddiff.jl @@ -1,17 +1,18 @@ import ForwardDiff: Dual, Partials, partials, value -_matrix_values(A::SparseMatrixCSC{Dual{T, V, N}, IT}) where {T, V, N, IT} = SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, value.(A.nzval)) -_matrix_values(A::Matrix{Dual{T, V, N}}) where {T, V, N} = Matrix{V}(value.(A)) -function _matrix_partials(A::SparseMatrixCSC{Dual{T, V, N}, IT}) where {T, V, N, IT} - dAs = Vector{SparseMatrixCSC{V, IT}}(undef, N) - for i in 1:N +_matrix_values(A::SparseMatrixCSC{Dual{T,V,N},IT}) where {T,V,N,IT} = + SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, value.(A.nzval)) +_matrix_values(A::Matrix{Dual{T,V,N}}) where {T,V,N} = Matrix{V}(value.(A)) +function _matrix_partials(A::SparseMatrixCSC{Dual{T,V,N},IT}) where {T,V,N,IT} + dAs = Vector{SparseMatrixCSC{V,IT}}(undef, N) + for i = 1:N dAs[i] = SparseMatrixCSC(A.m, A.n, A.colptr, A.rowval, partials.(A.nzval, i)) end return dAs end -function _matrix_partials(A::Matrix{Dual{T, V, N}}) where {T, V, N} +function _matrix_partials(A::Matrix{Dual{T,V,N}}) where {T,V,N} dAs = Vector{Matrix{V}}(undef, N) - for i in 1:N + for i = 1:N dAs[i] = Matrix(partials.(A, i)) end return dAs @@ -19,25 +20,29 @@ end for solver in (:cg, :bicgstab) - for matrix in (:(SparseMatrixCSC{V, IT} where {IT}), :(Matrix{V})) + for matrix in (:(SparseMatrixCSC{V,IT} where {IT}), :(Matrix{V})) @eval begin - function Krylov.$solver(A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N} + function Krylov.$solver( + A::$matrix, + _b::Vector{Dual{T,V,N}}; + options..., + ) where {T,V,N} b = value.(_b) m = length(b) dbs = Matrix{V}(undef, m, N) - for i in 1:m - dbs[i,:] = partials(_b[i]) + for i = 1:m + dbs[i, :] = partials(_b[i]) end - x, stats = $solver(A,b; options...) + x, stats = $solver(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) - for i in 1:N - nb = dbs[:,i] - dx, dstats = $solver(A,nb; options...) - dxs[:,i] = dx + for i = 1:N + nb = dbs[:, i] + dx, dstats = $solver(A, nb; options...) + dxs[:, i] = dx end - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) @@ -45,22 +50,22 @@ for solver in (:cg, :bicgstab) end end - for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT} where {IT}), :(Matrix{Dual{T,V,N}})) + for matrix in (:(SparseMatrixCSC{Dual{T,V,N},IT} where {IT}), :(Matrix{Dual{T,V,N}})) @eval begin - function Krylov.$solver(_A::$matrix, b::Vector{V}; options...) where {T, V, N} + function Krylov.$solver(_A::$matrix, b::Vector{V}; options...) where {T,V,N} A = _matrix_values(_A) dAs = _matrix_partials(_A) m = length(b) - x, stats = $solver(A,b; options...) + x, stats = $solver(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) - for i in 1:N - nb = - dAs[i]*x - dx, dstats = $solver(A,nb; options...) - dxs[:,i] = dx + for i = 1:N + nb = -dAs[i] * x + dx, dstats = $solver(A, nb; options...) + dxs[:, i] = dx end - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) @@ -68,27 +73,31 @@ for solver in (:cg, :bicgstab) end end - for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT} where {IT}), :(Matrix{Dual{T,V,N}})) + for matrix in (:(SparseMatrixCSC{Dual{T,V,N},IT} where {IT}), :(Matrix{Dual{T,V,N}})) @eval begin - function Krylov.$solver(_A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N} + function Krylov.$solver( + _A::$matrix, + _b::Vector{Dual{T,V,N}}; + options..., + ) where {T,V,N} A = _matrix_values(_A) dAs = _matrix_partials(_A) b = value.(_b) m = length(b) dbs = Matrix{V}(undef, m, N) - for i in 1:m - dbs[i,:] = partials(_b[i]) + for i = 1:m + dbs[i, :] = partials(_b[i]) end - x, stats = $solver(A,b; options...) + x, stats = $solver(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) - for i in 1:N - nb = dbs[:,i] - dAs[i]*x - dx, dstats = $solver(A,nb; options...) - dxs[:,i] = dx + for i = 1:N + nb = dbs[:, i] - dAs[i] * x + dx, dstats = $solver(A, nb; options...) + dxs[:, i] = dx end - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) @@ -98,24 +107,24 @@ for solver in (:cg, :bicgstab) end -for matrix in (:(SparseMatrixCSC{V, IT} where {IT}), :(Matrix{V})) +for matrix in (:(SparseMatrixCSC{V,IT} where {IT}), :(Matrix{V})) @eval begin - function Krylov.gmres(A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N} + function Krylov.gmres(A::$matrix, _b::Vector{Dual{T,V,N}}; options...) where {T,V,N} b = value.(_b) m = length(b) dbs = Matrix{V}(undef, m, N) - for i in 1:m - dbs[i,:] = partials(_b[i]) + for i = 1:m + dbs[i, :] = partials(_b[i]) end - x, stats = gmres(A,b; options...) + x, stats = gmres(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) if N != 0 - xs, dstats = block_gmres(A,dbs; options...) + xs, dstats = block_gmres(A, dbs; options...) dxs .= xs end - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) @@ -123,25 +132,25 @@ for matrix in (:(SparseMatrixCSC{V, IT} where {IT}), :(Matrix{V})) end end -for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT} where {IT}), :(Matrix{Dual{T,V,N}})) +for matrix in (:(SparseMatrixCSC{Dual{T,V,N},IT} where {IT}), :(Matrix{Dual{T,V,N}})) @eval begin - function Krylov.gmres(_A::$matrix, b::Vector{V}; options...) where {T, V, N} + function Krylov.gmres(_A::$matrix, b::Vector{V}; options...) where {T,V,N} A = _matrix_values(_A) dAs = _matrix_partials(_A) m = length(b) dbs = Matrix{V}(undef, m, N) - x, stats = gmres(A,b; options...) + x, stats = gmres(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) - for i in 1:N - dbs[:,i] = - dAs[i]*x + for i = 1:N + dbs[:, i] = -dAs[i] * x end if N != 0 - dx, dstats = block_gmres(A,dbs; options...) + dx, dstats = block_gmres(A, dbs; options...) end dxs .= dx - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) @@ -149,29 +158,33 @@ for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT} where {IT}), :(Matrix{Dual{T,V end end -for matrix in (:(SparseMatrixCSC{Dual{T,V,N}, IT} where {IT}), :(Matrix{Dual{T,V,N}})) +for matrix in (:(SparseMatrixCSC{Dual{T,V,N},IT} where {IT}), :(Matrix{Dual{T,V,N}})) @eval begin - function Krylov.gmres(_A::$matrix, _b::Vector{Dual{T, V, N}}; options...) where {T, V, N} + function Krylov.gmres( + _A::$matrix, + _b::Vector{Dual{T,V,N}}; + options..., + ) where {T,V,N} A = _matrix_values(_A) dAs = _matrix_partials(_A) b = value.(_b) m = length(b) dbs = Matrix{V}(undef, m, N) - for i in 1:m - dbs[i,:] = partials(_b[i]) + for i = 1:m + dbs[i, :] = partials(_b[i]) end - x, stats = gmres(A,b; options...) + x, stats = gmres(A, b; options...) dxs = Matrix{V}(undef, m, N) px = Vector{Partials{N,V}}(undef, m) - for i in 1:N - dbs[:,i] -= dAs[i]*x + for i = 1:N + dbs[:, i] -= dAs[i] * x end if N != 0 - dx, dstats = block_gmres(A,dbs; options...) + dx, dstats = block_gmres(A, dbs; options...) end dxs .= dx - for i in 1:m - px[i] = Partials{N,V}(Tuple(dxs[i,j] for j in 1:N)) + for i = 1:m + px[i] = Partials{N,V}(Tuple(dxs[i, j] for j = 1:N)) end duals = Dual{T,V,N}.(x, px) return (duals, stats) diff --git a/test/create_matrix.jl b/test/create_matrix.jl index 9592685..acb5146 100644 --- a/test/create_matrix.jl +++ b/test/create_matrix.jl @@ -12,7 +12,7 @@ function create_unsymmetric_matrix(n) # Modify the singular values to make them close to each other but not too small # Here we set them all to be between 1 and 2 - S = Diagonal(range(1, stop=2, length=n)) + S = Diagonal(range(1, stop = 2, length = n)) # Reconstruct the matrix well_conditioned_matrix = U * S * V' diff --git a/test/enzymediff.jl b/test/enzymediff.jl index a97e12f..a3d4c3c 100644 --- a/test/enzymediff.jl +++ b/test/enzymediff.jl @@ -12,22 +12,22 @@ using Test Random.seed!(1) include("create_matrix.jl") @testset "Enzyme Rules" begin - @testset "$MT" for MT = (Matrix, SparseMatrixCSC) - @testset "($M, $N)" for (M,N) = ((I,I),) + @testset "$MT" for MT in (Matrix, SparseMatrixCSC) + @testset "($M, $N)" for (M, N) in ((I, I),) # Square unsymmetric solvers - @testset "$solver" for solver = (Krylov.gmres, Krylov.bicgstab) + @testset "$solver" for solver in (Krylov.gmres, Krylov.bicgstab) A = [] if MT == Matrix A = create_unsymmetric_matrix(10) b = rand(10) else - A, b = sparse_laplacian(4, FC=Float64) + A, b = sparse_laplacian(4, FC = Float64) end test_enzyme_with(solver, A, b, M, N) end # Square symmetric solvers - @testset "$solver" for solver = (Krylov.cg,) - A, b = sparse_laplacian(4, FC=Float64) + @testset "$solver" for solver in (Krylov.cg,) + A, b = sparse_laplacian(4, FC = Float64) A = MT(A) test_enzyme_with(solver, A, b, M, N) end diff --git a/test/forwarddiff.jl b/test/forwarddiff.jl index ebcadcc..f43701c 100644 --- a/test/forwarddiff.jl +++ b/test/forwarddiff.jl @@ -1,6 +1,6 @@ -A, b = sparse_laplacian(4, FC=Float64) -@testset "$solver" for solver = (Krylov.cg, Krylov.gmres, Krylov.bicgstab) - x, stats = solver(A,b; atol=atol, rtol=rtol) +A, b = sparse_laplacian(4, FC = Float64) +@testset "$solver" for solver in (Krylov.cg, Krylov.gmres, Krylov.bicgstab) + x, stats = solver(A, b; atol = atol, rtol = rtol) # A passive, b active # Sparse diff --git a/test/get_div_grad.jl b/test/get_div_grad.jl index ae27e50..733515c 100644 --- a/test/get_div_grad.jl +++ b/test/get_div_grad.jl @@ -1,234 +1,234 @@ # Identity matrix. -eye(n::Int; FC=Float64) = sparse(one(FC) * I, n, n) +eye(n::Int; FC = Float64) = sparse(one(FC) * I, n, n) # Compute the energy norm ‖r‖ₚ = √(rᴴPr) where P is a symmetric and positive definite matrix. metric(r, P) = sqrt(real(dot(r, P * r))) # Based on Lars Ruthotto's initial implementation. -function get_div_grad(n1 :: Int, n2 :: Int, n3 :: Int) +function get_div_grad(n1::Int, n2::Int, n3::Int) - # Divergence - D1 = kron(eye(n3), kron(eye(n2), ddx(n1))) - D2 = kron(eye(n3), kron(ddx(n2), eye(n1))) - D3 = kron(ddx(n3), kron(eye(n2), eye(n1))) + # Divergence + D1 = kron(eye(n3), kron(eye(n2), ddx(n1))) + D2 = kron(eye(n3), kron(ddx(n2), eye(n1))) + D3 = kron(ddx(n3), kron(eye(n2), eye(n1))) - # DIV from faces to cell-centers - Div = [D1 D2 D3] + # DIV from faces to cell-centers + Div = [D1 D2 D3] - return Div * Div' + return Div * Div' end # 1D finite difference on staggered grid -function ddx(n :: Int) - e = ones(n) - return sparse([1:n; 1:n], [1:n; 2:n+1], [-e; e]) +function ddx(n::Int) + e = ones(n) + return sparse([1:n; 1:n], [1:n; 2:n+1], [-e; e]) end # Primal and dual ODEs discretized with central second order finite differences. -function ODE(n, f, g, ode_coefs; dim_x=[0.0, 1.0]) - # Ω = ]xₗ, xᵣ[ - # Ω ∪ ∂Ω = [xₗ, xᵣ] - xₗ = dim_x[1] - xᵣ = dim_x[2] - - # Uniform grid of Ω with n points - Δx = (xᵣ - xₗ) / (n + 1) - grid = [i * Δx for i = 1 : n] - - χ₁ = ode_coefs[1] - χ₂ = ode_coefs[2] - χ₃ = ode_coefs[3] - - # Model problems with Au = b and Aᵀv = c - # - # A ∈ ℜⁿ*ⁿ, u ∈ ℜⁿ, b ∈ ℜⁿ, v ∈ ℜⁿ, c ∈ ℜⁿ - # - # ∂²z(xᵢ) / ∂x² ≈ (zᵢ₋₁ -2zᵢ + zᵢ₊₁) / (Δx)² - # - # ∂z(xᵢ) / ∂x ≈ (zᵢ₊₁ - zᵢ₋₁) / (2 * Δx) - A = spzeros(n, n) - for i = 1 : n - if i ≠ 1 - A[i, i-1] = χ₁ / (Δx * Δx) - χ₂ / (2 * Δx) +function ODE(n, f, g, ode_coefs; dim_x = [0.0, 1.0]) + # Ω = ]xₗ, xᵣ[ + # Ω ∪ ∂Ω = [xₗ, xᵣ] + xₗ = dim_x[1] + xᵣ = dim_x[2] + + # Uniform grid of Ω with n points + Δx = (xᵣ - xₗ) / (n + 1) + grid = [i * Δx for i = 1:n] + + χ₁ = ode_coefs[1] + χ₂ = ode_coefs[2] + χ₃ = ode_coefs[3] + + # Model problems with Au = b and Aᵀv = c + # + # A ∈ ℜⁿ*ⁿ, u ∈ ℜⁿ, b ∈ ℜⁿ, v ∈ ℜⁿ, c ∈ ℜⁿ + # + # ∂²z(xᵢ) / ∂x² ≈ (zᵢ₋₁ -2zᵢ + zᵢ₊₁) / (Δx)² + # + # ∂z(xᵢ) / ∂x ≈ (zᵢ₊₁ - zᵢ₋₁) / (2 * Δx) + A = spzeros(n, n) + for i = 1:n + if i ≠ 1 + A[i, i-1] = χ₁ / (Δx * Δx) - χ₂ / (2 * Δx) + end + A[i, i] = -2 * χ₁ / (Δx * Δx) + χ₃ + if i ≠ n + A[i, i+1] = χ₁ / (Δx * Δx) + χ₂ / (2 * Δx) + end end - A[i, i] = -2 * χ₁ / (Δx * Δx) + χ₃ - if i ≠ n - A[i, i+1] = χ₁ / (Δx * Δx) + χ₂ / (2 * Δx) - end - end - b = f(grid) - c = g(grid) - return A, b, c + b = f(grid) + c = g(grid) + return A, b, c end # Primal and dual PDEs discretized with central second order finite differences. -function PDE(n, m, f, g, pde_coefs; dim_x=[0.0, 1.0], dim_y=[0.0, 1.0]) - # Ω = ]xₗ,xᵣ[ × ]yₗ,yᵣ[ - # Ω ∪ ∂Ω = [xₗ,xᵣ] × [yₗ,yᵣ] - xₗ = dim_x[1] - xᵣ = dim_x[2] - - yₗ = dim_y[1] - yᵣ = dim_y[2] - - # Uniform grid of Ω with n × m points - Δx = (xᵣ - xₗ) / (n + 1) - x = [xₗ + i * Δx for i = 1 : n] - - Δy = (yᵣ - yₗ) / (m + 1) - y = [yₗ + j * Δy for j = 1 : m] - - a = pde_coefs[1] - b = pde_coefs[2] - c = pde_coefs[3] - d = pde_coefs[4] - e = pde_coefs[5] - - # Model problems with Au = b and Aᵀv = c - # - # A ∈ ℜᵐⁿ*ᵐⁿ, u ∈ ℜᵐⁿ, b ∈ ℜᵐⁿ, v ∈ ℜᵐⁿ, c ∈ ℜᵐⁿ - # xᵢ = i * Δx, yⱼ = j * Δy and zᵢ.ⱼ = z(xᵢ, yⱼ) - # - # ∂²z(xᵢ, yⱼ) / ∂x² ≈ (zᵢ₋₁.ⱼ -2zᵢ.ⱼ + zᵢ₊₁.ⱼ) / (Δx)² - # ∂²z(xᵢ, yⱼ) / ∂y² ≈ (zᵢ.ⱼ₋₁ -2zᵢ.ⱼ + zᵢ.ⱼ₊₁) / (Δy)² - # - # ∂z(xᵢ, yⱼ) / ∂x ≈ (zᵢ₊₁.ⱼ - zᵢ₋₁.ⱼ) / (2 * Δx) - # ∂z(xᵢ, yⱼ) / ∂y ≈ (zᵢ.ⱼ₊₁ - zᵢ.ⱼ₋₁) / (2 * Δy) - # - # uᵢ.ⱼ = u[i + n * (j-1)] - # bᵢ.ⱼ = f[i + n * (j-1)] - # - # vᵢ.ⱼ = v[i + n * (j-1)] - # cᵢ.ⱼ = g[i + n * (j-1)] - A = spzeros(n * m, n * m) - for i = 1 : n - for j = 1 : m - A[i + n*(j-1), i + n*(j-1)] = - 2*a / (Δx * Δx) - 2*b / (Δy * Δy) + e - if i ≥ 2 - A[i + n*(j-1), (i-1) + n*(j-1)] = a / (Δx * Δx) - c / (2 * Δx) - end - if i ≤ n-1 - A[i + n*(j-1), (i+1) + n*(j-1)] = a / (Δx * Δx) + c / (2 * Δx) - end - if j ≥ 2 - A[i + n*(j-1), i + n*(j-2)] = b / (Δy * Δy) - d / (2 * Δy) - end - if j ≤ m-1 - A[i + n*(j-1), i + n*j] = b / (Δy * Δy) + d / (2 * Δy) - end +function PDE(n, m, f, g, pde_coefs; dim_x = [0.0, 1.0], dim_y = [0.0, 1.0]) + # Ω = ]xₗ,xᵣ[ × ]yₗ,yᵣ[ + # Ω ∪ ∂Ω = [xₗ,xᵣ] × [yₗ,yᵣ] + xₗ = dim_x[1] + xᵣ = dim_x[2] + + yₗ = dim_y[1] + yᵣ = dim_y[2] + + # Uniform grid of Ω with n × m points + Δx = (xᵣ - xₗ) / (n + 1) + x = [xₗ + i * Δx for i = 1:n] + + Δy = (yᵣ - yₗ) / (m + 1) + y = [yₗ + j * Δy for j = 1:m] + + a = pde_coefs[1] + b = pde_coefs[2] + c = pde_coefs[3] + d = pde_coefs[4] + e = pde_coefs[5] + + # Model problems with Au = b and Aᵀv = c + # + # A ∈ ℜᵐⁿ*ᵐⁿ, u ∈ ℜᵐⁿ, b ∈ ℜᵐⁿ, v ∈ ℜᵐⁿ, c ∈ ℜᵐⁿ + # xᵢ = i * Δx, yⱼ = j * Δy and zᵢ.ⱼ = z(xᵢ, yⱼ) + # + # ∂²z(xᵢ, yⱼ) / ∂x² ≈ (zᵢ₋₁.ⱼ -2zᵢ.ⱼ + zᵢ₊₁.ⱼ) / (Δx)² + # ∂²z(xᵢ, yⱼ) / ∂y² ≈ (zᵢ.ⱼ₋₁ -2zᵢ.ⱼ + zᵢ.ⱼ₊₁) / (Δy)² + # + # ∂z(xᵢ, yⱼ) / ∂x ≈ (zᵢ₊₁.ⱼ - zᵢ₋₁.ⱼ) / (2 * Δx) + # ∂z(xᵢ, yⱼ) / ∂y ≈ (zᵢ.ⱼ₊₁ - zᵢ.ⱼ₋₁) / (2 * Δy) + # + # uᵢ.ⱼ = u[i + n * (j-1)] + # bᵢ.ⱼ = f[i + n * (j-1)] + # + # vᵢ.ⱼ = v[i + n * (j-1)] + # cᵢ.ⱼ = g[i + n * (j-1)] + A = spzeros(n * m, n * m) + for i = 1:n + for j = 1:m + A[i+n*(j-1), i+n*(j-1)] = -2 * a / (Δx * Δx) - 2 * b / (Δy * Δy) + e + if i ≥ 2 + A[i+n*(j-1), (i-1)+n*(j-1)] = a / (Δx * Δx) - c / (2 * Δx) + end + if i ≤ n - 1 + A[i+n*(j-1), (i+1)+n*(j-1)] = a / (Δx * Δx) + c / (2 * Δx) + end + if j ≥ 2 + A[i+n*(j-1), i+n*(j-2)] = b / (Δy * Δy) - d / (2 * Δy) + end + if j ≤ m - 1 + A[i+n*(j-1), i+n*j] = b / (Δy * Δy) + d / (2 * Δy) + end + end end - end - b = zeros(n * m) - for i = 1 : n - for j = 1 : m - b[i + n*(j-1)] = f(x[i], y[j]) + b = zeros(n * m) + for i = 1:n + for j = 1:m + b[i+n*(j-1)] = f(x[i], y[j]) + end end - end - c = zeros(n * m) - for i = 1 : n - for j = 1 : m - c[i + n*(j-1)] = g(x[i], y[j]) + c = zeros(n * m) + for i = 1:n + for j = 1:m + c[i+n*(j-1)] = g(x[i], y[j]) + end end - end - return A, b, c + return A, b, c end # Model Poisson equation in polar coordinates -function polar_poisson(n, m, f, g; R=1.0) - Δr = 2 * R / (2*n + 1) - r = [(i - 1/2) * Δr for i = 1 : n+1] - - Δθ = 2 * π / m - θ = [(j - 1) * Δθ for j = 1 : m+1] - - λ = [1 / (2 * (k - 1/2)) for k = 1 : n] - β = [1 / ((k - 1/2)^2 * Δθ^2) for k = 1 : n] - - D = spdiagm(0 => β) - T = spdiagm(-1 => 1.0 .- λ[2:n], 0 => -2.0 * ones(n), 1 => 1.0 .+ λ[1:n-1]) - - A = spzeros(n * m, n * m) - for k = 1 : m - A[1+(k-1)*n : k*n, 1+(k-1)*n : k*n] = T - 2*D - if k ≤ m-1 - A[1+k*n : (k+1)*n, 1+(k-1)*n : k*n] = D - A[1+(k-1)*n : k*n, 1+k*n : (k+1)*n] = D +function polar_poisson(n, m, f, g; R = 1.0) + Δr = 2 * R / (2 * n + 1) + r = [(i - 1 / 2) * Δr for i = 1:n+1] + + Δθ = 2 * π / m + θ = [(j - 1) * Δθ for j = 1:m+1] + + λ = [1 / (2 * (k - 1 / 2)) for k = 1:n] + β = [1 / ((k - 1 / 2)^2 * Δθ^2) for k = 1:n] + + D = spdiagm(0 => β) + T = spdiagm(-1 => 1.0 .- λ[2:n], 0 => -2.0 * ones(n), 1 => 1.0 .+ λ[1:n-1]) + + A = spzeros(n * m, n * m) + for k = 1:m + A[1+(k-1)*n:k*n, 1+(k-1)*n:k*n] = T - 2 * D + if k ≤ m - 1 + A[1+k*n:(k+1)*n, 1+(k-1)*n:k*n] = D + A[1+(k-1)*n:k*n, 1+k*n:(k+1)*n] = D + end end - end - A[1+(m-1)*n : m*n, 1 : n] = D - A[1 : n, 1+(m-1)*n : m*n] = D - - b = zeros(n * m) - for i = 1 : n - for j = 1 : m - b[i + n*(j-1)] = Δr * Δr * f(r[i], θ[j]) - if i == n - b[i + n*(j-1)] -= (1.0 + λ[n]) * g(R, θ[j]) - end + A[1+(m-1)*n:m*n, 1:n] = D + A[1:n, 1+(m-1)*n:m*n] = D + + b = zeros(n * m) + for i = 1:n + for j = 1:m + b[i+n*(j-1)] = Δr * Δr * f(r[i], θ[j]) + if i == n + b[i+n*(j-1)] -= (1.0 + λ[n]) * g(R, θ[j]) + end + end end - end - return A, b + return A, b end # Model Poisson equation in cartesian coordinates -function cartesian_poisson(n, m, f, g; dim_x=[0.0, 1.0], dim_y=[0.0, 1.0]) - # Ω = ]xₗ,xᵣ[ × ]yₗ,yᵣ[ - # Ω ∪ ∂Ω = [xₗ,xᵣ] × [yₗ,yᵣ] - xₗ = dim_x[1] - xᵣ = dim_x[2] - - yₗ = dim_y[1] - yᵣ = dim_y[2] - - # Uniform grid of Ω with n × m points - Δx = (xᵣ - xₗ) / (n + 1) - x = [xₗ + i * Δx for i = 1 : n] - - Δy = (yᵣ - yₗ) / (m + 1) - y = [yₗ + j * Δy for j = 1 : m] - - A = spzeros(n * m, n * m) - for i = 1 : n - for j = 1 : m - A[i + (j-1)*n, i + (j-1)*n] = - 2.0 / (Δx * Δx) - 2.0 / (Δy * Δy) - if i ≥ 2 - A[i + (j-1)*n, i-1 + (j-1)*n] = 1.0 / (Δx * Δx) - end - if i ≤ n-1 - A[i + (j-1)*n, i+1 + (j-1)*n] = 1.0 / (Δx * Δx) - end - if j ≥ 2 - A[i + (j-1)*n, i + (j-2)*n] = 1.0 / (Δy * Δy) - end - if j ≤ m-1 - A[i + (j-1)*n, i + j*n] = 1.0 / (Δy * Δy) - end +function cartesian_poisson(n, m, f, g; dim_x = [0.0, 1.0], dim_y = [0.0, 1.0]) + # Ω = ]xₗ,xᵣ[ × ]yₗ,yᵣ[ + # Ω ∪ ∂Ω = [xₗ,xᵣ] × [yₗ,yᵣ] + xₗ = dim_x[1] + xᵣ = dim_x[2] + + yₗ = dim_y[1] + yᵣ = dim_y[2] + + # Uniform grid of Ω with n × m points + Δx = (xᵣ - xₗ) / (n + 1) + x = [xₗ + i * Δx for i = 1:n] + + Δy = (yᵣ - yₗ) / (m + 1) + y = [yₗ + j * Δy for j = 1:m] + + A = spzeros(n * m, n * m) + for i = 1:n + for j = 1:m + A[i+(j-1)*n, i+(j-1)*n] = -2.0 / (Δx * Δx) - 2.0 / (Δy * Δy) + if i ≥ 2 + A[i+(j-1)*n, i-1+(j-1)*n] = 1.0 / (Δx * Δx) + end + if i ≤ n - 1 + A[i+(j-1)*n, i+1+(j-1)*n] = 1.0 / (Δx * Δx) + end + if j ≥ 2 + A[i+(j-1)*n, i+(j-2)*n] = 1.0 / (Δy * Δy) + end + if j ≤ m - 1 + A[i+(j-1)*n, i+j*n] = 1.0 / (Δy * Δy) + end + end end - end - - b = zeros(n * m) - for i = 1 : n - for j = 1 : m - b[i + (j-1)*n] = f(x[i], y[j]) - if i == 1 - b[i + (j-1)*n] -= g(xₗ, y[j]) / (Δx * Δx) - end - if i == n - b[i + (j-1)*n] -= g(xᵣ, y[j]) / (Δx * Δx) - end - if j == 1 - b[i + (j-1)*n] -= g(x[i], yₗ) / (Δy * Δy) - end - if j == m - b[i + (j-1)*n] -= g(x[i], yᵣ) / (Δy * Δy) - end + + b = zeros(n * m) + for i = 1:n + for j = 1:m + b[i+(j-1)*n] = f(x[i], y[j]) + if i == 1 + b[i+(j-1)*n] -= g(xₗ, y[j]) / (Δx * Δx) + end + if i == n + b[i+(j-1)*n] -= g(xᵣ, y[j]) / (Δx * Δx) + end + if j == 1 + b[i+(j-1)*n] -= g(x[i], yₗ) / (Δy * Δy) + end + if j == m + b[i+(j-1)*n] -= g(x[i], yᵣ) / (Δy * Δy) + end + end end - end - return A, b + return A, b end diff --git a/test/utils.jl b/test/utils.jl index 8be9dec..477eb69 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,118 +1,132 @@ # Sparse Laplacian. include("get_div_grad.jl") -function sparse_laplacian(n :: Int=16; FC=Float64) - A = get_div_grad(n, n, n) - b = ones(n^3) - return A, b +function sparse_laplacian(n::Int = 16; FC = Float64) + A = get_div_grad(n, n, n) + b = ones(n^3) + return A, b end -function check(A,b) - tA, tb = sparse_laplacian(4, FC=Float64) - @test all(value.(tb) .== b) - @test all(value.(tA) .== A) +function check(A, b) + tA, tb = sparse_laplacian(4, FC = Float64) + @test all(value.(tb) .== b) + @test all(value.(tA) .== A) end function check_values(solver, A, b) - x = solver(A,b; atol=atol, rtol=rtol)[1] - db = Dual.(b) - dx = solver(A,db; atol=atol, rtol=rtol)[1] - @test all(dx .== x) + x = solver(A, b; atol = atol, rtol = rtol)[1] + db = Dual.(b) + dx = solver(A, db; atol = atol, rtol = rtol)[1] + @test all(dx .== x) end function check_jacobian(solver, A, b) - adJ = ForwardDiff.jacobian(x -> solver(A, x; atol=atol, rtol=rtol)[1], b) - fdm = central_fdm(8, 1); - fdJ = FiniteDifferences.jacobian(fdm, x -> solver(A, x; atol=atol, rtol=rtol)[1], copy(b)) - @test all(isapprox.(adJ, fdJ[1])) + adJ = ForwardDiff.jacobian(x -> solver(A, x; atol = atol, rtol = rtol)[1], b) + fdm = central_fdm(8, 1) + fdJ = FiniteDifferences.jacobian( + fdm, + x -> solver(A, x; atol = atol, rtol = rtol)[1], + copy(b), + ) + @test all(isapprox.(adJ, fdJ[1])) end function check_derivatives_and_values_active_active(solver, A, b, x) - fdm = central_fdm(8, 1); + fdm = central_fdm(8, 1) dualsA = copy(A) fill!(dualsA, 0.0) - dualsA[1,1] = 1.0 + dualsA[1, 1] = 1.0 dA = ForwardDiff.Dual.(A, dualsA) - check(A,b) + check(A, b) dualsb = copy(b) fill!(dualsb, 0.0) dualsb[1] = 1.0 db = ForwardDiff.Dual.(b, dualsb) - dx, stats = solver(dA,db; atol=atol, rtol=rtol) + dx, stats = solver(dA, db; atol = atol, rtol = rtol) all(isapprox(value.(dx), x)) function A_one_one(x) _A = copy(A) - _A[1,1] = x - solver(_A,b; atol=atol, rtol=rtol) + _A[1, 1] = x + solver(_A, b; atol = atol, rtol = rtol) end function b_one(x) _b = copy(b) _b[1] = x - solver(A,_b; atol=atol, rtol=rtol) + solver(A, _b; atol = atol, rtol = rtol) end - fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1,1])) + fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1, 1])) fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a)[1], copy(b[1])) isapprox(value.(dx), x) - fd =fda[1] + fdb[1] - @test isapprox(partials.(dx,1), fd) + fd = fda[1] + fdb[1] + @test isapprox(partials.(dx, 1), fd) end function check_derivatives_and_values_active_passive(solver, A, b, x) - fdm = central_fdm(8, 1); + fdm = central_fdm(8, 1) dualsA = copy(A) fill!(dualsA, 0.0) - dualsA[1,1] = 1.0 + dualsA[1, 1] = 1.0 dA = ForwardDiff.Dual.(A, dualsA) - check(A,b) + check(A, b) - dx, stats = solver(dA,b; atol=atol, rtol=rtol) + dx, stats = solver(dA, b; atol = atol, rtol = rtol) all(isapprox(value.(dx), x)) function A_one_one(x) _A = copy(A) - _A[1,1] = x - solver(_A,b; atol=atol, rtol=rtol) + _A[1, 1] = x + solver(_A, b; atol = atol, rtol = rtol) end - fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1,1])) + fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a)[1], copy(A[1, 1])) isapprox(value.(dx), x) - @test isapprox(partials.(dx,1), fda[1]) + @test isapprox(partials.(dx, 1), fda[1]) end -function driver!(solver::GmresSolver, A, b, M, N, ldiv=false) - gmres!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv) +function driver!(solver::GmresSolver, A, b, M, N, ldiv = false) + gmres!(solver, A, b, atol = 1e-16, rtol = 1e-16, M = M, N = N, verbose = 0, ldiv = ldiv) nothing end -function driver!(solver::BicgstabSolver, A, b, M, N, ldiv=false) - bicgstab!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, N=N, verbose=0, ldiv=ldiv) +function driver!(solver::BicgstabSolver, A, b, M, N, ldiv = false) + bicgstab!( + solver, + A, + b, + atol = 1e-16, + rtol = 1e-16, + M = M, + N = N, + verbose = 0, + ldiv = ldiv, + ) nothing end -function driver!(solver::CgSolver, A, b, M, N, ldiv=false) - cg!(solver, A,b, atol=1e-16, rtol=1e-16, M=M, verbose=0, ldiv=ldiv) +function driver!(solver::CgSolver, A, b, M, N, ldiv = false) + cg!(solver, A, b, atol = 1e-16, rtol = 1e-16, M = M, verbose = 0, ldiv = ldiv) nothing end -function test_enzyme_with(solver, A, b, M, N, ldiv=false) +function test_enzyme_with(solver, A, b, M, N, ldiv = false) tsolver = if solver == Krylov.cg - CgSolver(A,b) + CgSolver(A, b) elseif solver == Krylov.gmres - GmresSolver(A,b) + GmresSolver(A, b) elseif solver == Krylov.bicgstab - BicgstabSolver(A,b) + BicgstabSolver(A, b) else error("Unsupported solver $solver is tested in DiffKrylov.jl") end - fdm = central_fdm(8, 1); + fdm = central_fdm(8, 1) function A_one_one(hx) _A = copy(A) - _A[1,1] = hx + _A[1, 1] = hx # fill!(tsolver.x, zero(eltype(solver.x))) driver!(tsolver, _A, b, M, N, ldiv) return tsolver.x[1] @@ -126,9 +140,9 @@ function test_enzyme_with(solver, A, b, M, N, ldiv=false) return tsolver.x[1] end - fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a), copy(A[1,1])) + fda = FiniteDifferences.jacobian(fdm, a -> A_one_one(a), copy(A[1, 1])) fdb = FiniteDifferences.jacobian(fdm, a -> b_one(a), copy(b[1])) - fd =fda[1] + fdb[1] + fd = fda[1] + fdb[1] # Test forward function duplicate(A::SparseMatrixCSC) dA = copy(A) @@ -141,35 +155,17 @@ function test_enzyme_with(solver, A, b, M, N, ldiv=false) db = Duplicated(b, zeros(length(b))) dupsolver = Duplicated(tsolver, deepcopy(tsolver)) fill!(dupsolver.dval.x, zero(eltype(dupsolver.dval.x))) - dA.dval[1,1] = 1.0 + dA.dval[1, 1] = 1.0 db.dval[1] = 1.0 - Enzyme.autodiff( - Forward, - driver!, - dupsolver, - dA, - db, - Const(M), - Const(N), - Const(ldiv) - ) - @test isapprox(dupsolver.dval.x[1], fd[1][1], atol=1e-4, rtol=1e-4) + Enzyme.autodiff(Forward, driver!, dupsolver, dA, db, Const(M), Const(N), Const(ldiv)) + @test isapprox(dupsolver.dval.x[1], fd[1][1], atol = 1e-4, rtol = 1e-4) # Test reverse dA = Duplicated(A, duplicate(A)) db = Duplicated(b, zeros(length(b))) dupsolver = Duplicated(tsolver, deepcopy(tsolver)) fill!(dupsolver.dval.x, zero(eltype(dupsolver.dval.x))) dupsolver.dval.x[1] = 1.0 - Enzyme.autodiff( - Reverse, - driver!, - dupsolver, - dA, - db, - Const(M), - Const(N), - Const(ldiv) - ) - @test isapprox(db.dval[1], fdb[1][1], atol=1e-4, rtol=1e-4) - @test isapprox(dA.dval[1,1], fda[1][1], atol=1e-4, rtol=1e-4) + Enzyme.autodiff(Reverse, driver!, dupsolver, dA, db, Const(M), Const(N), Const(ldiv)) + @test isapprox(db.dval[1], fdb[1][1], atol = 1e-4, rtol = 1e-4) + @test isapprox(dA.dval[1, 1], fda[1][1], atol = 1e-4, rtol = 1e-4) end