Skip to content

Commit

Permalink
πŸ‘¨β€πŸ”§ Zero dropout and deep copy
Browse files Browse the repository at this point in the history
  • Loading branch information
EssamWisam committed Dec 29, 2024
1 parent 0d738d3 commit 4ce351e
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,10 +187,13 @@ end
3 4
])

stable_rng=StableRNG(123)

for j in eachindex(embedding_dims)
for i in eachindex(models)
# Without lightweight wrapper
clf = models[1](
builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2),
builder = MLJFlux.Short(n_hidden = 5, dropout = 0.0),
optimiser = Optimisers.Adam(0.01),
batch_size = 8,
epochs = 100,
Expand All @@ -199,13 +202,14 @@ end
embedding_dims = embedding_dims[3],
rng=42
)
emb = MLJFlux.EntityEmbedder(clf)
mach_emb = machine(emb, X, ys[1])
mach = machine(clf, X, ys[1])

fit!(mach, verbosity = 0)
fit!(mach_emb, verbosity = 0)
Xnew = transform(mach, X)
# With lightweight wrapper
clf2 = deepcopy(clf)
emb = MLJFlux.EntityEmbedder(clf2)
mach_emb = machine(emb, X, ys[1])
fit!(mach_emb, verbosity = 0)
Xnew_emb = transform(mach_emb, X)
@test Xnew == Xnew_emb

Expand Down

0 comments on commit 4ce351e

Please sign in to comment.