diff --git a/eli5/permutation_importance.py b/eli5/permutation_importance.py index b5c4a3f0..b3cdad7d 100644 --- a/eli5/permutation_importance.py +++ b/eli5/permutation_importance.py @@ -15,7 +15,7 @@ import numpy as np # type: ignore from sklearn.utils import check_random_state # type: ignore - +from multiprocess import Pool # type: ignore def iter_shuffled(X, columns_to_shuffle=None, pre_shuffle=False, random_state=None): @@ -58,7 +58,8 @@ def get_score_importances( y, n_iter=5, # type: int columns_to_shuffle=None, - random_state=None + random_state=None, + n_jobs=1 ): # type: (...) -> Tuple[float, List[np.ndarray]] """ @@ -84,12 +85,15 @@ def get_score_importances( """ rng = check_random_state(random_state) base_score = score_func(X, y) + seed0 = rng.randint(2**32) + pool = Pool(n_jobs) + result = pool.map( + lambda seed: _get_scores_shufled(score_func, X, y, + columns_to_shuffle=columns_to_shuffle, + random_state=np.random.RandomState(seed)), + range(seed0, seed0+n_iter)) scores_decreases = [] - for i in range(n_iter): - scores_shuffled = _get_scores_shufled( - score_func, X, y, columns_to_shuffle=columns_to_shuffle, - random_state=rng - ) + for scores_shuffled in result: scores_decreases.append(-scores_shuffled + base_score) return base_score, scores_decreases diff --git a/eli5/sklearn/permutation_importance.py b/eli5/sklearn/permutation_importance.py index 30ab3cad..dfa08acf 100644 --- a/eli5/sklearn/permutation_importance.py +++ b/eli5/sklearn/permutation_importance.py @@ -117,6 +117,8 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin): Whether to fit the estimator on the whole data if cross-validation is used (default is True). + n_jobs : int, number of parallel jobs for shuffle iterations + Attributes ---------- feature_importances_ : array @@ -142,7 +144,7 @@ class PermutationImportance(BaseEstimator, MetaEstimatorMixin): random state """ def __init__(self, estimator, scoring=None, n_iter=5, random_state=None, - cv='prefit', refit=True): + cv='prefit', refit=True, n_jobs=1): # type: (...) -> None if isinstance(cv, str) and cv != "prefit": raise ValueError("Invalid cv value: {!r}".format(cv)) @@ -152,6 +154,7 @@ def __init__(self, estimator, scoring=None, n_iter=5, random_state=None, self.n_iter = n_iter self.random_state = random_state self.cv = cv + self.n_jobs = n_jobs self.rng_ = check_random_state(random_state) def _wrap_scorer(self, base_scorer, pd_columns): @@ -228,7 +231,7 @@ def _non_cv_scores_importances(self, X, y): def _get_score_importances(self, score_func, X, y): return get_score_importances(score_func, X, y, n_iter=self.n_iter, - random_state=self.rng_) + random_state=self.rng_, n_jobs=self.n_jobs) @property def caveats_(self): diff --git a/requirements.txt b/requirements.txt index ca97e5d2..b70181b1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ attrs > 16.0.0 jinja2 pip >= 8.1 setuptools >= 20.7 +multiprocess diff --git a/setup.py b/setup.py index 011fdd28..ab430f25 100755 --- a/setup.py +++ b/setup.py @@ -40,6 +40,7 @@ def get_long_description(): 'typing', 'graphviz', 'tabulate>=0.7.7', + 'multiprocess', ], extras_require={ ":python_version<'3.5.6'": [ diff --git a/tests/test_permutation_importance.py b/tests/test_permutation_importance.py index effb4ff9..05567e0f 100644 --- a/tests/test_permutation_importance.py +++ b/tests/test_permutation_importance.py @@ -42,10 +42,11 @@ def is_shuffled(X, X_sh, col): def test_get_feature_importances(boston_train): X, y, feat_names = boston_train svr = SVR(C=20).fit(X, y) - score, importances = get_score_importances(svr.score, X, y) - assert score > 0.7 - importances = dict(zip(feat_names, np.mean(importances, axis=0))) - print(score) - print(importances) - assert importances['AGE'] > importances['NOX'] - assert importances['B'] > importances['CHAS'] + for n_jobs in [1, 2]: + score, importances = get_score_importances(svr.score, X, y, n_jobs=n_jobs) + assert score > 0.7 + importances = dict(zip(feat_names, np.mean(importances, axis=0))) + print(score) + print(importances) + assert importances['AGE'] > importances['NOX'] + assert importances['B'] > importances['CHAS']