Skip to content

Commit

Permalink
np.in1d -> np.isin
Browse files Browse the repository at this point in the history
  • Loading branch information
zm711 committed Sep 17, 2023
1 parent 642c875 commit bf4789b
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion phy/apps/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def get_spike_raw_amplitudes(self, spike_ids, channel_id=None, **kwargs):
# The cluster assignments of the requested spikes.
spike_clusters = self.supervisor.clustering.spike_clusters[spike_ids]
# Only keep spikes from clusters on the "best" channel.
to_keep = np.in1d(spike_clusters, self.get_clusters_on_channel(channel_id))
to_keep = np.isin(spike_clusters, self.get_clusters_on_channel(channel_id))
waveforms = self.model.get_waveforms(spike_ids[to_keep], [channel_id])
if waveforms is not None:
waveforms = waveforms[..., 0]
Expand Down
2 changes: 1 addition & 1 deletion phy/cluster/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _update_cluster_ids(self, to_remove=None, to_add=None):
self._spikes_per_cluster[clu] = spk
# If spikes_per_cluster is invalid, recompute the entire
# spikes_per_cluster array.
coherent = np.all(np.in1d(self._cluster_ids, sorted(self._spikes_per_cluster)))
coherent = np.all(np.isin(self._cluster_ids, sorted(self._spikes_per_cluster)))
if not coherent:
logger.debug("Recompute spikes_per_cluster manually: this might take a while.")
sc = self._spike_clusters
Expand Down
4 changes: 2 additions & 2 deletions phy/cluster/tests/test_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ def test_extend_spikes():
# These are the spikes belonging to those clusters, but not in the
# originally-specified spikes.
extended = _extend_spikes(spike_ids, spike_clusters)
assert np.all(np.in1d(spike_clusters[extended], clusters))
assert np.all(np.isin(spike_clusters[extended], clusters))

# The function only returns spikes that weren't in the passed spikes.
assert len(np.intersect1d(extended, spike_ids)) == 0

# Check that all spikes from our clusters have been selected.
rest = np.setdiff1d(np.arange(n_spikes), extended)
rest = np.setdiff1d(rest, spike_ids)
assert not np.any(np.in1d(spike_clusters[rest], clusters))
assert not np.any(np.isin(spike_clusters[rest], clusters))


def test_concatenate_spike_clusters():
Expand Down

0 comments on commit bf4789b

Please sign in to comment.