diff --git a/examples/plot_01_survival_analysis.py b/examples/plot_01_survival_analysis.py index 316bbfd..8ebb7f4 100644 --- a/examples/plot_01_survival_analysis.py +++ b/examples/plot_01_survival_analysis.py @@ -41,8 +41,6 @@ # random variable :math:`T`, and the censoring date, represented by :math:`C`. # # In this dataset, approximately 42% of the data is censored.. - -# %% y["event"].value_counts(normalize=True) # %% @@ -71,7 +69,6 @@ from hazardous import SurvivalBoost survival_boost = SurvivalBoost(show_progressbar=False).fit(X_train, y_train) - survival_boost # %% @@ -94,34 +91,70 @@ # Let's plot the estimated survival function for some patients. import matplotlib.pyplot as plt -fig, ax = plt.subplots() + +def plot_survival_curves(patient_ids_to_plot, time_grid, survival_curves): + fig, ax = plt.subplots() + + for idx in patient_ids_to_plot: + ax.plot(time_grid, survival_curves[idx], label=f"Patient {idx}") + + # plot symbols for death or censoring + event = y_test.iloc[idx]["event"] + duration = y_test.iloc[idx]["duration"] + + # find the index of time closest to duration + jdx = np.searchsorted(time_grid, duration) + smiley = "☠️" if event == 1 else "✖" + ax.text( + duration, + survival_curves[idx, jdx], + smiley, + fontsize=20, + color=ax.lines[idx].get_color(), + ) + + ax.legend() + ax.set_title("") + ax.set_xlabel("Months") + ax.set_ylabel("Predicted Survival Probability") + + plt.show() + patient_ids_to_plot = [0, 1, 2, 3] +plot_survival_curves( + patient_ids_to_plot, + survival_boost.time_grid_, + survival_curves, +) -for idx in patient_ids_to_plot: - ax.plot(survival_boost.time_grid_, survival_curves[idx], label=f"Patient {idx}") - - # plot symbols for death or censoring - event = y_test.iloc[idx]["event"] - duration = y_test.iloc[idx]["duration"] - - # find the index of time closest to duration - jdx = np.searchsorted(survival_boost.time_grid_, duration) - smiley = "☠️" if event == 1 else "✖" - ax.text( - duration, - survival_curves[idx, jdx], - smiley, - fontsize=20, - color=ax.lines[idx].get_color(), - ) +# %% +# Bagging for curves smoothing +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Bagging can help us smooth our survival and incidence curves, at the cost of +# fitting ``SurvivalBoost`` multiple times. -ax.legend() -ax.set_title("") -ax.set_xlabel("Months") -ax.set_ylabel("Predicted Survival Probability") +from hazardous import BaggingSurvival -plt.show() + +bagging_est = BaggingSurvival( + survival_boost, + n_estimators=5, + bootstrap=False, +).fit(X, y) + +smooth_curves = bagging_est.predict_cumulative_incidence( + X_test, + times=None, +) +smooth_survival_curves = smooth_curves[:, 0] # survival function S(t) + +plot_survival_curves( + patient_ids_to_plot, + bagging_est.time_grid_, + smooth_survival_curves, +) # %% # diff --git a/hazardous/__init__.py b/hazardous/__init__.py index dac08e0..f591e46 100644 --- a/hazardous/__init__.py +++ b/hazardous/__init__.py @@ -1,5 +1,6 @@ from pathlib import Path +from ._bagging import BaggingSurvival from ._survival_boost import SurvivalBoost with open(Path(__file__).parent / "VERSION.txt") as _fh: @@ -9,4 +10,5 @@ __all__ = [ "metrics", "SurvivalBoost", + "BaggingSurvival", ] diff --git a/hazardous/_bagging.py b/hazardous/_bagging.py new file mode 100644 index 0000000..f95a0d2 --- /dev/null +++ b/hazardous/_bagging.py @@ -0,0 +1,271 @@ +from copy import deepcopy +from warnings import warn + +import numpy as np +from joblib import Parallel, delayed, effective_n_jobs +from sklearn.base import check_array, check_is_fitted +from sklearn.ensemble._bagging import BaseBagging +from sklearn.utils._param_validation import HasMethods + +from ._survival_boost import SurvivalBoost +from .base import SurvivalMixin +from .metrics import mean_integrated_brier_score +from .utils import ( + _dict_to_recarray, + check_y_survival, + get_unique_events, + make_time_grid, +) + + +class BaggingSurvival(BaseBagging, SurvivalMixin): + """TODO""" + + _parameter_constraints = deepcopy(BaseBagging._parameter_constraints) + _parameter_constraints["estimator"] = [ + HasMethods(["fit", "score", "predict_cumulative_incidence"]) + ] + + def __init__( + self, + estimator=None, + n_estimators=3, + *, + max_samples=1.0, + max_features=1.0, + bootstrap=True, + bootstrap_features=False, + oob_score=False, + warm_start=False, + n_jobs=None, + random_state=None, + verbose=0, + ): + super().__init__( + estimator=estimator, + n_estimators=n_estimators, + max_samples=max_samples, + max_features=max_features, + bootstrap=bootstrap, + bootstrap_features=bootstrap_features, + oob_score=oob_score, + warm_start=warm_start, + n_jobs=n_jobs, + random_state=random_state, + verbose=verbose, + ) + + def _get_estimator(self): + """Resolve which estimator to return""" + if self.estimator is None: + return SurvivalBoost(show_progressbar=False) + return self.estimator + + def _set_oob_score(self, X, y): + n_samples = y.shape[0] + n_events_ = self.n_events_ + n_time_steps_ = self.time_grid_.shape[0] + + y_pred = np.zeros((n_samples, n_events_, n_time_steps_)) + + for estimator, samples, features in zip( + self.estimators_, self.estimators_samples_, self.estimators_features_ + ): + # Create mask for OOB samples + mask = ~indices_to_mask(samples, n_samples) + + y_pred[mask, :] += estimator.predict_proba((X[mask, :])[:, features]) + + if (y_pred.sum(axis=(1, 2)) == 0).any(): + warn( + "Some inputs do not have OOB scores. " + "This probably means too few estimators were used " + "to compute any reliable oob estimates." + ) + + self.oob_score_ = -mean_integrated_brier_score( + y_train=self.weighted_targets_.y_train, + y_test=y, + y_pred=y_pred, + time_grid=self.time_grid_, + ) + + def _validate_y(self, y): + event, duration = check_y_survival(y) + self.event_ids_ = get_unique_events(event) + self.n_events_ = len(self.event_ids_) + + base_estimator = self._get_estimator() + self.time_grid_ = make_time_grid( + event, + duration, + base_estimator.n_time_grid_steps, + ) + self.y_train_ = y # XXX: Used by SurvivalMixin.score() + self.time_horizon_ = base_estimator.time_horizon + + return y + + def fit(self, X, y, **fit_params): + y = _dict_to_recarray(y) + return super().fit(X, y, **fit_params) + + def predict_cumulative_incidence(self, X, times=None): + """TODO""" + check_is_fitted(self) + + # Check data + X = check_array(X, force_all_finite="allow-nan") + + # Parallel loop + n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) + + # Get time grid + times = times or self.time_grid_ + + all_proba = Parallel( + n_jobs=n_jobs, verbose=self.verbose, **self._parallel_args() + )( + delayed(_parallel_predict_cumulative_incidence)( + self.estimators_[starts[i] : starts[i + 1]], + self.estimators_features_[starts[i] : starts[i + 1]], + X, + times, + self.n_events_, + ) + for i in range(n_jobs) + ) + + # Reduce + proba = sum(all_proba) / self.n_estimators + + return proba + + def predict_survival_function(self, X, times=None): + return self.predict_cumulative_incidence(X, times=times)[:, 0, :] + + def predict_proba(self, X, time_horizon=None): + """TODO""" + check_is_fitted(self) + + # Check data + X = check_array(X, force_all_finite="allow-nan") + + # Parallel loop + n_jobs, _, starts = _partition_estimators(self.n_estimators, self.n_jobs) + + # Get time grid + time_horizon = time_horizon or self.time_horizon_ + + all_proba = Parallel( + n_jobs=n_jobs, verbose=self.verbose, **self._parallel_args() + )( + delayed(_parallel_predict_proba)( + self.estimators_[starts[i] : starts[i + 1]], + self.estimators_features_[starts[i] : starts[i + 1]], + X, + time_horizon, + self.n_events_, + ) + for i in range(n_jobs) + ) + + # Reduce + proba = sum(all_proba) / self.n_estimators + + return proba + + +def _parallel_predict_cumulative_incidence( + estimators, + estimators_features, + X, + times, + n_events, +): + """Private function used to compute (proba-)predictions within a job.""" + n_samples = X.shape[0] + n_time_steps = times.shape[0] + proba = np.zeros((n_samples, n_events, n_time_steps)) + + for estimator, features in zip(estimators, estimators_features): + proba_estimator = estimator.predict_cumulative_incidence( + X[:, features], times=times + ) + + if n_events == len(estimator.event_ids_): + proba += proba_estimator + + else: + proba[:, estimator.event_ids_] += proba_estimator[ + :, range(len(estimator.event_ids_)) + ] + + return proba + + +def _parallel_predict_proba( + estimators, + estimators_features, + X, + time_horizon, + n_events, +): + """Private function used to compute (proba-)predictions within a job.""" + n_samples = X.shape[0] + proba = np.zeros((n_samples, n_events)) + + for estimator, features in zip(estimators, estimators_features): + proba_estimator = estimator.predict_proba( + X[:, features], time_horizon=time_horizon + ) + + if n_events == len(estimator.event_ids_): + proba += proba_estimator + + else: + proba[:, estimator.event_ids_] += proba_estimator[ + :, range(len(estimator.event_ids_)) + ] + + return proba + + +# Vendored from a private module in sklearn. +def indices_to_mask(indices, mask_length): + """Convert list of indices to boolean mask. + + Parameters + ---------- + indices : list-like + List of integers treated as indices. + mask_length : int + Length of boolean mask to be generated. + This parameter must be greater than max(indices). + + Returns + ------- + mask : 1d boolean nd-array + Boolean array that is True where indices are present, else False. + """ + if mask_length <= np.max(indices): + raise ValueError("mask_length must be greater than max(indices)") + + mask = np.zeros(mask_length, dtype=bool) + mask[indices] = True + + return mask + + +# Vendored from a private module in sklearn. +def _partition_estimators(n_estimators, n_jobs): + """Private function used to partition estimators between jobs.""" + # Compute the number of jobs + n_jobs = min(effective_n_jobs(n_jobs), n_estimators) + + # Partition estimators between jobs + n_estimators_per_job = np.full(n_jobs, n_estimators // n_jobs, dtype=int) + n_estimators_per_job[: n_estimators % n_jobs] += 1 + starts = np.cumsum(n_estimators_per_job) + + return n_jobs, n_estimators_per_job.tolist(), [0] + starts.tolist() diff --git a/hazardous/_survival_boost.py b/hazardous/_survival_boost.py index 0400806..cd66d19 100644 --- a/hazardous/_survival_boost.py +++ b/hazardous/_survival_boost.py @@ -1,18 +1,15 @@ from numbers import Real import numpy as np -from sklearn.base import BaseEstimator, ClassifierMixin +from sklearn.base import BaseEstimator from sklearn.ensemble import HistGradientBoostingClassifier from sklearn.utils.validation import check_array, check_random_state from tqdm import tqdm from ._ipcw import AlternatingCensoringEstimator, KaplanMeierIPCW -from .metrics._brier_score import ( - IncidenceScoreComputer, - integrated_brier_score_incidence, - integrated_brier_score_survival, -) -from .utils import check_y_survival +from .base import SurvivalMixin +from .metrics._brier_score import IncidenceScoreComputer +from .utils import check_y_survival, get_unique_events, make_time_grid class WeightedMultiClassTargetSampler(IncidenceScoreComputer): @@ -168,7 +165,7 @@ def fit(self, X): ) -class SurvivalBoost(BaseEstimator, ClassifierMixin): +class SurvivalBoost(BaseEstimator, SurvivalMixin): r"""Cause-specific Cumulative Incidence Function (CIF) with GBDT [1]_. This model estimates the cause-specific Cumulative Incidence Function (CIF) for @@ -381,45 +378,18 @@ def fit(self, X, y, times=None): event, duration = check_y_survival(y) # Add 0 as a special event id for the survival function. - self.event_ids_ = np.array(sorted(list(set([0]) | set(event)))) - - self.estimator_ = self._build_base_estimator() + self.event_ids_ = get_unique_events(event) # Compute the default time grid used at prediction time. - any_event_mask = event > 0 - observed_times = duration[any_event_mask] - if times is None: - if observed_times.shape[0] > self.n_time_grid_steps: - self.time_grid_ = np.quantile( - observed_times, np.linspace(0, 1, num=self.n_time_grid_steps) - ) - else: - self.time_grid_ = observed_times.copy() - self.time_grid_.sort() + self.time_grid_ = make_time_grid(event, duration, self.n_time_grid_steps) else: self.time_grid_ = times.copy() self.time_grid_.sort() - if self.ipcw_strategy == "alternating": - ipcw_estimator = AlternatingCensoringEstimator( - incidence_estimator=self.estimator_ - ) - elif self.ipcw_strategy == "kaplan-meier": - ipcw_estimator = KaplanMeierIPCW() - else: - raise ValueError( - f"Invalid parameter value: ipcw_strategy={self.ipcw_strategy!r}. " - "Valid values are 'alternating' and 'kaplan-meier'." - ) + self.estimator_ = self._build_base_estimator() - self.weighted_targets_ = WeightedMultiClassTargetSampler( - y, - hard_zero_fraction=self.hard_zero_fraction, - random_state=self.random_state, - ipcw_estimator=ipcw_estimator, - n_iter_before_feedback=self.n_iter_before_feedback, - ) + self.weighted_targets_ = self._check_target_sampling(y) iterator = range(self.n_iter) if self.show_progressbar: @@ -453,7 +423,7 @@ def fit(self, X, y, times=None): ) if (idx_iter % self.n_iter_before_feedback == 0) and isinstance( - ipcw_estimator, AlternatingCensoringEstimator + self.weighted_targets_.ipcw_estimator, AlternatingCensoringEstimator ): self.weighted_targets_.fit(X) @@ -584,52 +554,26 @@ def _build_base_estimator(self): min_samples_leaf=self.min_samples_leaf, ) - def score(self, X, y): - """Return the mean of IBS for each event of interest and survival. - - This returns the negative of the mean of the Integrated Brier Score - (IBS, a proper scoring rule) of each competing event as well as the IBS - of the survival to any event. So, the higher the value, the better the - model to be consistent with the scoring convention of scikit-learn to - make it possible to use this class with scikit-learn model selection - utilities such as ``GridSearchCV`` and ``RandomizedSearchCV``. - - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - The input samples. - y : dict with keys "event" and "duration" - The target values. "event" is a boolean array of shape (n_samples,) - indicating whether the event was observed or not. "duration" is a - float array of shape (n_samples,) indicating the time of the event - or the time of censoring. + def _check_target_sampling(self, y): + if self.ipcw_strategy == "alternating": + ipcw_estimator = AlternatingCensoringEstimator( + incidence_estimator=self.estimator_ + ) + elif self.ipcw_strategy == "kaplan-meier": + ipcw_estimator = KaplanMeierIPCW() + else: + raise ValueError( + f"Invalid parameter value: ipcw_strategy={self.ipcw_strategy!r}. " + "Valid values are 'alternating' and 'kaplan-meier'." + ) - Returns - ------- - score : float - The negative of time-integrated Brier score (IBS). + weighted_targets = WeightedMultiClassTargetSampler( + y, + hard_zero_fraction=self.hard_zero_fraction, + random_state=self.random_state, + ipcw_estimator=ipcw_estimator, + n_iter_before_feedback=self.n_iter_before_feedback, + ) + self.y_train_ = y - TODO: implement time integrated NLL and use as the default for the - .score method to match the objective function used at fit time. - """ - predicted_curves = self.predict_cumulative_incidence(X) - ibs_events = [] - for event_idx in self.event_ids_: - predicted_curves_for_event = predicted_curves[:, event_idx] - if event_idx == 0: - ibs_event = integrated_brier_score_survival( - y_train=self.weighted_targets_.y_train, - y_test=y, - y_pred=predicted_curves_for_event, - times=self.time_grid_, - ) - else: - ibs_event = integrated_brier_score_incidence( - y_train=self.weighted_targets_.y_train, - y_test=y, - y_pred=predicted_curves_for_event, - times=self.time_grid_, - event_of_interest=event_idx, - ) - ibs_events.append(ibs_event) - return -np.mean(ibs_events) + return weighted_targets diff --git a/hazardous/base.py b/hazardous/base.py new file mode 100644 index 0000000..b052060 --- /dev/null +++ b/hazardous/base.py @@ -0,0 +1,44 @@ +from .metrics._brier_score import mean_integrated_brier_score + + +class SurvivalMixin: + _estimator_type = "survival" + + def score(self, X, y): + """Return the mean of IBS for each event of interest and survival. + + This returns the negative of the mean of the Integrated Brier Score + (IBS, a proper scoring rule) of each competing event as well as the IBS + of the survival to any event. So, the higher the value, the better the + model to be consistent with the scoring convention of scikit-learn to + make it possible to use this class with scikit-learn model selection + utilities such as ``GridSearchCV`` and ``RandomizedSearchCV``. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + The input samples. + y : dict with keys "event" and "duration" + The target values. "event" is a boolean array of shape (n_samples,) + indicating whether the event was observed or not. "duration" is a + float array of shape (n_samples,) indicating the time of the event + or the time of censoring. + + Returns + ------- + score : float + The negative of time-integrated Brier score (IBS). + + TODO: implement time integrated NLL and use as the default for the + .score method to match the objective function used at fit time. + """ + y_pred = self.predict_cumulative_incidence(X) + return -mean_integrated_brier_score( + y_train=self.y_train_, + y_test=y, + y_pred=y_pred, + time_grid=self.time_grid_, + ) + + def __sklearn_tags__(self): + return {"requires_y": True} diff --git a/hazardous/metrics/__init__.py b/hazardous/metrics/__init__.py index b7978de..ad5ab1f 100644 --- a/hazardous/metrics/__init__.py +++ b/hazardous/metrics/__init__.py @@ -3,6 +3,7 @@ brier_score_survival, integrated_brier_score_incidence, integrated_brier_score_survival, + mean_integrated_brier_score, ) __all__ = [ @@ -10,4 +11,5 @@ "brier_score_incidence", "integrated_brier_score_survival", "integrated_brier_score_incidence", + "mean_integrated_brier_score", ] diff --git a/hazardous/metrics/_brier_score.py b/hazardous/metrics/_brier_score.py index f573f82..0d03f78 100644 --- a/hazardous/metrics/_brier_score.py +++ b/hazardous/metrics/_brier_score.py @@ -572,3 +572,29 @@ def integrated_brier_score_incidence( event_of_interest=event_of_interest, ) return computer.integrated_brier_score_incidence(y_test, y_pred, times) + + +def mean_integrated_brier_score(y_train, y_test, y_pred, time_grid): + if y_pred.ndim != 3: + raise ValueError(f"y_pred must be 3D, got shape: {y_pred.shape}") + + ibs_events = [] + for event_idx in range(y_pred.shape[1]): + if event_idx == 0: + ibs_event = integrated_brier_score_survival( + y_train=y_train, + y_test=y_test, + y_pred=y_pred[:, event_idx], + times=time_grid, + ) + else: + ibs_event = integrated_brier_score_incidence( + y_train=y_train, + y_test=y_test, + y_pred=y_pred[:, event_idx], + times=time_grid, + event_of_interest=event_idx, + ) + ibs_events.append(ibs_event) + + return np.mean(ibs_events) diff --git a/hazardous/utils.py b/hazardous/utils.py index 2c74e9b..7f9d249 100644 --- a/hazardous/utils.py +++ b/hazardous/utils.py @@ -16,6 +16,7 @@ def _dict_to_recarray(y, cast_event_to_bool=False): ) y_out["event"] = y["event"] y_out["duration"] = y["duration"] + return y_out @@ -53,3 +54,49 @@ def check_event_of_interest(k): f"got: event_of_interest={k}" ) return + + +def get_unique_events(event): + """Get the unique events, including censoring 0. + + Parameters + ---------- + event : array of shape (n_samples,) + + Returns + ------- + unique_event : array of shape (n_unique_event,) + """ + return np.array(sorted(list(set([0]) | set(event)))) + + +def make_time_grid(event, duration, n_time_grid_steps): + """Compute a time grid on observed events. + + The time grid size is the minimum between ``n_time_grid_steps`` and + the number of unique observed durations. + + Parameters + ---------- + event : array of shape (n_samples,) + duration : array of shape (n_samples,) + n_time_grid_steps : int + + Returns + ------- + time_grid : array of shape (n_time_steps,) + Note that n_time_steps <= n_time_grid_steps + """ + + any_event_mask = event > 0 + observed_times = duration[any_event_mask] + + if observed_times.shape[0] > n_time_grid_steps: + time_grid = np.quantile( + observed_times, np.linspace(0, 1, num=n_time_grid_steps) + ) + else: + time_grid = observed_times.copy() + time_grid.sort() + + return time_grid diff --git a/pyproject.toml b/pyproject.toml index a7cc94c..dd22945 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,7 +37,7 @@ test = [ ] oldest_deps = [ "numpy==1.22", - "scikit-learn==1.1.3", + "scikit-learn==1.3.2", "pandas==1.5.1", ] doc = [