Skip to content

Commit

Permalink
fix(sklearn): add return_std parameter in VarianceEstimator (#101)
Browse files Browse the repository at this point in the history
  • Loading branch information
34j authored Oct 24, 2023
1 parent f168239 commit d7c1625
Showing 1 changed file with 41 additions and 1 deletion.
42 changes: 41 additions & 1 deletion src/boost_loss/regression/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from copy import copy
from typing import Any, Literal, Sequence
from typing import Any, Literal, Sequence, overload

import numpy as np
from joblib import Parallel, delayed
Expand Down Expand Up @@ -184,13 +184,39 @@ def predict_raw(self, X: Any, **predict_params: Any) -> NDArray[Any]:
[estimator.predict(X, **predict_params) for estimator in self.estimators_]
)

@overload
def predict(
self,
X: Any,
type_: Literal["mean", "median", "var", "std", "range", "mae", "mse"]
| None = None,
return_std: Literal[False] = False,
**predict_params: Any,
) -> NDArray[Any]:
...

@overload
def predict(
self,
X: Any,
type_: tuple[
Literal["mean", "median"], Literal["var", "std", "range", "mae", "mse"]
]
| None = None,
return_std: Literal[True] = ...,
**predict_params: Any,
) -> tuple[NDArray[Any], NDArray[Any]]:
...

def predict(
self,
X: Any,
type_: Literal["mean", "median", "var", "std", "range", "mae", "mse"]
| tuple[Literal["mean", "median"], Literal["var", "std", "range", "mae", "mse"]]
| None = None,
return_std: bool = False,
**predict_params: Any,
) -> NDArray[Any] | tuple[NDArray[Any], NDArray[Any]]:
"""Returns predictions of the ensemble.
Parameters
Expand All @@ -200,6 +226,9 @@ def predict(
type_ : Literal['mean', 'median', 'var', 'std', 'range', 'mae', 'mse'], optional
Type of the prediction, by default None
If None, self.m_type is used.
return_std : bool, optional
Whether to return a tuple of (predictions, standard deviation),
by default False
**predict_params : Any
The parameters to be passed to `predict` method of each estimator.
Expand All @@ -213,6 +242,17 @@ def predict(
ValueError
When type_ is not supported.
"""
if return_std or isinstance(type_, tuple):
if isinstance(type_, str):
type_tuple = (type_, self.var_type)
elif type_ is None:
type_tuple = (self.m_type, self.var_type)
else:
type_tuple = type_
return self.predict(
X, type_=type_tuple[0], **predict_params
), self.predict_var(X, type_=type_tuple[1], **predict_params)

type_ = type_ or self.m_type
if type_ == "mean":
return self.predict_raw(X, **predict_params).mean(axis=0)
Expand Down

0 comments on commit d7c1625

Please sign in to comment.