From 0d738d30d5084331df813af7e2954d6353a0de8e Mon Sep 17 00:00:00 2001 From: Essam Date: Fri, 27 Dec 2024 19:44:45 +0200 Subject: [PATCH 01/12] =?UTF-8?q?=E2=AD=90=EF=B8=8F=20Entity=20embedder=20?= =?UTF-8?q?interface=20is=20here?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/MLJFlux.jl | 2 + src/entity_embedding.jl | 12 ++-- src/entity_embedding_utils.jl | 2 +- src/mlj_embedder_interface.jl | 128 ++++++++++++++++++++++++++++++++++ src/types.jl | 6 +- test/entity_embedding.jl | 20 ++++-- 6 files changed, 156 insertions(+), 14 deletions(-) create mode 100644 src/mlj_embedder_interface.jl diff --git a/src/MLJFlux.jl b/src/MLJFlux.jl index bbc0b66..592fbbe 100644 --- a/src/MLJFlux.jl +++ b/src/MLJFlux.jl @@ -30,9 +30,11 @@ include("image.jl") include("fit_utils.jl") include("entity_embedding_utils.jl") include("mlj_model_interface.jl") +include("mlj_embedder_interface.jl") export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier +export EntityEmbedder export CUDALibs, CPU1 include("deprecated.jl") diff --git a/src/entity_embedding.jl b/src/entity_embedding.jl index 313e3e6..9fac24c 100644 --- a/src/entity_embedding.jl +++ b/src/entity_embedding.jl @@ -25,7 +25,7 @@ entityprops = [ numfeats = 4 # Run it through the categorical embedding layer -embedder = EntityEmbedder(entityprops, 4) +embedder = EntityEmbedderLayer(entityprops, 4) julia> output = embedder(batch) 5×10 Matrix{Float64}: 0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1 @@ -35,18 +35,18 @@ julia> output = embedder(batch) -0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354 ``` """ # 1. Define layer struct to hold parameters -struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} +struct EntityEmbedderLayer{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer} embedders::A1 modifiers::A2 # applied on the input before passing it to the embedder numfeats::I end # 2. Define the forward pass (i.e., calling an instance of the layer) -(m::EntityEmbedder)(x) = +(m::EntityEmbedderLayer)(x) = (vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...)) # 3. Define the constructor which initializes the parameters and returns the instance -function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) +function EntityEmbedderLayer(entityprops, numfeats; init = Flux.randn32) embedders = [] modifiers = [] # Setup entityprops @@ -66,8 +66,8 @@ function EntityEmbedder(entityprops, numfeats; init = Flux.randn32) end end - EntityEmbedder(embedders, modifiers, numfeats) + EntityEmbedderLayer(embedders, modifiers, numfeats) end # 4. Register it as layer with Flux -Flux.@layer EntityEmbedder \ No newline at end of file +Flux.@layer EntityEmbedderLayer \ No newline at end of file diff --git a/src/entity_embedding_utils.jl b/src/entity_embedding_utils.jl index 21e77ed..aa1e54d 100644 --- a/src/entity_embedding_utils.jl +++ b/src/entity_embedding_utils.jl @@ -100,7 +100,7 @@ function construct_model_chain_with_entityembs( ) chain = try Flux.Chain( - EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)), + EntityEmbedderLayer(entityprops, shape[1]; init = Flux.glorot_uniform(rng)), build(model, rng, (entityemb_output_dim, shape[2])), ) |> move catch ex diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl new file mode 100644 index 0000000..868dd9a --- /dev/null +++ b/src/mlj_embedder_interface.jl @@ -0,0 +1,128 @@ +### EntityEmbedder with MLJ Interface + +# 1. Interface Struct +mutable struct EntityEmbedder{M <: MLJFluxModel} <: Unsupervised + model::M +end; + + +# 2. Constructor +function EntityEmbedder(model;) + return EntityEmbedder(model) +end; + + +# 4. Fitted parameters (for user access) +MMI.fitted_params(::EntityEmbedder, fitresult) = fitresult + +# 5. Fit method +function MMI.fit(transformer::EntityEmbedder, verbosity::Int, X, y) + return MLJModelInterface.fit(transformer.model, verbosity, X, y) +end; + + +# 6. Transform method +function MMI.transform(transformer::EntityEmbedder, fitresult, Xnew) + Xnew_transf = MLJModelInterface.transform(transformer.model, fitresult, Xnew) + return Xnew_transf +end + +# 8. Extra metadata +MMI.metadata_pkg( + EntityEmbedder, + package_name = "MLJTransforms", + package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6", + package_url = "https://github.com/JuliaAI/MLJTransforms.jl", + is_pure_julia = true, +) + +MMI.metadata_model( + EntityEmbedder, + input_scitype = Table, + output_scitype = Table, + load_path = "MLJTransforms.EntityEmbedder", +) + +MMI.target_in_fit(::Type{<:EntityEmbedder}) = true + + + + + +""" +$(MMI.doc_header(EntityEmbedder)) + +`EntityEmbedder` implements entity embeddings as in the "Entity Embeddings of Categorical Variables" paper by Cheng Guo, Felix Berkhahn. + +# Training data + +In MLJ (or MLJBase) bind an instance unsupervised `model` to data with + + mach = machine(model, X, y) + +Here: + + +- `X` is any table of input features (eg, a `DataFrame`). Features to be transformed must + have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to + check scitypes. + +- `y` is the target, which can be any `AbstractVector` whose element + scitype is `Continuous` or `Count` for regression problems and + `Multiclass` or `OrderedFactor` for classification problems; check the scitype with `schema(y)` + +Train the machine using `fit!(mach)`. + +# Hyper-parameters + +- `model`: The underlying deep learning model to be used for entity embedding. So far this supports `NeuralNetworkClassifier`, `NeuralNetworkRegressor`, and `MultitargetNeuralNetworkRegressor`. + +# Operations + +- `transform(mach, Xnew)`: Transform the categorical features of `Xnew` into dense `Continuous` vectors using the trained `MLJFlux.EntityEmbedderLayer` layer present in the network. + Check relevant documentation [here](https://fluxml.ai/MLJFlux.jl/dev/) and in particular, the `embedding_dims` hyperparameter. + + +# Examples + +```julia +using MLJFlux +using MLJ +using CategoricalArrays + +# Setup some data +N = 200 +X = (; + Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))), + Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true), + Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)), + Column5 = randn(Float32, N), + Column6 = categorical( + repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)), + ), +) +y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification + +# Initiate model +NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux + +clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2)) + +emb = EntityEmbedder(clf) + +# Construct machine +mach = machine(emb, X, y) + +# Train model +fit!(mach) + +# Transform data using model to encode categorical columns +Xnew = transform(mach, X) +Xnew +``` + +See also +[`TargetEncoder`](@ref) +""" +EntityEmbedder \ No newline at end of file diff --git a/src/types.jl b/src/types.jl index bf64145..dc1adef 100644 --- a/src/types.jl +++ b/src/types.jl @@ -194,7 +194,7 @@ MMI.metadata_pkg.( const MODELSUPPORTDOC = """ In addition to features with `Continuous` scientific element type, this model supports categorical features in the input table. If present, such features are embedded into dense -vectors by the use of an additional `EntityEmbedder` layer after the input, as described in +vectors by the use of an additional `EntityEmbedderLayer` layer after the input, as described in Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn arXiv, 2016. """ @@ -204,7 +204,7 @@ const XDOC = """ scitype (typically `Float32`); or (ii) a table of input features (eg, a `DataFrame`) whose columns have `Continuous`, `Multiclass` or `OrderedFactor` element scitype; check column scitypes with `schema(X)`. If any `Multiclass` or `OrderedFactor` features - appear, the constructed network will use an `EntityEmbedder` layer to transform + appear, the constructed network will use an `EntityEmbedderLayer` layer to transform them into dense vectors. If `X` is a `Matrix`, it is assumed that columns correspond to features and rows corresponding to observations. @@ -222,7 +222,7 @@ const EMBDOC = """ const TRANSFORMDOC = """ - `transform(mach, Xnew)`: Assuming `Xnew` has the same schema as `X`, transform the categorical features of `Xnew` into dense `Continuous` vectors using the - `MLJFlux.EntityEmbedder` layer present in the network. Does nothing in case the model + `MLJFlux.EntityEmbedderLayer` layer present in the network. Does nothing in case the model was trained on an input `X` that lacks categorical features. """ diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index da0c89b..e98fef9 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -22,7 +22,7 @@ entityprops = [ (index = 4, levels = 2, newdim = 2), ] - embedder = MLJFlux.EntityEmbedder(entityprops, 4) + embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4) output = embedder(batch) @@ -68,7 +68,7 @@ end ] cat_model = Chain( - MLJFlux.EntityEmbedder(entityprops, 4), + MLJFlux.EntityEmbedderLayer(entityprops, 4), Dense(9 => (ind == 1) ? 10 : 1), finalizer[ind], ) @@ -143,7 +143,7 @@ end @testset "Transparent when no categorical variables" begin entityprops = [] numfeats = 4 - embedder = MLJFlux.EntityEmbedder(entityprops, 4) + embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4) output = embedder(batch) @test output ≈ batch @test eltype(output) == Float32 @@ -197,11 +197,23 @@ end acceleration = CUDALibs(), optimiser_changes_trigger_retraining = true, embedding_dims = embedding_dims[3], + rng=42 ) - + emb = MLJFlux.EntityEmbedder(clf) + mach_emb = machine(emb, X, ys[1]) mach = machine(clf, X, ys[1]) fit!(mach, verbosity = 0) + fit!(mach_emb, verbosity = 0) + Xnew = transform(mach, X) + Xnew_emb = transform(mach_emb, X) + @test Xnew == Xnew_emb + + # Pipeline doesn't throw an error + pipeline = emb |> clf + mach_pipe = machine(pipeline, X, y) + fit!(mach_pipe, verbosity = 0) + y = predict_mode(mach_pipe, X) mapping_matrices = MLJFlux.get_embedding_matrices( fitted_params(mach).chain, From 4ce351e021ca02e5facbc473017f19eef0a461e4 Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 29 Dec 2024 11:56:45 +0200 Subject: [PATCH 02/12] =?UTF-8?q?=F0=9F=91=A8=E2=80=8D=F0=9F=94=A7=20Zero?= =?UTF-8?q?=20dropout=20and=20deep=20copy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test/entity_embedding.jl | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index e98fef9..8954933 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -187,10 +187,13 @@ end 3 4 ]) + stable_rng=StableRNG(123) + for j in eachindex(embedding_dims) for i in eachindex(models) + # Without lightweight wrapper clf = models[1]( - builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2), + builder = MLJFlux.Short(n_hidden = 5, dropout = 0.0), optimiser = Optimisers.Adam(0.01), batch_size = 8, epochs = 100, @@ -199,13 +202,14 @@ end embedding_dims = embedding_dims[3], rng=42 ) - emb = MLJFlux.EntityEmbedder(clf) - mach_emb = machine(emb, X, ys[1]) mach = machine(clf, X, ys[1]) - fit!(mach, verbosity = 0) - fit!(mach_emb, verbosity = 0) Xnew = transform(mach, X) + # With lightweight wrapper + clf2 = deepcopy(clf) + emb = MLJFlux.EntityEmbedder(clf2) + mach_emb = machine(emb, X, ys[1]) + fit!(mach_emb, verbosity = 0) Xnew_emb = transform(mach_emb, X) @test Xnew == Xnew_emb From 757c186e9e8edfd721e3e0abd3ac8d4b2900903f Mon Sep 17 00:00:00 2001 From: Essam Date: Tue, 31 Dec 2024 08:40:34 +0200 Subject: [PATCH 03/12] Update entity_embedding.jl --- test/entity_embedding.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index 8954933..8bf35fb 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -193,7 +193,7 @@ end for i in eachindex(models) # Without lightweight wrapper clf = models[1]( - builder = MLJFlux.Short(n_hidden = 5, dropout = 0.0), + builder = MLJFlux.MLP(hidden=(10, 10)), optimiser = Optimisers.Adam(0.01), batch_size = 8, epochs = 100, From d95fc8e38d9124ae5c615a2fe7b1f9f25a27604e Mon Sep 17 00:00:00 2001 From: Essam Date: Wed, 15 Jan 2025 16:38:12 -0600 Subject: [PATCH 04/12] Update src/mlj_embedder_interface.jl Co-authored-by: Anthony Blaom, PhD --- src/mlj_embedder_interface.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index 868dd9a..ecdbc97 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -75,7 +75,9 @@ Train the machine using `fit!(mach)`. # Hyper-parameters -- `model`: The underlying deep learning model to be used for entity embedding. So far this supports `NeuralNetworkClassifier`, `NeuralNetworkRegressor`, and `MultitargetNeuralNetworkRegressor`. +- `model`: The supervised MLJFlux neural network model to be used for entity embedding. + This must be one of these: `MLJFlux.NeuralNetworkClassifier`, + `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. # Operations From 9de68f36aeabe4ad6534d710fe8f13570216c063 Mon Sep 17 00:00:00 2001 From: Essam Date: Wed, 15 Jan 2025 16:38:22 -0600 Subject: [PATCH 05/12] Update src/mlj_embedder_interface.jl Co-authored-by: Anthony Blaom, PhD --- src/mlj_embedder_interface.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index ecdbc97..e188775 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -107,6 +107,7 @@ X = (; y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification # Initiate model +EntityEmbedder = @load EntityEmbedder pkg=MLJFlux NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2)) From 4e11f18ca2b597a637414811f4aa6bdc70674e4f Mon Sep 17 00:00:00 2001 From: Essam Date: Wed, 15 Jan 2025 16:38:36 -0600 Subject: [PATCH 06/12] Update src/mlj_embedder_interface.jl Co-authored-by: Anthony Blaom, PhD --- src/mlj_embedder_interface.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index e188775..b874f99 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -88,7 +88,6 @@ Train the machine using `fit!(mach)`. # Examples ```julia -using MLJFlux using MLJ using CategoricalArrays From baf10d9b8988952e95e9f41ddddb40fa5d8dbdb2 Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 16 Jan 2025 20:36:53 -0600 Subject: [PATCH 07/12] =?UTF-8?q?=F0=9F=8C=9F=20Keyword=20argument=20model?= =?UTF-8?q?=20for=20consistency?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mlj_embedder_interface.jl | 4 +++- test/entity_embedding.jl | 5 ++++- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index 868dd9a..a39ffcb 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -6,8 +6,10 @@ mutable struct EntityEmbedder{M <: MLJFluxModel} <: Unsupervised end; +const ERR_MODEL_UNSPECIFIED = ErrorException("You must specify a suitable MLJFlux supervised model, as in `EntityEmbedder(model=...)`. ") # 2. Constructor -function EntityEmbedder(model;) +function EntityEmbedder(;model=nothing) + model === nothing && throw(ERR_MODEL_UNSPECIFIED) return EntityEmbedder(model) end; diff --git a/test/entity_embedding.jl b/test/entity_embedding.jl index 8bf35fb..63af4bb 100644 --- a/test/entity_embedding.jl +++ b/test/entity_embedding.jl @@ -150,7 +150,7 @@ end end -@testset "get_embedding_matrices works and has the right dimensions" begin +@testset "get_embedding_matrices works as well as the light wrapper" begin models = [ MLJFlux.NeuralNetworkBinaryClassifier, MLJFlux.NeuralNetworkClassifier, @@ -208,6 +208,9 @@ end # With lightweight wrapper clf2 = deepcopy(clf) emb = MLJFlux.EntityEmbedder(clf2) + @test_throws MLJFlux.ERR_MODEL_UNSPECIFIED begin + MLJFlux.EntityEmbedder() + end mach_emb = machine(emb, X, ys[1]) fit!(mach_emb, verbosity = 0) Xnew_emb = transform(mach_emb, X) From 8ac77e4ed2cce4869dcdf98a8ebd57634cd47f06 Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 16 Jan 2025 20:42:07 -0600 Subject: [PATCH 08/12] =?UTF-8?q?=F0=9F=93=96=20Update=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mlj_embedder_interface.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index 5f8f3d4..a180098 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -65,13 +65,11 @@ In MLJ (or MLJBase) bind an instance unsupervised `model` to data with Here: -- `X` is any table of input features (eg, a `DataFrame`). Features to be transformed must +- `X` is any table of input features supported by the model being wrapped. Features to be transformed must have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to check scitypes. -- `y` is the target, which can be any `AbstractVector` whose element - scitype is `Continuous` or `Count` for regression problems and - `Multiclass` or `OrderedFactor` for classification problems; check the scitype with `schema(y)` +- `y` is the target, which can be any `AbstractVector` supported by the model being wrapped. Train the machine using `fit!(mach)`. @@ -127,6 +125,6 @@ Xnew ``` See also -[`TargetEncoder`](@ref) +[`NeuralNetworkClassifier`, `NeuralNetworkRegressor`](@ref) """ EntityEmbedder \ No newline at end of file From 0753e94f35ec73a72cb634e8c85452fdca05a6b5 Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 16 Jan 2025 20:49:15 -0600 Subject: [PATCH 09/12] =?UTF-8?q?=E2=9C=8F=EF=B8=8F=20Better=20docs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mlj_embedder_interface.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index a180098..8ed58f6 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -52,7 +52,7 @@ MMI.target_in_fit(::Type{<:EntityEmbedder}) = true """ -$(MMI.doc_header(EntityEmbedder)) + EntityEmbedder(; model=mljflux_neural_model) `EntityEmbedder` implements entity embeddings as in the "Entity Embeddings of Categorical Variables" paper by Cheng Guo, Felix Berkhahn. @@ -77,7 +77,8 @@ Train the machine using `fit!(mach)`. - `model`: The supervised MLJFlux neural network model to be used for entity embedding. This must be one of these: `MLJFlux.NeuralNetworkClassifier`, - `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. + `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. The selected model may have hyperparameters + that may affect embedding performance, the most notable of which could be the `builder` argument. # Operations From 34a5af3c8e9247f7871e22b76686c8b6d326f028 Mon Sep 17 00:00:00 2001 From: Essam Date: Thu, 16 Jan 2025 21:09:20 -0600 Subject: [PATCH 10/12] =?UTF-8?q?=F0=9F=9A=80=20Update=20the=20interface?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/mlj_embedder_interface.jl | 29 +++++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index 8ed58f6..df230e5 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -40,13 +40,38 @@ MMI.metadata_pkg( MMI.metadata_model( EntityEmbedder, - input_scitype = Table, - output_scitype = Table, load_path = "MLJTransforms.EntityEmbedder", ) MMI.target_in_fit(::Type{<:EntityEmbedder}) = true +# 9. Forwarding traits +MMI.is_wrapper(::Type{<:EntityEmbedder}) =true +MMI.supports_training_losses(::Type{<:EntityEmbedder}) = true + + +for trait in [ + :input_scitype, + :output_scitype, + :target_scitype, + ] + + quote + MMI.$trait(::Type{<:EntityEmbedder{M}}) where M = MMI.$trait(M) + end |> eval +end + +# ## Iteration parameter +prepend(s::Symbol, ::Nothing) = nothing +prepend(s::Symbol, t::Symbol) = Expr(:(.), s, QuoteNode(t)) +prepend(s::Symbol, ex::Expr) = Expr(:(.), prepend(s, ex.args[1]), ex.args[2]) +quote + MMI.iteration_parameter(::Type{<:EntityEmbedder{M}}) where M = + prepend(:model, MMI.iteration_parameter(M)) +end |> eval + +MMI.training_losses(embedder::EntityEmbedder, report) = + MMI.training_losses(embedder.model, report) From b879986378588f2290661a0ca39c7bb9b71087a5 Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 19 Jan 2025 20:08:01 -0600 Subject: [PATCH 11/12] Update src/mlj_embedder_interface.jl --- src/mlj_embedder_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index df230e5..8c67934 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -101,7 +101,7 @@ Train the machine using `fit!(mach)`. # Hyper-parameters - `model`: The supervised MLJFlux neural network model to be used for entity embedding. - This must be one of these: `MLJFlux.NeuralNetworkClassifier`, + This must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier` `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. The selected model may have hyperparameters that may affect embedding performance, the most notable of which could be the `builder` argument. From 981155428c7c4d1c7c8d73e42556db9766bae847 Mon Sep 17 00:00:00 2001 From: Essam Date: Sun, 19 Jan 2025 20:08:45 -0600 Subject: [PATCH 12/12] Update src/mlj_embedder_interface.jl --- src/mlj_embedder_interface.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/mlj_embedder_interface.jl b/src/mlj_embedder_interface.jl index 8c67934..b37a90d 100644 --- a/src/mlj_embedder_interface.jl +++ b/src/mlj_embedder_interface.jl @@ -101,7 +101,7 @@ Train the machine using `fit!(mach)`. # Hyper-parameters - `model`: The supervised MLJFlux neural network model to be used for entity embedding. - This must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier` + This must be one of these: `MLJFlux.NeuralNetworkClassifier`, `NeuralNetworkBinaryClassifier`, `MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. The selected model may have hyperparameters that may affect embedding performance, the most notable of which could be the `builder` argument.