Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): implement get_params() API for steps #145

Merged
merged 4 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 64 additions & 2 deletions ibis_ml/core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
import inspect
import os
import pprint
from collections import defaultdict
Expand Down Expand Up @@ -306,6 +307,69 @@ def categorize(df: pd.DataFrame, categories: dict[str, list[Any]]) -> pd.DataFra


class Step:
@classmethod
def _get_param_names(cls) -> list[str]:
"""Get parameter names for the estimator.
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

Notes
-----
Copied from [1]_.

References
----------
.. [1] https://github.com/scikit-learn/scikit-learn/blob/ab2f539/sklearn/base.py#L148-L173
"""
# fetch the constructor or the original constructor before
# deprecation wrapping if any
init = getattr(cls.__init__, "deprecated_original", cls.__init__)
if init is object.__init__:
# No explicit constructor to introspect
return []

# introspect the constructor arguments to find the model parameters
# to represent
init_signature = inspect.signature(init)
# Consider the constructor parameters excluding 'self'
parameters = [
p
for p in init_signature.parameters.values()
if p.name != "self" and p.kind != p.VAR_KEYWORD
]
for p in parameters:
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-learn estimators should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
f" {cls} with constructor {init_signature} doesn't "
" follow this convention."
)
Comment on lines +339 to +346
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation won't get hit unless you don't override _get_param_names for a step with varargs in the constructor (e.g. Mutate or MutateAt. I suppose I can add a test for malformed custom step?

# Extract and sort argument names excluding 'self'
return sorted([p.name for p in parameters])

def get_params(self, deep=True) -> dict[str, Any]:
"""Get parameters for this estimator.
deepyaman marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
deep : bool, default=True
Has no effect, because steps cannot contain nested substeps.

Returns
-------
params : dict
Parameter names mapped to their values.

Notes
-----
Derived from [1]_.

References
----------
.. [1] https://github.com/scikit-learn/scikit-learn/blob/626b460/sklearn/base.py#L145-L167
"""
return {key: getattr(self, key) for key in self._get_param_names()}

def __repr__(self) -> str:
return pprint.pformat(self)

Expand Down Expand Up @@ -449,8 +513,6 @@ def set_output(
) -> Recipe:
"""Set output type returned by `transform`.

This is part of the standard Scikit-Learn API.

Parameters
----------
transform : {"default", "pandas"}, default=None
Expand Down
13 changes: 11 additions & 2 deletions ibis_ml/steps/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,11 @@ def __init__(
self.expr = expr
self.named_exprs = named_exprs

@classmethod
def _get_param_names(cls) -> list[str]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that I have to override this, but I guess it's better to do this for now (rather than change the API to require specifying expressions in a dictionary).

"""Get parameter names for the estimator."""
return ["expr", "inputs", "named_exprs"]

def _repr(self) -> Iterable[tuple[str, Any]]:
yield ("", self.inputs)
if self.expr is not None:
Expand Down Expand Up @@ -191,11 +196,15 @@ def __init__(
self.exprs = exprs
self.named_exprs = named_exprs

@classmethod
def _get_param_names(cls) -> list[str]:
"""Get parameter names for the estimator."""
return ["exprs", "named_exprs"]

def _repr(self) -> Iterable[tuple[str, Any]]:
for expr in self.exprs:
yield "", expr
for name, expr in self.named_exprs.items():
yield name, expr
yield from self.named_exprs.items()

def is_fitted(self):
return True
Expand Down
3 changes: 3 additions & 0 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def test_mutate_at_expr():
res = step.transform_table(t)
sol = t.mutate(x=_.x.abs(), y=_.y.abs())
assert res.equals(sol)
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]


def test_mutate_at_named_exprs():
Expand All @@ -44,6 +45,7 @@ def test_mutate_at_named_exprs():
res = step.transform_table(t)
sol = t.mutate(x=_.x.abs(), y=_.y.abs(), x_log=_.x.log(), y_log=_.y.log())
assert res.equals(sol)
assert list(step.get_params()) == ["expr", "inputs", "named_exprs"]


def test_mutate():
Expand All @@ -54,3 +56,4 @@ def test_mutate():
res = step.transform_table(t)
sol = t.mutate(_.x.abs().name("x_abs"), y_log=lambda t: t.y.log())
assert res.equals(sol)
assert list(step.get_params()) == ["exprs", "named_exprs"]
7 changes: 7 additions & 0 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,13 @@ def test_errors_nicely_if_not_fitted(table, method):
getattr(r, method)(table)


def test_get_params():
rec = ml.Recipe(ml.ExpandDateTime(ml.timestamp()))

assert "expanddatetime__components" in rec.get_params(deep=True)
assert "expanddatetime__components" not in rec.get_params(deep=False)


@pytest.mark.parametrize(
("step", "url"),
[
Expand Down
Loading