From bf4789b5c84bce174490dbc3aa998008f6586f84 Mon Sep 17 00:00:00 2001 From: zm711 <92116279+zm711@users.noreply.github.com> Date: Sun, 17 Sep 2023 10:01:34 -0400 Subject: [PATCH] np.in1d -> np.isin --- phy/apps/base.py | 2 +- phy/cluster/clustering.py | 2 +- phy/cluster/tests/test_clustering.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/phy/apps/base.py b/phy/apps/base.py index 0425ddc3f..8a1698928 100644 --- a/phy/apps/base.py +++ b/phy/apps/base.py @@ -160,7 +160,7 @@ def get_spike_raw_amplitudes(self, spike_ids, channel_id=None, **kwargs): # The cluster assignments of the requested spikes. spike_clusters = self.supervisor.clustering.spike_clusters[spike_ids] # Only keep spikes from clusters on the "best" channel. - to_keep = np.in1d(spike_clusters, self.get_clusters_on_channel(channel_id)) + to_keep = np.isin(spike_clusters, self.get_clusters_on_channel(channel_id)) waveforms = self.model.get_waveforms(spike_ids[to_keep], [channel_id]) if waveforms is not None: waveforms = waveforms[..., 0] diff --git a/phy/cluster/clustering.py b/phy/cluster/clustering.py index 6b5fda81f..872a49ba1 100644 --- a/phy/cluster/clustering.py +++ b/phy/cluster/clustering.py @@ -232,7 +232,7 @@ def _update_cluster_ids(self, to_remove=None, to_add=None): self._spikes_per_cluster[clu] = spk # If spikes_per_cluster is invalid, recompute the entire # spikes_per_cluster array. - coherent = np.all(np.in1d(self._cluster_ids, sorted(self._spikes_per_cluster))) + coherent = np.all(np.isin(self._cluster_ids, sorted(self._spikes_per_cluster))) if not coherent: logger.debug("Recompute spikes_per_cluster manually: this might take a while.") sc = self._spike_clusters diff --git a/phy/cluster/tests/test_clustering.py b/phy/cluster/tests/test_clustering.py index 0a10fbe5e..f920c8db8 100644 --- a/phy/cluster/tests/test_clustering.py +++ b/phy/cluster/tests/test_clustering.py @@ -50,7 +50,7 @@ def test_extend_spikes(): # These are the spikes belonging to those clusters, but not in the # originally-specified spikes. extended = _extend_spikes(spike_ids, spike_clusters) - assert np.all(np.in1d(spike_clusters[extended], clusters)) + assert np.all(np.isin(spike_clusters[extended], clusters)) # The function only returns spikes that weren't in the passed spikes. assert len(np.intersect1d(extended, spike_ids)) == 0 @@ -58,7 +58,7 @@ def test_extend_spikes(): # Check that all spikes from our clusters have been selected. rest = np.setdiff1d(np.arange(n_spikes), extended) rest = np.setdiff1d(rest, spike_ids) - assert not np.any(np.in1d(spike_clusters[rest], clusters)) + assert not np.any(np.isin(spike_clusters[rest], clusters)) def test_concatenate_spike_clusters():