Skip to content

Commit

Permalink
overload constructor trait for IteratedModel types
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed Jun 3, 2024
1 parent 67584bf commit 3bbfab0
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 5 deletions.
6 changes: 2 additions & 4 deletions src/traits.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
MLJBase.is_wrapper(::Type{<:EitherIteratedModel}) = true
MLJBase.caches_data_by_default(::Type{<:EitherIteratedModel}) = false
MLJBase.load_path(::Type{<:DeterministicIteratedModel}) =
"MLJIteration.DeterministicIteratedModel"
MLJBase.load_path(::Type{<:ProbabilisticIteratedModel}) =
"MLJIteration.ProbabilisticIteratedModel"
MLJBase.load_path(::Type{<:EitherIteratedModel}) = "MLJIteration.IteratedModel"
MLJBase.constructor(::Type{<:EitherIteratedModel}) = IteratedModel
MLJBase.package_name(::Type{<:EitherIteratedModel}) = "MLJIteration"
MLJBase.package_uuid(::Type{<:EitherIteratedModel}) =
"614be32b-d00c-4edb-bd02-1eb411ab5e55"
Expand Down
3 changes: 2 additions & 1 deletion test/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ imodel = IteratedModel(model=model, measure=mae)
@test !MLJBase.caches_data_by_default(imodel)
@test !supports_weights(imodel)
@test !supports_class_weights(imodel)
@test load_path(imodel) == "MLJIteration.DeterministicIteratedModel"
@test load_path(imodel) == "MLJIteration.IteratedModel"
@test package_name(imodel) == "MLJIteration"
@test package_uuid(imodel) == "614be32b-d00c-4edb-bd02-1eb411ab5e55"
@test package_url(imodel) == "https://github.com/JuliaAI/MLJIteration.jl"
Expand All @@ -22,6 +22,7 @@ imodel = IteratedModel(model=model, measure=mae)
@test input_scitype(imodel) == input_scitype(model)
@test output_scitype(imodel) == output_scitype(model)
@test target_scitype(imodel) == target_scitype(model)
@test constructor(imodel) == IteratedModel

end

Expand Down

0 comments on commit 3bbfab0

Please sign in to comment.