diff --git a/ibis_ml/core.py b/ibis_ml/core.py index 68dacf9..ee43a45 100644 --- a/ibis_ml/core.py +++ b/ibis_ml/core.py @@ -1,6 +1,7 @@ from __future__ import annotations import copy +import inspect import os import pprint from collections import defaultdict @@ -306,6 +307,69 @@ def categorize(df: pd.DataFrame, categories: dict[str, list[Any]]) -> pd.DataFra class Step: + @classmethod + def _get_param_names(cls) -> list[str]: + """Get parameter names for the estimator. + + Notes + ----- + Copied from [1]_. + + References + ---------- + .. [1] https://github.com/scikit-learn/scikit-learn/blob/ab2f539/sklearn/base.py#L148-L173 + """ + # fetch the constructor or the original constructor before + # deprecation wrapping if any + init = getattr(cls.__init__, "deprecated_original", cls.__init__) + if init is object.__init__: + # No explicit constructor to introspect + return [] + + # introspect the constructor arguments to find the model parameters + # to represent + init_signature = inspect.signature(init) + # Consider the constructor parameters excluding 'self' + parameters = [ + p + for p in init_signature.parameters.values() + if p.name != "self" and p.kind != p.VAR_KEYWORD + ] + for p in parameters: + if p.kind == p.VAR_POSITIONAL: + raise RuntimeError( + "scikit-learn estimators should always " + "specify their parameters in the signature" + " of their __init__ (no varargs)." + f" {cls} with constructor {init_signature} doesn't " + " follow this convention." + ) + # Extract and sort argument names excluding 'self' + return sorted([p.name for p in parameters]) + + def get_params(self, deep=True) -> dict[str, Any]: + """Get parameters for this estimator. + + Parameters + ---------- + deep : bool, default=True + Has no effect, because steps cannot contain nested substeps. + + Returns + ------- + params : dict + Parameter names mapped to their values. + + Notes + ----- + Derived from [1]_. + + References + ---------- + .. [1] https://github.com/scikit-learn/scikit-learn/blob/626b460/sklearn/base.py#L145-L167 + """ + return {key: getattr(self, key) for key in self._get_param_names()} + def __repr__(self) -> str: return pprint.pformat(self) @@ -449,8 +513,6 @@ def set_output( ) -> Recipe: """Set output type returned by `transform`. - This is part of the standard Scikit-Learn API. - Parameters ---------- transform : {"default", "pandas"}, default=None diff --git a/ibis_ml/steps/_common.py b/ibis_ml/steps/_common.py index 0c8b67c..ec8ef59 100644 --- a/ibis_ml/steps/_common.py +++ b/ibis_ml/steps/_common.py @@ -135,6 +135,11 @@ def __init__( self.expr = expr self.named_exprs = named_exprs + @classmethod + def _get_param_names(cls) -> list[str]: + """Get parameter names for the estimator.""" + return ["expr", "inputs", "named_exprs"] + def _repr(self) -> Iterable[tuple[str, Any]]: yield ("", self.inputs) if self.expr is not None: @@ -191,11 +196,15 @@ def __init__( self.exprs = exprs self.named_exprs = named_exprs + @classmethod + def _get_param_names(cls) -> list[str]: + """Get parameter names for the estimator.""" + return ["exprs", "named_exprs"] + def _repr(self) -> Iterable[tuple[str, Any]]: for expr in self.exprs: yield "", expr - for name, expr in self.named_exprs.items(): - yield name, expr + yield from self.named_exprs.items() def is_fitted(self): return True diff --git a/tests/test_common.py b/tests/test_common.py index 5ff5abd..effea3a 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -34,6 +34,7 @@ def test_mutate_at_expr(): res = step.transform_table(t) sol = t.mutate(x=_.x.abs(), y=_.y.abs()) assert res.equals(sol) + assert list(step.get_params()) == ["expr", "inputs", "named_exprs"] def test_mutate_at_named_exprs(): @@ -44,6 +45,7 @@ def test_mutate_at_named_exprs(): res = step.transform_table(t) sol = t.mutate(x=_.x.abs(), y=_.y.abs(), x_log=_.x.log(), y_log=_.y.log()) assert res.equals(sol) + assert list(step.get_params()) == ["expr", "inputs", "named_exprs"] def test_mutate(): @@ -54,3 +56,4 @@ def test_mutate(): res = step.transform_table(t) sol = t.mutate(_.x.abs().name("x_abs"), y_log=lambda t: t.y.log()) assert res.equals(sol) + assert list(step.get_params()) == ["exprs", "named_exprs"] diff --git a/tests/test_core.py b/tests/test_core.py index b16cf59..3994f60 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -365,6 +365,13 @@ def test_errors_nicely_if_not_fitted(table, method): getattr(r, method)(table) +def test_get_params(): + rec = ml.Recipe(ml.ExpandDateTime(ml.timestamp())) + + assert "expanddatetime__components" in rec.get_params(deep=True) + assert "expanddatetime__components" not in rec.get_params(deep=False) + + @pytest.mark.parametrize( ("step", "url"), [