Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(core): implement get_params() API for steps #145

Merged
merged 4 commits into from
Sep 14, 2024

Conversation

deepyaman
Copy link
Collaborator

@deepyaman deepyaman commented Aug 30, 2024

Description of changes

For example, from the example notebook:

import ibis_ml as ml

flights_rec = ml.Recipe(
    ml.ExpandDate("date", components=["dow", "month"]),
    ml.Drop("date"),
    ml.TargetEncode(ml.nominal()),
    ml.DropZeroVariance(ml.everything()),
    ml.MutateAt("dep_time", ibis._.hour() * 60 + ibis._.minute()),
    ml.MutateAt(ml.timestamp(), ibis._.epoch_seconds()),
    # By default, PyTorch requires that the type of `X` is `np.float32`.
    # https://discuss.pytorch.org/t/mat1-and-mat2-must-have-the-same-dtype-but-got-double-and-float/197555/2
    ml.Cast(ml.numeric(), "float32"),
)
flights_rec.get_params()

will yield:

{'steps': (ExpandDate(cols(('date',)), components=['dow', 'month']),
  Drop(cols(('date',))),
  TargetEncode(nominal(), smooth=0.0),
  DropZeroVariance(everything(), tolerance=0.0001),
  MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
  MutateAt(timestamp(), _.epoch_seconds()),
  Cast(numeric(), 'float32')),
 'expanddate': ExpandDate(cols(('date',)), components=['dow', 'month']),
 'drop': Drop(cols(('date',))),
 'targetencode': TargetEncode(nominal(), smooth=0.0),
 'dropzerovariance': DropZeroVariance(everything(), tolerance=0.0001),
 'mutateat-1': MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
 'mutateat-2': MutateAt(timestamp(), _.epoch_seconds()),
 'cast': Cast(numeric(), 'float32'),
 'expanddate__components': ['dow', 'month'],
 'expanddate__inputs': cols(('date',)),
 'drop__inputs': cols(('date',)),
 'targetencode__inputs': nominal(),
 'targetencode__smooth': 0.0,
 'dropzerovariance__inputs': everything(),
 'dropzerovariance__tolerance': 0.0001,
 'mutateat-1__expr': ((_.hour() * 60) + _.minute()),
 'mutateat-1__inputs': cols(('dep_time',)),
 'mutateat-1__named_exprs': {},
 'mutateat-2__expr': _.epoch_seconds(),
 'mutateat-2__inputs': timestamp(),
 'mutateat-2__named_exprs': {},
 'cast__dtype': Float32(nullable=True),
 'cast__inputs': numeric()}

Further down pipe.get_params() would also work:

{'memory': None,
 'steps': [('flights_rec',
   Recipe(ExpandDate(cols(('date',)), components=['dow', 'month']),
          Drop(cols(('date',))),
          TargetEncode(nominal(), smooth=0.0),
          DropZeroVariance(everything(), tolerance=0.0001),
          MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
          MutateAt(timestamp(), _.epoch_seconds()),
          Cast(numeric(), 'float32'))),
  ('mod',
   <class 'skorch.classifier.NeuralNetClassifier'>[uninitialized](
     module=<class '__main__.MyModule'>,
   ))],
 'verbose': False,
 'flights_rec': Recipe(ExpandDate(cols(('date',)), components=['dow', 'month']),
        Drop(cols(('date',))),
        TargetEncode(nominal(), smooth=0.0),
        DropZeroVariance(everything(), tolerance=0.0001),
        MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
        MutateAt(timestamp(), _.epoch_seconds()),
        Cast(numeric(), 'float32')),
 'mod': <class 'skorch.classifier.NeuralNetClassifier'>[uninitialized](
   module=<class '__main__.MyModule'>,
 ),
 'flights_rec__steps': (ExpandDate(cols(('date',)), components=['dow', 'month']),
  Drop(cols(('date',))),
  TargetEncode(nominal(), smooth=0.0),
  DropZeroVariance(everything(), tolerance=0.0001),
  MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
  MutateAt(timestamp(), _.epoch_seconds()),
  Cast(numeric(), 'float32')),
 'flights_rec__expanddate': ExpandDate(cols(('date',)), components=['dow', 'month']),
 'flights_rec__drop': Drop(cols(('date',))),
 'flights_rec__targetencode': TargetEncode(nominal(), smooth=0.0),
 'flights_rec__dropzerovariance': DropZeroVariance(everything(), tolerance=0.0001),
 'flights_rec__mutateat-1': MutateAt(cols(('dep_time',)), ((_.hour() * 60) + _.minute())),
 'flights_rec__mutateat-2': MutateAt(timestamp(), _.epoch_seconds()),
 'flights_rec__cast': Cast(numeric(), 'float32'),
 'flights_rec__expanddate__components': ['dow', 'month'],
 'flights_rec__expanddate__inputs': cols(('date',)),
 'flights_rec__drop__inputs': cols(('date',)),
 'flights_rec__targetencode__inputs': nominal(),
 'flights_rec__targetencode__smooth': 0.0,
 'flights_rec__dropzerovariance__inputs': everything(),
 'flights_rec__dropzerovariance__tolerance': 0.0001,
 'flights_rec__mutateat-1__expr': ((_.hour() * 60) + _.minute()),
 'flights_rec__mutateat-1__inputs': cols(('dep_time',)),
 'flights_rec__mutateat-1__named_exprs': {},
 'flights_rec__mutateat-2__expr': _.epoch_seconds(),
 'flights_rec__mutateat-2__inputs': timestamp(),
 'flights_rec__mutateat-2__named_exprs': {},
 'flights_rec__cast__dtype': Float32(nullable=True),
 'flights_rec__cast__inputs': numeric(),
 'mod__module': __main__.MyModule,
 'mod__criterion': torch.nn.modules.loss.NLLLoss,
 'mod__optimizer': torch.optim.sgd.SGD,
 'mod__lr': 0.1,
 'mod__max_epochs': 10,
 'mod__batch_size': 128,
 'mod__iterator_train': torch.utils.data.dataloader.DataLoader,
 'mod__iterator_valid': torch.utils.data.dataloader.DataLoader,
 'mod__dataset': skorch.dataset.Dataset,
 'mod__train_split': <skorch.dataset.ValidSplit object at 0x31cbccad0>,
 'mod__callbacks': None,
 'mod__predict_nonlinearity': 'auto',
 'mod__warm_start': False,
 'mod__verbose': 1,
 'mod__device': 'cpu',
 'mod__compile': False,
 'mod__use_caching': 'auto',
 'mod___params_to_validate': {'iterator_train__shuffle'},
 'mod__iterator_train__shuffle': True,
 'mod__classes': None,
 'mod__callbacks__epoch_timer': <skorch.callbacks.logging.EpochTimer at 0x15fc20690>,
 'mod__callbacks__train_loss': <skorch.callbacks.scoring.PassthroughScoring at 0x31c67ce50>,
 'mod__callbacks__train_loss__name': 'train_loss',
 'mod__callbacks__train_loss__lower_is_better': True,
 'mod__callbacks__train_loss__on_train': True,
 'mod__callbacks__valid_loss': <skorch.callbacks.scoring.PassthroughScoring at 0x31c7c47d0>,
 'mod__callbacks__valid_loss__name': 'valid_loss',
 'mod__callbacks__valid_loss__lower_is_better': True,
 'mod__callbacks__valid_loss__on_train': False,
 'mod__callbacks__valid_acc': <skorch.callbacks.scoring.EpochScoring at 0x1499f68d0>,
 'mod__callbacks__valid_acc__scoring': 'accuracy',
 'mod__callbacks__valid_acc__lower_is_better': False,
 'mod__callbacks__valid_acc__on_train': False,
 'mod__callbacks__valid_acc__name': 'valid_acc',
 'mod__callbacks__valid_acc__target_extractor': <function skorch.utils.to_numpy(X)>,
 'mod__callbacks__valid_acc__use_caching': True,
 'mod__callbacks__print_log': <skorch.callbacks.logging.PrintLog at 0x31cbbfc50>,
 'mod__callbacks__print_log__keys_ignored': None,
 'mod__callbacks__print_log__sink': <function print(*args, sep=' ', end='\n', file=None, flush=False)>,
 'mod__callbacks__print_log__tablefmt': 'simple',
 'mod__callbacks__print_log__floatfmt': '.4f',
 'mod__callbacks__print_log__stralign': 'right'}

Issues closed

Partially addresses #135

Will probably leave it open until address set_params(), too

@@ -135,6 +135,11 @@ def __init__(
self.expr = expr
self.named_exprs = named_exprs

@classmethod
def _get_param_names(cls) -> list[str]:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that I have to override this, but I guess it's better to do this for now (rather than change the API to require specifying expressions in a dictionary).

@deepyaman deepyaman requested review from jcrist and jitingxu1 August 30, 2024 07:34
@deepyaman deepyaman self-assigned this Aug 30, 2024
@codecov-commenter
Copy link

codecov-commenter commented Aug 30, 2024

Codecov Report

Attention: Patch coverage is 89.28571% with 3 lines in your changes missing coverage. Please review.

Project coverage is 85.53%. Comparing base (aa71647) to head (0387595).

Files with missing lines Patch % Lines
ibis_ml/core.py 85.71% 2 Missing ⚠️
ibis_ml/steps/_common.py 85.71% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #145      +/-   ##
==========================================
+ Coverage   85.24%   85.53%   +0.29%     
==========================================
  Files          27       27              
  Lines        1938     1964      +26     
==========================================
+ Hits         1652     1680      +28     
+ Misses        286      284       -2     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@deepyaman deepyaman marked this pull request as ready for review August 30, 2024 07:45
ibis_ml/core.py Outdated
Comment on lines 363 to 380
# `hasattr()` always returns `True` for deferred objects
and not isinstance(value, (type, Deferred))
Copy link
Collaborator Author

@deepyaman deepyaman Aug 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, I guess can just get rid of this whole if block? A Step can't be nested in another Step—at least for now.

Probably better...

Lines 366–7 aren't currently being hit, and I don't see any reasonable way it would be hit.

Comment on lines +324 to +347
if p.kind == p.VAR_POSITIONAL:
raise RuntimeError(
"scikit-learn estimators should always "
"specify their parameters in the signature"
" of their __init__ (no varargs)."
f" {cls} with constructor {init_signature} doesn't "
" follow this convention."
)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This validation won't get hit unless you don't override _get_param_names for a step with varargs in the constructor (e.g. Mutate or MutateAt. I suppose I can add a test for malformed custom step?

Copy link
Collaborator

@jitingxu1 jitingxu1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks great to me. Feel free to merge it on your own.

ibis_ml/core.py Show resolved Hide resolved
ibis_ml/core.py Show resolved Hide resolved
@deepyaman deepyaman force-pushed the feat/core/step-get-params branch from 08deb9d to 6c8c815 Compare September 13, 2024 23:34
@deepyaman deepyaman force-pushed the feat/core/step-get-params branch from 5e8c761 to 0387595 Compare September 13, 2024 23:59
@deepyaman deepyaman merged commit 1c8706e into ibis-project:main Sep 14, 2024
4 checks passed
@deepyaman deepyaman deleted the feat/core/step-get-params branch September 14, 2024 00:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants