From 4ce351e021ca02e5facbc473017f19eef0a461e4 Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 29 Dec 2024 11:56:45 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=91=A8=E2=80=8D=F0=9F=94=A7=20Zero=20drop?= =?UTF-8?q?out=20and=20deep=20copy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/entity_embedding.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index e98fef9..8954933 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -187,10 +187,13 @@ end 3 4 ]) + stable_rng=StableRNG(123) + for j in eachindex(embedding_dims) for i in eachindex(models) + # Without lightweight wrapper clf = models[1]( - builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.0), optimiser = Optimisers.Adam(0.01), batch_size = 8, epochs = 100, @@ -199,13 +202,14 @@ end embedding_dims = embedding_dims[3], rng=42 ) - emb = MLJFlux.EntityEmbedder(clf) - mach_emb = machine(emb, X, ys[1]) mach = machine(clf, X, ys[1]) - fit!(mach, verbosity = 0) - fit!(mach_emb, verbosity = 0) Xnew = transform(mach, X) + # With lightweight wrapper + clf2 = deepcopy(clf) + emb = MLJFlux.EntityEmbedder(clf2) + mach_emb = machine(emb, X, ys[1]) + fit!(mach_emb, verbosity = 0) Xnew_emb = transform(mach_emb, X) @test Xnew == Xnew_emb