Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
CheukHinHoJerry committed Jun 15, 2024
1 parent 3c65a77 commit a276da5
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/Polynomials4ML.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ include("sparseproduct.jl")
# LinearLayer implementation
# this is needed to better play with cached arrays + to give the correct
# behaviour when the feature dimension is different from expected.
# include("linear.jl")
include("linear.jl")

# generic machinery for wrapping poly4ml bases into lux layers
include("lux.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ out, st = l(x, ps, st)
println(out == x * transpose(W))) # true
```
"""
struct LinearLayer{FEATFIRST} <: AbstractP4MLLayer
struct LinearLayer{FEATFIRST} <: AbstractP4MLTensor
in_dim::Integer
out_dim::Integer
@reqfields()
Expand Down
17 changes: 6 additions & 11 deletions test/test_linear.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
if !feat
@info("Testing evaluate on vector input vs batch input")
for ntest = 1:30
X = randn(N, in_d)
X = randn(N, in_d)
Y1, _ = l(X, ps, st)
Y2 = hcat([l(X[i,:], ps, st)[1] for i = 1:N]...)'
Y3 = hcat([ps.W * X[i,:] for i = 1:N]...)'
Expand All @@ -52,7 +52,7 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
u = randn(size(val))
F(t) = dot(u, l(_BB(t), ps, st)[1])
dF(t) = begin
val, pb = Zygote.pullback(LuxCore.apply, l, _BB(t), ps, st)
val, pb = Zygote.pullback(P4ML.evaluate, l, _BB(t), ps, st)
∂BB = pb((u, st))[2]
return dot(∂BB, bu)
end
Expand All @@ -79,7 +79,7 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
u = randn(size(val))
F(t) = dot(u, l(_BB(t), ps, st)[1])
dF(t) = begin
val, pb = Zygote.pullback(LuxCore.apply, l, _BB(t), ps, st)
val, pb = Zygote.pullback(P4ML.evaluate, l, _BB(t), ps, st)
∂BB = pb((u, st))[2]
return dot(∂BB, bu)
end
Expand All @@ -99,7 +99,7 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
u = randn(size(val))
F(t) = dot(u, l(x, re([_BB(t)...]), st)[1])
dF(t) = begin
val, pb = Zygote.pullback(LuxCore.apply, l, x, re([_BB(t)...]), st)
val, pb = Zygote.pullback(P4ML.evaluate, l, x, re([_BB(t)...]), st)
∂BB = pb((u, st))[3]
return dot(∂BB[1], bu)
end
Expand All @@ -113,21 +113,16 @@ for (feat, in_size, out_fun) in zip(feature_arr, in_size_arr, out_fun_arr)
# test_rrule(LuxCore.apply, l, x, ps, st)
end

##

#

# check which matmul it is calling
# l = P4ML.LinearLayer(in_d, out_d; feature_first = false)
# ps, st = LuxCore.setup(MersenneTwister(1234), l)
# X = rand(N, in_d)
# using ObjectPools
# release!(X)
# X = rand(N,in_d)

# @profview let l = l, ps = ps, st = st, X = X
# for _ = 1:100_000
# out = l(X, ps, st)[1]
# release!(out)
# l(X, ps, st)
# end
# end

0 comments on commit a276da5

Please sign in to comment.