Skip to content

Commit

Permalink
reviving linear layer and tests - to be cleaned up soon
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Jun 15, 2024
1 parent 7d9f677 commit 3c65a77
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 35 deletions.
67 changes: 33 additions & 34 deletions src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,65 +33,64 @@ out, st = l(x, ps, st)
println(out == x * transpose(W))) # true
```
"""
struct LinearLayer{FEATFIRST} <: AbstractExplicitLayer
struct LinearLayer{FEATFIRST} <: AbstractP4MLLayer
in_dim::Integer
out_dim::Integer
use_cache::Bool
@reqfields()
end

LinearLayer(in_dim::Int, out_dim::Int; feature_first = false, use_cache = true) = LinearLayer{feature_first}(in_dim, out_dim, use_cache, _make_reqfields()...)
LinearLayer(in_dim::Int, out_dim::Int; feature_first = false) = LinearLayer{feature_first}(in_dim, out_dim, _make_reqfields()...)

_valtype(l::LinearLayer, x::AbstractArray, ps) =
promote_type(eltype(x), eltype(ps.W))
# ----------------------- evaluation and allocation interfaces

function (l::LinearLayer)(x::AbstractVector, ps, st)
out = acquire!(st.pool, :A, (l.out_dim, ), _valtype(l, x, ps))
mul!(unwrap(out), ps.W, unwrap(x)); release!(x);
return out, st
_valtype(l::LinearLayer, x::AbstractArray, ps, st) = promote_type(eltype(x), eltype(ps.W))
_gradtype(l::LinearLayer, x, ps, st) = promote_type(eltype(x), eltype(ps.W))

_out_size(l::LinearLayer, x::AbstractVector, ps, st) = (l.out_dim, )
_out_size(l::LinearLayer{true}, x::AbstractMatrix, ps, st) = (l.out_dim, size(x, 2))
_out_size(l::LinearLayer{false}, x::AbstractMatrix, ps, st) = (size(x, 1), l.out_dim)

LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) = ( W = randn(rng, l.out_dim, l.in_dim), )
LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) = NamedTuple()

function whatalloc(::typeof(evaluate!), l::LinearLayer, x::AbstractArray, ps, st)
TV = _valtype(l, x, ps, st)
sz = _out_size(l, x, ps, st)
return (TV, sz...)
end

function (l::LinearLayer{true})(x::AbstractMatrix, ps, st)
out = acquire!(st.pool, :bA, (l.out_dim, size(x, 2)), _valtype(l, x, ps));
mul!(unwrap(out), ps.W, unwrap(x)); release!(x);
function evaluate!(out, l::LinearLayer, x::AbstractVecOrMat, ps, st)
mul!(out, ps.W, x)
return out, st
end

(l::LinearLayer{false})(x::AbstractMatrix, ps, st) = begin
out = acquire!(st.pool, :bA, (size(x, 1), l.out_dim), _valtype(l, x, ps));
mul!(unwrap(out), unwrap(x), transpose(PtrArray(ps.W))); release!(x);
function evaluate!(out, l::LinearLayer{false}, x::AbstractMatrix, ps, st)
mul!(out, x, transpose(PtrArray(ps.W)))
return out, st
end

# Jerry: Maybe we should use Glorot Uniform if we have no idea about what we should use?
LuxCore.initialparameters(rng::AbstractRNG, l::LinearLayer) =
( W = randn(rng, l.out_dim, l.in_dim), )

LuxCore.initialstates(rng::AbstractRNG, l::LinearLayer) =
( l.use_cache ? (pool = ArrayPool(FlexArrayCache), )
: (pool = ArrayPool(FlexArray), ))


# --------------------- connect with ChainRules
# can this be generalized again?
# TODO: check whether we can do this without multiple dispatch on vec/mat without loss of performance
function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractVector, ps, st)
val = l(x, ps, st)
function pb(A)
return NoTangent(), NoTangent(), ps.W' * A[1], (W = A[1] * x',), NoTangent()
end
return val, pb
end

function rrule(::typeof(LuxCore.apply), l::LinearLayer, x::AbstractMatrix, ps, st)
import ChainRulesCore: rrule, NoTangent

function rrule(::typeof(evaluate), l::LinearLayer, x::AbstractVecOrMat, ps, st)
val = l(x, ps, st)

function pb(A)
return NoTangent(), NoTangent(), ps.W' * A[1], (W = A[1] * x',), NoTangent()
end

return val, pb
end

function rrule(::typeof(LuxCore.apply), l::LinearLayer{false}, x::AbstractMatrix, ps, st)
function rrule(::typeof(evaluate), l::LinearLayer{false}, x::AbstractMatrix, ps, st)
val = l(x, ps, st)

function pb(A)
return NoTangent(), NoTangent(), A[1] * ps.W, (W = transpose(PtrArray(A[1])) * unwrap(x),), NoTangent()
return NoTangent(), NoTangent(), A[1] * ps.W, (W = transpose(PtrArray(A[1])) * x,), NoTangent()
end

return val, pb
end
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ using Test
@testset "Sparse Symmetric Product" begin include("ace/test_sparsesymmprod.jl"); end
@testset "Sparse Symmetric Product - DAG" begin include("ace/test_sparsesymmproddag.jl"); end
@testset "Sparse Product" begin include("test_sparseproduct.jl"); end
# @testset "Linear layer" begin include("test_linear.jl"); end
@testset "Linear layer" begin include("test_linear.jl"); end

# Misc
@testset "Static Prod" begin include("test_staticprod.jl"); end
Expand Down

0 comments on commit 3c65a77

Please sign in to comment.