diff --git a/src/classifier.jl b/src/classifier.jl index 40bcf5bd..ed9d4cf9 100644 --- a/src/classifier.jl +++ b/src/classifier.jl @@ -1,6 +1,12 @@ # if `b` is a builder, then `b(model, rng, shape...)` is called to make a # new chain, where `shape` is the return value of this method: +""" + shape(model::NeuralNetworkClassifier, X, y) + +A private method that returns the shape of the input and output of the model for given data `X` and `y`. +""" function MLJFlux.shape(model::NeuralNetworkClassifier, X, y) + X = X isa Matrix ? Tables.table(X) : X levels = MLJModelInterface.classes(y[1]) n_output = length(levels) n_input = Tables.schema(X).names |> length @@ -10,7 +16,7 @@ end # builds the end-to-end Flux chain needed, given the `model` and `shape`: MLJFlux.build(model::NeuralNetworkClassifier, rng, shape) = Flux.Chain(build(model.builder, rng, shape...), - model.finaliser) + model.finaliser) # returns the model `fitresult` (see "Adding Models for General Use" # section of the MLJ manual) which must always have the form `(chain, @@ -19,15 +25,15 @@ MLJFlux.fitresult(model::NeuralNetworkClassifier, chain, y) = (chain, MLJModelInterface.classes(y[1])) function MLJModelInterface.predict(model::NeuralNetworkClassifier, - fitresult, - Xnew) + fitresult, + Xnew) chain, levels = fitresult X = reformat(Xnew) - probs = vcat([chain(tomat(X[:,i]))' for i in 1:size(X, 2)]...) + probs = vcat([chain(tomat(X[:, i]))' for i in 1:size(X, 2)]...) return MLJModelInterface.UnivariateFinite(levels, probs) end MLJModelInterface.metadata_model(NeuralNetworkClassifier, - input=Table(Continuous), - target=AbstractVector{<:Finite}, - path="MLJFlux.NeuralNetworkClassifier") + input=Union{AbstractMatrix{Continuous},Table(Continuous)}, + target=AbstractVector{<:Finite}, + path="MLJFlux.NeuralNetworkClassifier") diff --git a/src/core.jl b/src/core.jl index 3dcb3fbe..166af8f5 100644 --- a/src/core.jl +++ b/src/core.jl @@ -141,12 +141,16 @@ function nrows(X) return length(cols[1]) end nrows(y::AbstractVector) = length(y) +nrows(X::AbstractMatrix) = size(X, 1) reformat(X) = reformat(X, scitype(X)) # --------------------------------- -# Reformatting tables +# Reformatting matrices +reformat(X, ::Type{<:AbstractMatrix}) = X' +# --------------------------------- +# Reformatting tables reformat(X, ::Type{<:Table}) = MLJModelInterface.matrix(X)' # --------------------------------- diff --git a/src/regressor.jl b/src/regressor.jl index 85a431aa..18aaef22 100644 --- a/src/regressor.jl +++ b/src/regressor.jl @@ -1,6 +1,12 @@ # # NEURAL NETWORK REGRESSOR +""" + shape(model::NeuralNetworkRegressor, X, y) + +A private method that returns the shape of the input and output of the model for given data `X` and `y`. +""" function shape(model::NeuralNetworkRegressor, X, y) + X = X isa Matrix ? Tables.table(X) : X n_input = Tables.schema(X).names |> length n_ouput = 1 return (n_input, 1) @@ -12,47 +18,55 @@ build(model::NeuralNetworkRegressor, rng, shape) = fitresult(model::NeuralNetworkRegressor, chain, y) = (chain, nothing) function MLJModelInterface.predict(model::NeuralNetworkRegressor, - fitresult, - Xnew) - chain = fitresult[1] + fitresult, + Xnew) + chain = fitresult[1] Xnew_ = reformat(Xnew) - return [chain(values.(tomat(Xnew_[:,i])))[1] - for i in 1:size(Xnew_, 2)] + return [chain(values.(tomat(Xnew_[:, i])))[1] + for i in 1:size(Xnew_, 2)] end MLJModelInterface.metadata_model(NeuralNetworkRegressor, - input=Table(Continuous), - target=AbstractVector{<:Continuous}, - path="MLJFlux.NeuralNetworkRegressor") + input=Union{AbstractMatrix{Continuous},Table(Continuous)}, + target=AbstractVector{<:Continuous}, + path="MLJFlux.NeuralNetworkRegressor") # # MULTITARGET NEURAL NETWORK REGRESSOR +ncols(X::AbstractMatrix) = size(X, 2) +ncols(X) = Tables.columns(X) |> Tables.columnnames |> length -function shape(model::MultitargetNeuralNetworkRegressor, X, y) - n_input = Tables.schema(X).names |> length - n_output = Tables.schema(y).names |> length - return (n_input, n_output) -end +""" + shape(model::MultitargetNeuralNetworkRegressor, X, y) + +A private method that returns the shape of the input and output of the model for given data `X` and `y`. +""" +shape(model::MultitargetNeuralNetworkRegressor, X, y) = (ncols(X), ncols(y)) build(model::MultitargetNeuralNetworkRegressor, rng, shape) = build(model.builder, rng, shape...) function fitresult(model::MultitargetNeuralNetworkRegressor, chain, y) - target_column_names = Tables.schema(y).names + if y isa Matrix + target_column_names = nothing + else + target_column_names = Tables.schema(y).names + end return (chain, target_column_names) end function MLJModelInterface.predict(model::MultitargetNeuralNetworkRegressor, - fitresult, Xnew) - chain, target_column_names = fitresult + fitresult, Xnew) + chain, target_column_names = fitresult X = reformat(Xnew) - ypred = [chain(values.(tomat(X[:,i]))) + ypred = [chain(values.(tomat(X[:, i]))) for i in 1:size(X, 2)] - return MLJModelInterface.table(reduce(hcat, y for y in ypred)', - names=target_column_names) + output = isnothing(target_column_names) ? permutedims(reduce(hcat, ypred)) : + MLJModelInterface.table(reduce(hcat, ypred)', names=target_column_names) + return output end MLJModelInterface.metadata_model(MultitargetNeuralNetworkRegressor, - input=Table(Continuous), - target=Table(Continuous), - path="MLJFlux.MultitargetNeuralNetworkRegressor") + input=Union{AbstractMatrix{Continuous},Table(Continuous)}, + target=Table(Continuous), + path="MLJFlux.MultitargetNeuralNetworkRegressor") diff --git a/src/types.jl b/src/types.jl index 2e7958ef..c608abbf 100644 --- a/src/types.jl +++ b/src/types.jl @@ -5,127 +5,89 @@ const MLJFluxModel = Union{MLJFluxProbabilistic,MLJFluxDeterministic} for Model in [:NeuralNetworkClassifier, :ImageClassifier] - default_builder_ex = - Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() - - ex = quote - mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic - builder::B - finaliser::F - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Int64} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - end - - function $Model(; builder::B = $default_builder_ex - , finaliser::F = Flux.softmax - , optimiser::O = Flux.Optimise.Adam() - , loss::L = Flux.crossentropy - , epochs = 10 - , batch_size = 1 - , lambda = 0 - , alpha = 0 - , rng = Random.GLOBAL_RNG - , optimiser_changes_trigger_retraining = false - , acceleration = CPU1() - ) where {B,F,O,L} - - model = $Model{B,F,O,L}(builder - , finaliser - , optimiser - , loss - , epochs - , batch_size - , lambda - , alpha - , rng - , optimiser_changes_trigger_retraining - , acceleration - ) - - message = clean!(model) - isempty(message) || @warn message - - return model - end + default_builder_ex = + Model == :ImageClassifier ? :(image_builder(VGGHack)) : Short() + + ex = quote + mutable struct $Model{B,F,O,L} <: MLJFluxProbabilistic + builder::B + finaliser::F + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Int64} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + end + + function $Model(; builder::B=$default_builder_ex, finaliser::F=Flux.softmax, optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.crossentropy, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() + ) where {B,F,O,L} + + model = $Model{B,F,O,L}(builder, finaliser, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration + ) + message = clean!(model) + isempty(message) || @warn message + + return model end - eval(ex) + + end + eval(ex) end for Model in [:NeuralNetworkRegressor, :MultitargetNeuralNetworkRegressor] - ex = quote - mutable struct $Model{B,O,L} <: MLJFluxDeterministic - builder::B - optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl - loss::L # can be called as in `loss(yhat, y)` - epochs::Int # number of epochs - batch_size::Int # size of a batch - lambda::Float64 # regularization strength - alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) - rng::Union{AbstractRNG,Integer} - optimiser_changes_trigger_retraining::Bool - acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` - end - - function $Model(; builder::B = Linear() - , optimiser::O = Flux.Optimise.Adam() - , loss::L = Flux.mse - , epochs = 10 - , batch_size = 1 - , lambda = 0 - , alpha = 0 - , rng = Random.GLOBAL_RNG - , optimiser_changes_trigger_retraining=false - , acceleration = CPU1() - ) where {B,O,L} - - model = $Model{B,O,L}(builder - , optimiser - , loss - , epochs - , batch_size - , lambda - , alpha - , rng - , optimiser_changes_trigger_retraining - , acceleration) - - message = clean!(model) - isempty(message) || @warn message - - return model - end + ex = quote + mutable struct $Model{B,O,L} <: MLJFluxDeterministic + builder::B + optimiser::O # mutable struct from Flux/src/optimise/optimisers.jl + loss::L # can be called as in `loss(yhat, y)` + epochs::Int # number of epochs + batch_size::Int # size of a batch + lambda::Float64 # regularization strength + alpha::Float64 # regularizaton mix (0 for all l2, 1 for all l1) + rng::Union{AbstractRNG,Integer} + optimiser_changes_trigger_retraining::Bool + acceleration::AbstractResource # eg, `CPU1()` or `CUDALibs()` + end + + function $Model(; builder::B=Linear(), optimiser::O=Flux.Optimise.Adam(), loss::L=Flux.mse, epochs=10, batch_size=1, lambda=0, alpha=0, rng=Random.GLOBAL_RNG, optimiser_changes_trigger_retraining=false, acceleration=CPU1() + ) where {B,O,L} + + model = $Model{B,O,L}(builder, optimiser, loss, epochs, batch_size, lambda, alpha, rng, optimiser_changes_trigger_retraining, acceleration) + message = clean!(model) + isempty(message) || @warn message + + return model end - eval(ex) + + end + eval(ex) end const Regressor = - Union{NeuralNetworkRegressor, MultitargetNeuralNetworkRegressor} + Union{NeuralNetworkRegressor,MultitargetNeuralNetworkRegressor} MMI.metadata_pkg.( - ( - NeuralNetworkRegressor, - MultitargetNeuralNetworkRegressor, - NeuralNetworkClassifier, - ImageClassifier, - ), - name="MLJFlux", - uuid="094fc8d1-fd35-5302-93ea-dabda2abf845", - url="https://github.com/alan-turing-institute/MLJFlux.jl", - julia=true, - license="MIT", + ( + NeuralNetworkRegressor, + MultitargetNeuralNetworkRegressor, + NeuralNetworkClassifier, + ImageClassifier, + ), + name="MLJFlux", + uuid="094fc8d1-fd35-5302-93ea-dabda2abf845", + url="https://github.com/alan-turing-institute/MLJFlux.jl", + julia=true, + license="MIT", ) @@ -148,8 +110,8 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. +- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype + `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. - `y` is the target, which can be any `AbstractVector` whose element scitype is `Multiclass` or `OrderedFactor`; check the scitype with `scitype(y)` @@ -583,8 +545,8 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is any table of input features (eg, a `DataFrame`) whose columns - are of scitype `Continuous`; check the column scitypes with `schema(X)`. +- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype + `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. - `y` is the target, which can be any `AbstractVector` whose element scitype is `Continuous`; check the scitype with `scitype(y)` @@ -810,11 +772,11 @@ In MLJ or MLJBase, bind an instance `model` to data with Here: -- `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. +- `X` is either a `Matrix` or any table of input features (eg, a `DataFrame`) whose columns are of scitype + `Continuous`; check column scitypes with `schema(X)`. If `X` is a `Matrix`, it is assumed to have columns corresponding to features and rows corresponding to observations. -- `y` is the target, which can be any table of output targets whose element scitype is - `Continuous`; check column scitypes with `schema(y)`. +- `y` is the target, which can be any table or matrix of output targets whose element scitype is + `Continuous`; check column scitypes with `schema(y)`. If `y` is a `Matrix`, it is assumed to have columns corresponding to variables and rows corresponding to observations. # Hyper-parameters diff --git a/test/classifier.jl b/test/classifier.jl index 55bade43..3eb73699 100644 --- a/test/classifier.jl +++ b/test/classifier.jl @@ -25,13 +25,26 @@ losses = [] @testset_accelerated "NeuralNetworkClassifier" accel begin Random.seed!(123) - basictest(MLJFlux.NeuralNetworkClassifier, - X, - y, - builder, - optimiser, - 0.85, - accel) + # Table input: + @testset "Table input" begin + basictest(MLJFlux.NeuralNetworkClassifier, + X, + y, + builder, + optimiser, + 0.85, + accel) + end + # Matrix input: + @testset "Matrix input" begin + basictest(MLJFlux.NeuralNetworkClassifier, + matrix(X), + y, + builder, + optimiser, + 0.85, + accel) + end train, test = MLJBase.partition(1:N, 0.7) diff --git a/test/regressor.jl b/test/regressor.jl index 0f05ee72..1345125f 100644 --- a/test/regressor.jl +++ b/test/regressor.jl @@ -18,13 +18,27 @@ train, test = MLJBase.partition(1:N, 0.7) Random.seed!(123) - basictest(MLJFlux.NeuralNetworkRegressor, - X, - y, - builder, - optimiser, - 0.7, - accel) + # Table input: + @testset "Table input" begin + basictest(MLJFlux.NeuralNetworkRegressor, + X, + y, + builder, + optimiser, + 0.7, + accel) + end + + # Matrix input: + @testset "Matrix input" begin + basictest(MLJFlux.NeuralNetworkRegressor, + matrix(X), + y, + builder, + optimiser, + 0.7, + accel) + end # test model is a bit better than constant predictor: stable_rng = StableRNGs.StableRNG(123) @@ -64,13 +78,26 @@ losses = [] Random.seed!(123) - basictest(MLJFlux.MultitargetNeuralNetworkRegressor, - X, - y, - builder, - optimiser, - 1.0, - accel) + # Table input: + @testset "Table input" begin + basictest(MLJFlux.MultitargetNeuralNetworkRegressor, + X, + y, + builder, + optimiser, + 1.0, + accel) + end + # Matrix input: + @testset "Matrix input" begin + basictest(MLJFlux.MultitargetNeuralNetworkRegressor, + matrix(X), + ymatrix, + builder, + optimiser, + 1.0, + accel) + end # test model is a bit better than constant predictor model = MLJFlux.MultitargetNeuralNetworkRegressor(acceleration=accel,