Skip to content

Commit

Permalink
Fixup more than simple jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Aug 8, 2024
1 parent 356ef34 commit 556ca4b
Show file tree
Hide file tree
Showing 3 changed files with 323 additions and 43 deletions.
236 changes: 193 additions & 43 deletions src/Enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1002,6 +1002,10 @@ end
end
end

@inline function onehot(x::AbstractFloat)
return (one(x),)
end

"""
gradient(::ReverseMode, f, x)
Expand Down Expand Up @@ -1126,10 +1130,15 @@ grad = gradient(Forward, f, [2.0, 3.0])
```
"""
@inline function gradient(::ForwardMode, f, x; shadow=onehot(x))
if length(x) == 0
if length(shadow) == 0
return ()
end
values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow))))
res = values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow))))
if x isa AbstractFloat
res[1]
else
res
end
end

@inline function chunkedonehot(x, ::Val{chunk}) where chunk
Expand All @@ -1141,6 +1150,10 @@ end
end
end

@inline function chunkedonehot(x::AbstractFloat, ::Val{chunk}) where chunk
return ((one(x),),)
end

@inline tupleconcat(x) = x
@inline tupleconcat(x, y) = (x..., y...)
@inline tupleconcat(x, y, z...) = (x..., tupleconcat(y, z...)...)
Expand Down Expand Up @@ -1171,62 +1184,140 @@ grad = gradient(Forward, f, [2.0, 3.0], Val(2))
tmp = ntuple(length(shadow)) do i
values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1])
end
tupleconcat(tmp...)
res = tupleconcat(tmp...)
if x isa AbstractFloat
res[1]
else
res
end
end

@inline function gradient(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F, X}
ntuple(length(shadow)) do i
res = ntuple(length(shadow)) do i
autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1]
end
if x isa AbstractFloat
res[1]
else
res
end
end

"""
jacobian(::ForwardMode, f, x; shadow=onehot(x))
jacobian(::ForwardMode, f, x, ::Val{chunk}; shadow=onehot(x))
Compute the jacobian of an array-input function `f` using (potentially vector)
forward mode. This is a simple rename of the [`gradient`](@ref) function,
and all relevant arguments apply here.
Compute the jacobian of an array or scalar-input function `f` using (potentially vector)
forward mode. All relevant arguments of the forward-mode [`gradient`](@ref) function
apply here.
Example:
```jldoctest
f(x) = [x[1]*x[2], x[2]]
f(x) = [ x[1] * x[2], x[2] + x[3] ]
grad = jacobian(Forward, f, [2.0, 3.0])
grad = jacobian(Forward, f, [2.0, 3.0, 4.0])
# output
2 Matrix{Float64}:
3.0 2.0
0.0 1.0
3 Matrix{Float64}:
3.0 2.0 0.0
0.0 1.0 1.0
```
For functions which return an AbstractArray, this function will return an array
whose shape is `(size(output)..., size(input)...)`
For functions who return other types, this function will retun an array or tuple
of shape `size(input)` of values of the output type.
"""
@inline function jacobian(::ForwardMode, f, x; shadow=onehot(x))
cols = if length(x) == 0
return ()
cols = if length(shadow) == 0
()
else
values(only(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow))))
end
reduce(hcat, cols)
if x isa AbstractFloat
cols[1]
elseif length(cols) > 0 && cols[1] isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)

st3 = if length(inshape) <= 1
st
else
reshape(st, (outshape..., inshape...))
end

st3
elseif x isa AbstractArray
inshape = size(x)
reshape(collect(cols), inshape)
else
cols
end
end

@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{chunk}; shadow=chunkedonehot(x, Val(chunk))) where {F, X, chunk}
if chunk == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end
tmp = ntuple(length(shadow)) do i
Base.@_inline_meta
values(autodiff(Forward, f, BatchDuplicatedNoNeed, BatchDuplicated(x, shadow[i]))[1])
end
cols = tupleconcat(tmp...)
reduce(hcat, cols)
if x isa AbstractFloat
cols[1]
elseif length(cols) > 0 && cols[1] isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)

st3 = if length(inshape) <= 1
st
else
reshape(st, (outshape..., inshape...))
end

st3
elseif x isa AbstractArray
inshape = size(x)
reshape(collect(cols), inshape)
else
cols
end
end

@inline function jacobian(::ForwardMode, f::F, x::X, ::Val{1}; shadow=onehot(x)) where {F,X}
cols = ntuple(length(shadow)) do i
Base.@_inline_meta
autodiff(Forward, f, DuplicatedNoNeed, Duplicated(x, shadow[i]))[1]
end
reduce(hcat, cols)
if x isa AbstractFloat
cols[1]
elseif length(cols) > 0 && cols[1] isa AbstractArray
inshape = size(x)
outshape = size(cols[1])
# st : outshape x total inputs
st = Base.stack(cols)

st3 = if length(inshape) <= 1
st
else
reshape(st, (outshape..., inshape...))
end

st3
elseif x isa AbstractArray
inshape = size(x)
reshape(collect(cols), inshape)
else
cols
end
end

"""
Expand All @@ -1239,27 +1330,35 @@ denotes the number of outputs `f` will return in an array.
Example:
```jldoctest
f(x) = [x[1]*x[2], x[2]]
f(x) = [ x[1] * x[2], x[2] + x[3] ]
grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))
grad = jacobian(Reverse, f, [2.0, 3.0, 4.0])
# output
2×2 Matrix{Float64}:
3.0 2.0
0.0 1.0
2×3 Matrix{Float64}:
3.0 2.0 0.0
0.0 1.0 1.0
```
For functions which return an AbstractArray, this function will return an array
whose shape is `(size(output)..., size(input)...)`
For functions who return other types, this function will retun an array or tuple
of shape `size(output)` of values of the input type.
```
"""
@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, ReturnPrimal, RABI<:ABI, ErrIfFuncWritten}
@assert !ReturnPrimal
@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{chunk}) where {F, X, chunk, n_out_val, RABI<:ABI, ErrIfFuncWritten}
num = ((n_out_val + chunk - 1) ÷ chunk)

if chunk == 0
throw(ErrorException("Cannot differentiate with a batch size of 0"))
end

tt′ = Tuple{BatchDuplicated{Core.Typeof(x), chunk}}
tt = Tuple{Core.Typeof(x)}
XT = Core.Typeof(x)
MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState
tt′ = MD ? Tuple{BatchMixedDuplicated{XT, chunk}} : Tuple{BatchDuplicated{XT, chunk}}
tt = Tuple{XT}
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
Expand All @@ -1281,28 +1380,55 @@ grad = jacobian(Reverse, f, [2.0, 3.0], Val(2))

tmp = ntuple(num) do i
Base.@_inline_meta
dx = ntuple(i == num ? last_size : chunk) do idx
dx = ntuple(Val(i == num ? last_size : chunk)) do idx
Base.@_inline_meta
zero(x)
z = make_zero(x)
MD ? Ref(z) : z
end
res = (i == num ? primal2 : primal)(Const(f), BatchDuplicated(x, dx))
res = (i == num ? primal2 : primal)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx))
tape = res[1]
j = 0
for shadow in res[3]
j += 1
@inbounds shadow[(i-1)*chunk+j] += Compiler.default_adjoint(eltype(typeof(shadow)))
end
(i == num ? adjoint2 : adjoint)(Const(f), BatchDuplicated(x, dx), tape)
return dx
(i == num ? adjoint2 : adjoint)(Const(f), MD ? BatchMixedDuplicated(x, dx) : BatchDuplicated(x, dx), tape)
return MD ? (ntuple(Val(i == num ? last_size : chunk)) do idx
Base.@_inline_meta
dx[idx][]
end) : dx, (i == 1 ? size(res[3][1]) : nothing)
end
@show tmp
rows = tupleconcat(map(first, tmp)...)
@show rows
outshape = tmp[1][2]
if x isa AbstractArray
inshape = size(x)
st = Base.stack(rows)
st2 = if length(outshape) == 1
st
else
reshape(st, (inshape..., outshape...))
end

st3 = if length(outshape) == 1 && length(inshape) == 1
transpose(st2)
else
transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... )
PermutedDimsArray(st2, transp)
end

st3
else
reshape(collect(rows), outshape)
end
rows = tupleconcat(tmp...)
mapreduce(LinearAlgebra.adjoint, vcat, rows)
end

@inline function jacobian(::ReverseMode{ReturnPrimal,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,ReturnPrimal,RABI<:ABI, ErrIfFuncWritten}
@assert !ReturnPrimal
tt′ = Tuple{Duplicated{Core.Typeof(x)}}
tt = Tuple{Core.Typeof(x)}
@inline function jacobian(::ReverseMode{#=ReturnPrimal=#false,RABI, ErrIfFuncWritten}, f::F, x::X, n_outs::Val{n_out_val}, ::Val{1} = Val(1)) where {F, X, n_out_val,RABI<:ABI, ErrIfFuncWritten}
XT = Core.Typeof(x)
MD = Compiler.active_reg_inner(XT, #=seen=#(), #=world=#nothing, #=justActive=#Val(true)) == Compiler.ActiveState
tt′ = MD ? Tuple{MixedDuplicated{XT}} : Tuple{Duplicated{XT}}
tt = Tuple{XT}
rt = Core.Compiler.return_type(f, tt)
ModifiedBetween = Val((false, false))
FA = Const{Core.Typeof(f)}
Expand All @@ -1312,16 +1438,40 @@ end
Val(codegen_world_age(Core.Typeof(f), tt))
end
primal, adjoint = Enzyme.Compiler.thunk(opt_mi, FA, DuplicatedNoNeed{rt}, tt′, #=Split=# Val(API.DEM_ReverseModeGradient), #=width=#Val(1), ModifiedBetween, #=ReturnPrimal=#Val(false), #=ShadowInit=#Val(false), RABI, Val(ErrIfFuncWritten))
rows = ntuple(n_outs) do i
tmp = ntuple(n_outs) do i
Base.@_inline_meta
dx = zero(x)
res = primal(Const(f), Duplicated(x, dx))
z = make_zero(x)
dx = MD ? Ref(z) : z
res = primal(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx))
tape = res[1]
@inbounds res[3][i] += Compiler.default_adjoint(eltype(typeof(res[3])))
adjoint(Const(f), Duplicated(x, dx), tape)
return dx
adjoint(Const(f), MD ? MixedDuplicated(x, dx) : Duplicated(x, dx), tape)
return MD ? dx[] : dx, (i == 1 ? size(res[3]) : nothing)
end
@show tmp
rows = map(first, tmp)
@show rows
outshape = tmp[1][2]
if x isa AbstractArray
inshape = size(x)
st = Base.stack(rows)
st2 = if length(outshape) == 1
st
else
reshape(st, (inshape..., outshape...))
end

st3 = if length(outshape) == 1 && length(inshape) == 1
transpose(st2)
else
transp = ( ((length(inshape)+1):(length(inshape)+length(outshape)))... , (1:length(inshape))... )
PermutedDimsArray(st2, transp)
end

st3
else
reshape(collect(rows), outshape)
end
mapreduce(LinearAlgebra.adjoint, vcat, rows)
end

"""
Expand Down
5 changes: 5 additions & 0 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3886,7 +3886,12 @@ include("rules/activityrules.jl")
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: DuplicatedNoNeed = API.DFT_DUP_NONEED
@inline Base.convert(::Type{API.CDIFFE_TYPE}, ::Type{A}) where A <: BatchDuplicatedNoNeed = API.DFT_DUP_NONEED

const DumpPreEnzyme = Ref(false)

function enzyme!(job, mod, primalf, TT, mode, width, parallel, actualRetType, wrap, modifiedBetween, returnPrimal, expectedTapeType, loweredArgs, boxedArgs)
if DumpPreEnzyme[]
API.EnzymeDumpModuleRef(mod.ref)
end
world = job.world
interp = GPUCompiler.get_interpreter(job)
rt = job.config.params.rt
Expand Down
Loading

0 comments on commit 556ca4b

Please sign in to comment.