Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

⭐️ Entity embedder interface is here #286

Merged
merged 13 commits into from
Jan 20, 2025
2 changes: 2 additions & 0 deletions src/MLJFlux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,11 @@ include("image.jl")
include("fit_utils.jl")
include("entity_embedding_utils.jl")
include("mlj_model_interface.jl")
include("mlj_embedder_interface.jl")

export NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor
export NeuralNetworkClassifier, NeuralNetworkBinaryClassifier, ImageClassifier
export EntityEmbedder
export CUDALibs, CPU1

include("deprecated.jl")
Expand Down
12 changes: 6 additions & 6 deletions src/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ entityprops = [
numfeats = 4

# Run it through the categorical embedding layer
embedder = EntityEmbedder(entityprops, 4)
embedder = EntityEmbedderLayer(entityprops, 4)
julia> output = embedder(batch)
5×10 Matrix{Float64}:
0.2 0.3 0.4 0.5 … 0.8 0.9 1.0 1.1
Expand All @@ -35,18 +35,18 @@ julia> output = embedder(batch)
-0.847354 -0.847354 -1.66261 -1.66261 -1.66261 -1.66261 -0.847354 -0.847354
```
""" # 1. Define layer struct to hold parameters
struct EntityEmbedder{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
struct EntityEmbedderLayer{A1 <: AbstractVector, A2 <: AbstractVector, I <: Integer}
embedders::A1
modifiers::A2 # applied on the input before passing it to the embedder
numfeats::I
end

# 2. Define the forward pass (i.e., calling an instance of the layer)
(m::EntityEmbedder)(x) =
(m::EntityEmbedderLayer)(x) =
(vcat([m.embedders[i](m.modifiers[i](x, i)) for i in 1:m.numfeats]...))

# 3. Define the constructor which initializes the parameters and returns the instance
function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
function EntityEmbedderLayer(entityprops, numfeats; init = Flux.randn32)
embedders = []
modifiers = []
# Setup entityprops
Expand All @@ -66,8 +66,8 @@ function EntityEmbedder(entityprops, numfeats; init = Flux.randn32)
end
end

EntityEmbedder(embedders, modifiers, numfeats)
EntityEmbedderLayer(embedders, modifiers, numfeats)
end

# 4. Register it as layer with Flux
Flux.@layer EntityEmbedder
Flux.@layer EntityEmbedderLayer
2 changes: 1 addition & 1 deletion src/entity_embedding_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function construct_model_chain_with_entityembs(
)
chain = try
Flux.Chain(
EntityEmbedder(entityprops, shape[1]; init = Flux.glorot_uniform(rng)),
EntityEmbedderLayer(entityprops, shape[1]; init = Flux.glorot_uniform(rng)),
build(model, rng, (entityemb_output_dim, shape[2])),
) |> move
catch ex
Expand Down
156 changes: 156 additions & 0 deletions src/mlj_embedder_interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
### EntityEmbedder with MLJ Interface

# 1. Interface Struct
mutable struct EntityEmbedder{M <: MLJFluxModel} <: Unsupervised
model::M
end;


const ERR_MODEL_UNSPECIFIED = ErrorException("You must specify a suitable MLJFlux supervised model, as in `EntityEmbedder(model=...)`. ")
# 2. Constructor
function EntityEmbedder(;model=nothing)
model === nothing && throw(ERR_MODEL_UNSPECIFIED)
return EntityEmbedder(model)
ablaom marked this conversation as resolved.
Show resolved Hide resolved
end;


# 4. Fitted parameters (for user access)
MMI.fitted_params(::EntityEmbedder, fitresult) = fitresult

# 5. Fit method
function MMI.fit(transformer::EntityEmbedder, verbosity::Int, X, y)
return MLJModelInterface.fit(transformer.model, verbosity, X, y)
end;


# 6. Transform method
function MMI.transform(transformer::EntityEmbedder, fitresult, Xnew)
Xnew_transf = MLJModelInterface.transform(transformer.model, fitresult, Xnew)
return Xnew_transf
end

# 8. Extra metadata
MMI.metadata_pkg(
EntityEmbedder,
package_name = "MLJTransforms",
package_uuid = "23777cdb-d90c-4eb0-a694-7c2b83d5c1d6",
package_url = "https://github.com/JuliaAI/MLJTransforms.jl",
is_pure_julia = true,
)

MMI.metadata_model(
EntityEmbedder,
load_path = "MLJTransforms.EntityEmbedder",
)

MMI.target_in_fit(::Type{<:EntityEmbedder}) = true

# 9. Forwarding traits
MMI.is_wrapper(::Type{<:EntityEmbedder}) =true
MMI.supports_training_losses(::Type{<:EntityEmbedder}) = true


for trait in [
:input_scitype,
:output_scitype,
:target_scitype,
]

quote
MMI.$trait(::Type{<:EntityEmbedder{M}}) where M = MMI.$trait(M)
end |> eval
end

# ## Iteration parameter
prepend(s::Symbol, ::Nothing) = nothing
prepend(s::Symbol, t::Symbol) = Expr(:(.), s, QuoteNode(t))
prepend(s::Symbol, ex::Expr) = Expr(:(.), prepend(s, ex.args[1]), ex.args[2])
quote
MMI.iteration_parameter(::Type{<:EntityEmbedder{M}}) where M =
prepend(:model, MMI.iteration_parameter(M))
end |> eval

MMI.training_losses(embedder::EntityEmbedder, report) =
MMI.training_losses(embedder.model, report)




"""
EntityEmbedder(; model=mljflux_neural_model)

ablaom marked this conversation as resolved.
Show resolved Hide resolved
`EntityEmbedder` implements entity embeddings as in the "Entity Embeddings of Categorical Variables" paper by Cheng Guo, Felix Berkhahn.

# Training data

In MLJ (or MLJBase) bind an instance unsupervised `model` to data with

mach = machine(model, X, y)

Here:


- `X` is any table of input features supported by the model being wrapped. Features to be transformed must
have element scitype `Multiclass` or `OrderedFactor`. Use `schema(X)` to
check scitypes.

- `y` is the target, which can be any `AbstractVector` supported by the model being wrapped.

Train the machine using `fit!(mach)`.

# Hyper-parameters

- `model`: The supervised MLJFlux neural network model to be used for entity embedding.
This must be one of these: `MLJFlux.NeuralNetworkClassifier`,
EssamWisam marked this conversation as resolved.
Show resolved Hide resolved
`MLJFlux.NeuralNetworkRegressor`,`MLJFlux.MultitargetNeuralNetworkRegressor`. The selected model may have hyperparameters
that may affect embedding performance, the most notable of which could be the `builder` argument.

# Operations

- `transform(mach, Xnew)`: Transform the categorical features of `Xnew` into dense `Continuous` vectors using the trained `MLJFlux.EntityEmbedderLayer` layer present in the network.
Check relevant documentation [here](https://fluxml.ai/MLJFlux.jl/dev/) and in particular, the `embedding_dims` hyperparameter.


# Examples

```julia
using MLJ
using CategoricalArrays

# Setup some data
N = 200
X = (;
Column1 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column2 = categorical(repeat(['a', 'b', 'c', 'd', 'e'], Int(N / 5))),
Column3 = categorical(repeat(["b", "c", "d", "f", "f"], Int(N / 5)), ordered = true),
Column4 = repeat(Float32[1.0, 2.0, 3.0, 4.0, 5.0], Int(N / 5)),
Column5 = randn(Float32, N),
Column6 = categorical(
repeat(["group1", "group1", "group2", "group2", "group3"], Int(N / 5)),
),
)
y = categorical([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # Classification

# Initiate model
EntityEmbedder = @load EntityEmbedder pkg=MLJFlux
NeuralNetworkClassifier = @load NeuralNetworkClassifier pkg=MLJFlux
EssamWisam marked this conversation as resolved.
Show resolved Hide resolved

clf = NeuralNetworkClassifier(embedding_dims=Dict(:Column2 => 2, :Column3 => 2))

emb = EntityEmbedder(clf)

# Construct machine
mach = machine(emb, X, y)

# Train model
fit!(mach)

# Transform data using model to encode categorical columns
Xnew = transform(mach, X)
Xnew
```

See also
[`NeuralNetworkClassifier`, `NeuralNetworkRegressor`](@ref)
"""
EntityEmbedder
6 changes: 3 additions & 3 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ MMI.metadata_pkg.(
const MODELSUPPORTDOC = """
In addition to features with `Continuous` scientific element type, this model supports
categorical features in the input table. If present, such features are embedded into dense
vectors by the use of an additional `EntityEmbedder` layer after the input, as described in
vectors by the use of an additional `EntityEmbedderLayer` layer after the input, as described in
Entity Embeddings of Categorical Variables by Cheng Guo, Felix Berkhahn arXiv, 2016.
"""

Expand All @@ -204,7 +204,7 @@ const XDOC = """
scitype (typically `Float32`); or (ii) a table of input features (eg, a `DataFrame`)
whose columns have `Continuous`, `Multiclass` or `OrderedFactor` element scitype; check
column scitypes with `schema(X)`. If any `Multiclass` or `OrderedFactor` features
appear, the constructed network will use an `EntityEmbedder` layer to transform
appear, the constructed network will use an `EntityEmbedderLayer` layer to transform
them into dense vectors. If `X` is a `Matrix`, it is assumed that columns correspond to
features and rows corresponding to observations.

Expand All @@ -222,7 +222,7 @@ const EMBDOC = """
const TRANSFORMDOC = """
- `transform(mach, Xnew)`: Assuming `Xnew` has the same schema as `X`, transform the
categorical features of `Xnew` into dense `Continuous` vectors using the
`MLJFlux.EntityEmbedder` layer present in the network. Does nothing in case the model
`MLJFlux.EntityEmbedderLayer` layer present in the network. Does nothing in case the model
was trained on an input `X` that lacks categorical features.
"""

Expand Down
33 changes: 26 additions & 7 deletions test/entity_embedding.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ entityprops = [
(index = 4, levels = 2, newdim = 2),
]

embedder = MLJFlux.EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4)

output = embedder(batch)

Expand Down Expand Up @@ -68,7 +68,7 @@ end
]

cat_model = Chain(
MLJFlux.EntityEmbedder(entityprops, 4),
MLJFlux.EntityEmbedderLayer(entityprops, 4),
Dense(9 => (ind == 1) ? 10 : 1),
finalizer[ind],
)
Expand Down Expand Up @@ -143,14 +143,14 @@ end
@testset "Transparent when no categorical variables" begin
entityprops = []
numfeats = 4
embedder = MLJFlux.EntityEmbedder(entityprops, 4)
embedder = MLJFlux.EntityEmbedderLayer(entityprops, 4)
output = embedder(batch)
@test output ≈ batch
@test eltype(output) == Float32
end


@testset "get_embedding_matrices works and has the right dimensions" begin
@testset "get_embedding_matrices works as well as the light wrapper" begin
models = [
MLJFlux.NeuralNetworkBinaryClassifier,
MLJFlux.NeuralNetworkClassifier,
Expand Down Expand Up @@ -187,21 +187,40 @@ end
3 4
])

stable_rng=StableRNG(123)

for j in eachindex(embedding_dims)
for i in eachindex(models)
# Without lightweight wrapper
clf = models[1](
builder = MLJFlux.Short(n_hidden = 5, dropout = 0.2),
builder = MLJFlux.MLP(hidden=(10, 10)),
optimiser = Optimisers.Adam(0.01),
batch_size = 8,
epochs = 100,
acceleration = CUDALibs(),
optimiser_changes_trigger_retraining = true,
embedding_dims = embedding_dims[3],
rng=42
)

mach = machine(clf, X, ys[1])

fit!(mach, verbosity = 0)
Xnew = transform(mach, X)
# With lightweight wrapper
clf2 = deepcopy(clf)
emb = MLJFlux.EntityEmbedder(clf2)
@test_throws MLJFlux.ERR_MODEL_UNSPECIFIED begin
MLJFlux.EntityEmbedder()
end
mach_emb = machine(emb, X, ys[1])
fit!(mach_emb, verbosity = 0)
Xnew_emb = transform(mach_emb, X)
@test Xnew == Xnew_emb

# Pipeline doesn't throw an error
pipeline = emb |> clf
mach_pipe = machine(pipeline, X, y)
fit!(mach_pipe, verbosity = 0)
y = predict_mode(mach_pipe, X)

mapping_matrices = MLJFlux.get_embedding_matrices(
fitted_params(mach).chain,
Expand Down
Loading