From d7c16255483f87fe0f81b84bb393a0aefa4774d8 Mon Sep 17 00:00:00 2001 From: 34j <55338215+34j@users.noreply.github.com> Date: Tue, 24 Oct 2023 09:41:09 +0900 Subject: [PATCH] fix(sklearn): add `return_std` parameter in `VarianceEstimator` (#101) --- src/boost_loss/regression/sklearn.py | 42 +++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/src/boost_loss/regression/sklearn.py b/src/boost_loss/regression/sklearn.py index d645d8a..3fa95eb 100644 --- a/src/boost_loss/regression/sklearn.py +++ b/src/boost_loss/regression/sklearn.py @@ -1,7 +1,7 @@ from __future__ import annotations from copy import copy -from typing import Any, Literal, Sequence +from typing import Any, Literal, Sequence, overload import numpy as np from joblib import Parallel, delayed @@ -184,13 +184,39 @@ def predict_raw(self, X: Any, **predict_params: Any) -> NDArray[Any]: [estimator.predict(X, **predict_params) for estimator in self.estimators_] ) + @overload def predict( self, X: Any, type_: Literal["mean", "median", "var", "std", "range", "mae", "mse"] | None = None, + return_std: Literal[False] = False, **predict_params: Any, ) -> NDArray[Any]: + ... + + @overload + def predict( + self, + X: Any, + type_: tuple[ + Literal["mean", "median"], Literal["var", "std", "range", "mae", "mse"] + ] + | None = None, + return_std: Literal[True] = ..., + **predict_params: Any, + ) -> tuple[NDArray[Any], NDArray[Any]]: + ... + + def predict( + self, + X: Any, + type_: Literal["mean", "median", "var", "std", "range", "mae", "mse"] + | tuple[Literal["mean", "median"], Literal["var", "std", "range", "mae", "mse"]] + | None = None, + return_std: bool = False, + **predict_params: Any, + ) -> NDArray[Any] | tuple[NDArray[Any], NDArray[Any]]: """Returns predictions of the ensemble. Parameters @@ -200,6 +226,9 @@ def predict( type_ : Literal['mean', 'median', 'var', 'std', 'range', 'mae', 'mse'], optional Type of the prediction, by default None If None, self.m_type is used. + return_std : bool, optional + Whether to return a tuple of (predictions, standard deviation), + by default False **predict_params : Any The parameters to be passed to `predict` method of each estimator. @@ -213,6 +242,17 @@ def predict( ValueError When type_ is not supported. """ + if return_std or isinstance(type_, tuple): + if isinstance(type_, str): + type_tuple = (type_, self.var_type) + elif type_ is None: + type_tuple = (self.m_type, self.var_type) + else: + type_tuple = type_ + return self.predict( + X, type_=type_tuple[0], **predict_params + ), self.predict_var(X, type_=type_tuple[1], **predict_params) + type_ = type_ or self.m_type if type_ == "mean": return self.predict_raw(X, **predict_params).mean(axis=0)