diff --git a/src/linear.jl b/src/linear.jl index 0278ebe..071dc45 100644 --- a/src/linear.jl +++ b/src/linear.jl @@ -33,65 +33,64 @@ out, st = l(x, ps, st) println(out == x * transpose(W))) # true ``` """ -struct LinearLayer{FEATFIRST} <: AbstractExplicitLayer +struct LinearLayer{FEATFIRST} <: AbstractP4MLLayer in_dim::Integer out_dim::Integer - use_cache::Bool @reqfields() end -LinearLayer(in_dim::Int, out_dim::Int; feature_first = false, use_cache = true) = LinearLayer{feature_first}(in_dim, out_dim, use_cache, _make_reqfields()...) +LinearLayer(in_dim::Int, out_dim::Int; feature_first = false) = LinearLayer{feature_first}(in_dim, out_dim, _make_reqfields()...) -_valtype(l::LinearLayer, x::AbstractArray, ps) = - promote_type(eltype(x), eltype(ps.W)) +# ----------------------- evaluation and allocation interfaces -function (l::LinearLayer)(x::AbstractVector, ps, st) - out = acquire!(st.pool, :A, (l.out_dim, ), _valtype(l, x, ps)) - mul!(unwrap(out), ps.W, unwrap(x)); release!(x); - return out, st +_valtype(l::LinearLayer, x::AbstractArray, ps, st) = promote_type(eltype(x), eltype(ps.W)) +_gradtype(l::LinearLayer, x, ps, st) = promote_type(eltype(x), eltype(ps.W)) + +_out_size(l::LinearLayer, x::AbstractVector, ps, st) = (l.out_dim, ) +_out_size(l::LinearLayer{true}, x::AbstractMatrix, ps, st) = (l.out_dim, size(x, 2)) +_out_size(l::LinearLayer{false}, x::AbstractMatrix, ps, st) = (size(x, 1), l.out_dim) + +LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = ( W = randn(rng, l.out_dim, l.in_dim), ) +LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = NamedTuple() + +function whatalloc(::typeof(evaluate!), l::LinearLayer, x::AbstractArray, ps, st) + TV = _valtype(l, x, ps, st) + sz = _out_size(l, x, ps, st) + return (TV, sz...) end -function (l::LinearLayer{true})(x::AbstractMatrix, ps, st) - out = acquire!(st.pool, :bA, (l.out_dim, size(x, 2)), _valtype(l, x, ps)); - mul!(unwrap(out), ps.W, unwrap(x)); release!(x); +function evaluate!(out, l::LinearLayer, x::AbstractVecOrMat, ps, st) + mul!(out, ps.W, x) return out, st end -(l::LinearLayer{false})(x::AbstractMatrix, ps, st) = begin - out = acquire!(st.pool, :bA, (size(x, 1), l.out_dim), _valtype(l, x, ps)); - mul!(unwrap(out), unwrap(x), transpose(PtrArray(ps.W))); release!(x); +function evaluate!(out, l::LinearLayer{false}, x::AbstractMatrix, ps, st) + mul!(out, x, transpose(PtrArray(ps.W))) return out, st end - -# Jerry: Maybe we should use Glorot Uniform if we have no idea about what we should use? -LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = - ( W = randn(rng, l.out_dim, l.in_dim), ) - -LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = - ( l.use_cache ? (pool = ArrayPool(FlexArrayCache), ) - : (pool = ArrayPool(FlexArray), )) - + +# --------------------- connect with ChainRules +# can this be generalized again? # TODO: check whether we can do this without multiple dispatch on vec/mat without loss of performance -function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractVector, ps, st) - val = l(x, ps, st) - function pb(A) - return NoTangent(), NoTangent(), ps.W' * A[1], (W = A[1] * x',), NoTangent() - end - return val, pb -end -function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractMatrix, ps, st) +import ChainRulesCore: rrule, NoTangent + +function rrule(::typeof(evaluate), l::LinearLayer, x::AbstractVecOrMat, ps, st) val = l(x, ps, st) + function pb(A) return NoTangent(), NoTangent(), ps.W' * A[1], (W = A[1] * x',), NoTangent() end + return val, pb end -function rrule(::typeof(LuxCore.apply), l::LinearLayer{false}, x::AbstractMatrix, ps, st) +function rrule(::typeof(evaluate), l::LinearLayer{false}, x::AbstractMatrix, ps, st) val = l(x, ps, st) + function pb(A) - return NoTangent(), NoTangent(), A[1] * ps.W, (W = transpose(PtrArray(A[1])) * unwrap(x),), NoTangent() + return NoTangent(), NoTangent(), A[1] * ps.W, (W = transpose(PtrArray(A[1])) * x,), NoTangent() end + return val, pb end \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index f10ac63..be33f4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -22,7 +22,7 @@ using Test @testset "Sparse Symmetric Product" begin include("ace/test_sparsesymmprod.jl"); end @testset "Sparse Symmetric Product - DAG" begin include("ace/test_sparsesymmproddag.jl"); end @testset "Sparse Product" begin include("test_sparseproduct.jl"); end - # @testset "Linear layer" begin include("test_linear.jl"); end + @testset "Linear layer" begin include("test_linear.jl"); end # Misc @testset "Static Prod" begin include("test_staticprod.jl"); end