Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

feat!: 1.0 release #43

Merged
merged 22 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,29 +1,31 @@
name = "LuxCore"
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.1.25"
version = "1.0.0"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"

[weakdeps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[extensions]
LuxCoreArrayInterfaceReverseDiffExt = ["ArrayInterface", "ReverseDiff"]
LuxCoreArrayInterfaceTrackerExt = ["ArrayInterface", "Tracker"]
LuxCoreChainRulesCoreExt = "ChainRulesCore"
LuxCoreEnzymeCoreExt = "EnzymeCore"
LuxCoreFunctorsExt = "Functors"
LuxCoreMLDataDevicesExt = "MLDataDevices"
LuxCoreSetfieldExt = "Setfield"

[compat]
ArrayInterface = "7.9"
Expand Down
11 changes: 5 additions & 6 deletions ext/LuxCoreArrayInterfaceReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,22 @@
module LuxCoreArrayInterfaceReverseDiffExt

using ArrayInterface: ArrayInterface
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractLuxLayer
using ReverseDiff: TrackedReal, TrackedArray

# AoS to SoA conversion
function LuxCore.apply(
m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st)
@warn "Lux.apply(m::AbstractExplicitLayer, \
m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st)
@warn "Lux.apply(m::AbstractLuxLayer, \

Check warning on line 10 in ext/LuxCoreArrayInterfaceReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxCoreArrayInterfaceReverseDiffExt.jl#L10

Added line #L10 was not covered by tests
x::AbstractArray{<:ReverseDiff.TrackedReal}, ps, st) input was corrected to \
Lux.apply(m::AbstractExplicitLayer, x::ReverseDiff.TrackedArray}, ps, \
st).\n\n\
Lux.apply(m::AbstractLuxLayer, x::ReverseDiff.TrackedArray}, ps, st).\n\n\
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
2. This might have performance implications. Check which layer was causing this \
problem using `Lux.Experimental.@debug_mode`." maxlog=1
return LuxCore.apply(m, reshape(ArrayInterface.aos_to_soa(x), size(x)), ps, st)
end

## Prevent an infinite loop
LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st)

Check warning on line 20 in ext/LuxCoreArrayInterfaceReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxCoreArrayInterfaceReverseDiffExt.jl#L20

Added line #L20 was not covered by tests

end
10 changes: 5 additions & 5 deletions ext/LuxCoreArrayInterfaceTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
module LuxCoreArrayInterfaceTrackerExt

using ArrayInterface: ArrayInterface
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractLuxLayer
using Tracker: TrackedReal, TrackedArray

# AoS to SoA conversion
function LuxCore.apply(m::AbstractExplicitLayer, x::AbstractArray{<:TrackedReal}, ps, st)
@warn "LuxCore.apply(m::AbstractExplicitLayer, \
function LuxCore.apply(m::AbstractLuxLayer, x::AbstractArray{<:TrackedReal}, ps, st)
@warn "LuxCore.apply(m::AbstractLuxLayer, \

Check warning on line 9 in ext/LuxCoreArrayInterfaceTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxCoreArrayInterfaceTrackerExt.jl#L8-L9

Added lines #L8 - L9 were not covered by tests
x::AbstractArray{<:Tracker.TrackedReal}, ps, st) input was corrected to \
LuxCore.apply(m::AbstractExplicitLayer, x::Tracker.TrackedArray}, ps, st).\n\n\
LuxCore.apply(m::AbstractLuxLayer, x::Tracker.TrackedArray}, ps, st).\n\n\
1. If this was not the desired behavior overload the dispatch on `m`.\n\n\
2. This might have performance implications. Check which layer was causing this \
problem using `Lux.Experimental.@debug_mode`." maxlog=1
return LuxCore.apply(m, ArrayInterface.aos_to_soa(x), ps, st)
end

## Prevent an infinite loop
LuxCore.apply(m::AbstractExplicitLayer, x::TrackedArray, ps, st) = m(x, ps, st)
LuxCore.apply(m::AbstractLuxLayer, x::TrackedArray, ps, st) = m(x, ps, st)

Check warning on line 19 in ext/LuxCoreArrayInterfaceTrackerExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxCoreArrayInterfaceTrackerExt.jl#L19

Added line #L19 was not covered by tests

end
4 changes: 2 additions & 2 deletions ext/LuxCoreChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
module LuxCoreChainRulesCoreExt

using ChainRulesCore: ChainRulesCore, @non_differentiable
using LuxCore: LuxCore, AbstractExplicitLayer
using LuxCore: LuxCore, AbstractLuxLayer
using Random: AbstractRNG

@non_differentiable LuxCore.replicate(::AbstractRNG)

function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractExplicitLayer, x::Symbol)
function ChainRulesCore.rrule(::typeof(getproperty), m::AbstractLuxLayer, x::Symbol)

Check warning on line 9 in ext/LuxCoreChainRulesCoreExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/LuxCoreChainRulesCoreExt.jl#L9

Added line #L9 was not covered by tests
mₓ = getproperty(m, x)
∇getproperty(_) = ntuple(Returns(ChainRulesCore.NoTangent()), 3)
return mₓ, ∇getproperty
Expand Down
6 changes: 3 additions & 3 deletions ext/LuxCoreEnzymeCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ compute the gradients w.r.t. the layer's parameters, use the first argument retu
by `LuxCore.setup(rng, layer)` instead.
"""

function EnzymeCore.Active(::LuxCore.AbstractExplicitLayer)
function EnzymeCore.Active(::LuxCore.AbstractLuxLayer)
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
end

for annotation in (:Duplicated, :DuplicatedNoNeed)
@eval function EnzymeCore.$(annotation)(
::LuxCore.AbstractExplicitLayer, ::LuxCore.AbstractExplicitLayer)
::LuxCore.AbstractLuxLayer, ::LuxCore.AbstractLuxLayer)
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
end
end

for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed)
@eval function EnzymeCore.$(annotation)(
::LuxCore.AbstractExplicitLayer, ::NTuple{N, <:LuxCore.AbstractExplicitLayer},
::LuxCore.AbstractLuxLayer, ::NTuple{N, <:LuxCore.AbstractLuxLayer},
check::Bool=true) where {N}
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
end
Expand Down
33 changes: 33 additions & 0 deletions ext/LuxCoreFunctorsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
module LuxCoreFunctorsExt

using LuxCore: LuxCore
using Functors: Functors

LuxCore.Internal.is_extension_loaded(::Val{:Functors}) = true

LuxCore.Internal.isleaf_impl(args...; kwargs...) = Functors.isleaf(args...; kwargs...)
LuxCore.Internal.fmap_impl(args...; kwargs...) = Functors.fmap(args...; kwargs...)
function LuxCore.Internal.fmap_with_path_impl(args...; kwargs...)
return Functors.fmap_with_path(args...; kwargs...)
end
LuxCore.Internal.fleaves_impl(args...; kwargs...) = Functors.fleaves(args...; kwargs...)

function Functors.functor(::Type{<:LuxCore.AbstractLuxContainerLayer{layers}},
x) where {layers}
children = NamedTuple{layers}(getproperty.((x,), layers))
layer_reconstructor = let x = x, layers = layers
z -> reduce(LuxCore.Internal.setfield, zip(layers, z); init=x)
end
return children, layer_reconstructor
end

function Functors.functor(::Type{<:LuxCore.AbstractLuxWrapperLayer{layer}},
x) where {layer}
children = NamedTuple{(layer,)}((getproperty(x, layer),))
layer_reconstructor = let x = x, layer = layer
z -> LuxCore.Internal.setfield(x, layer, getproperty(z, layer))
end
return children, layer_reconstructor
end

end
2 changes: 1 addition & 1 deletion ext/LuxCoreMLDataDevicesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using MLDataDevices: MLDataDevices

for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI)
ldev = Symbol(dev, :Device)
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractExplicitLayer)
@eval function (::MLDataDevices.$(ldev))(NN::LuxCore.AbstractLuxLayer)
@warn "Lux layers are stateless and hence don't participate in device transfers. \
Apply this function on the parameters and states generated using \
`LuxCore.setup`."
Expand Down
15 changes: 15 additions & 0 deletions ext/LuxCoreSetfieldExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
module LuxCoreSetfieldExt

using LuxCore: LuxCore
using Setfield: Setfield

LuxCore.Internal.is_extension_loaded(::Val{:Setfield}) = true

function LuxCore.Internal.setfield_impl(x, prop, val)
return Setfield.set(x, Setfield.PropertyLens{prop}(), val)
end
function LuxCore.Internal.setfield_impl(x, (prop, val))
return LuxCore.Internal.setfield_impl(x, prop, val)
end

end
Loading
Loading