diff --git a/phy/apps/base.py b/phy/apps/base.py index 0425ddc3..8a169892 100644 --- a/phy/apps/base.py +++ b/phy/apps/base.py @@ -160,7 +160,7 @@ def get_spike_raw_amplitudes(self, spike_ids, channel_id=None, **kwargs): # The cluster assignments of the requested spikes. spike_clusters = self.supervisor.clustering.spike_clusters[spike_ids] # Only keep spikes from clusters on the "best" channel. - to_keep = np.in1d(spike_clusters, self.get_clusters_on_channel(channel_id)) + to_keep = np.isin(spike_clusters, self.get_clusters_on_channel(channel_id)) waveforms = self.model.get_waveforms(spike_ids[to_keep], [channel_id]) if waveforms is not None: waveforms = waveforms[..., 0] diff --git a/phy/cluster/clustering.py b/phy/cluster/clustering.py index 6b5fda81..872a49ba 100644 --- a/phy/cluster/clustering.py +++ b/phy/cluster/clustering.py @@ -232,7 +232,7 @@ def _update_cluster_ids(self, to_remove=None, to_add=None): self._spikes_per_cluster[clu] = spk # If spikes_per_cluster is invalid, recompute the entire # spikes_per_cluster array. - coherent = np.all(np.in1d(self._cluster_ids, sorted(self._spikes_per_cluster))) + coherent = np.all(np.isin(self._cluster_ids, sorted(self._spikes_per_cluster))) if not coherent: logger.debug("Recompute spikes_per_cluster manually: this might take a while.") sc = self._spike_clusters diff --git a/phy/cluster/tests/test_clustering.py b/phy/cluster/tests/test_clustering.py index 0a10fbe5..f920c8db 100644 --- a/phy/cluster/tests/test_clustering.py +++ b/phy/cluster/tests/test_clustering.py @@ -50,7 +50,7 @@ def test_extend_spikes(): # These are the spikes belonging to those clusters, but not in the # originally-specified spikes. extended = _extend_spikes(spike_ids, spike_clusters) - assert np.all(np.in1d(spike_clusters[extended], clusters)) + assert np.all(np.isin(spike_clusters[extended], clusters)) # The function only returns spikes that weren't in the passed spikes. assert len(np.intersect1d(extended, spike_ids)) == 0 @@ -58,7 +58,7 @@ def test_extend_spikes(): # Check that all spikes from our clusters have been selected. rest = np.setdiff1d(np.arange(n_spikes), extended) rest = np.setdiff1d(rest, spike_ids) - assert not np.any(np.in1d(spike_clusters[rest], clusters)) + assert not np.any(np.isin(spike_clusters[rest], clusters)) def test_concatenate_spike_clusters():