diff --git a/src/spikeanalysis/plotbase.py b/src/spikeanalysis/plotbase.py index 8b357df..b765257 100644 --- a/src/spikeanalysis/plotbase.py +++ b/src/spikeanalysis/plotbase.py @@ -142,7 +142,8 @@ def set_plot_kwargs(self, ax: plt.axes, plot_kwargs: namedtuple): if plot_kwargs.ylim is not None: ax.set_ylim(plot_kwargs.ylim) - def _save_fig(self, cluster_number, extra_title="", format="png"): + def _save_fig(self, fig, cluster_number, extra_title="", format="png"): title = f"{cluster_number}_{extra_title}" - plt.savefig(title, format=format) + fig.savefig(title + "." + format, format=format) + diff --git a/src/spikeanalysis/spike_analysis.py b/src/spikeanalysis/spike_analysis.py index f6fec26..f61f977 100644 --- a/src/spikeanalysis/spike_analysis.py +++ b/src/spikeanalysis/spike_analysis.py @@ -414,6 +414,8 @@ def get_raw_firing_rate( self.fr_bins[stim] = bins[fr_window_values] self.mean_firing_rate = final_fr + + def zscore_data(self, time_bin_ms, bsl_window, z_window, eps:float=0): self.z_score_data(time_bin_ms=time_bin_ms, bsl_window=bsl_window, z_window=z_window, eps=eps) diff --git a/src/spikeanalysis/spike_plotter.py b/src/spikeanalysis/spike_plotter.py index 95ad027..8fbd4a6 100644 --- a/src/spikeanalysis/spike_plotter.py +++ b/src/spikeanalysis/spike_plotter.py @@ -322,6 +322,8 @@ def _plot_scores( if len(np.shape(sorted_z_scores)) == 2: sorted_z_scores = np.expand_dims(sorted_z_scores, axis=1) + # at baseline we need to eliminate cases of nan's, infinities, and 0's (if all the way across a stimulus) + nan_mask = np.any( np.any(np.isnan(sorted_z_scores) | np.isinf(sorted_z_scores), axis=2) | np.all(np.equal(sorted_z_scores, 0), axis=2), @@ -444,8 +446,14 @@ def _plot_scores( ) plt.figure(dpi=plot_kwargs.dpi) if plot_kwargs.save and plot_kwargs.title is not None: - self._save_fig(plot_kwargs.title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) - elif plot_kwargs.title is None: + self._save_fig( + fig=fig, + cluster_number=plot_kwargs.title + str(stimulus), + extra_title=plot_kwargs.extra_title, + format=plot_kwargs.format, + ) + elif plot_kwargs.save and plot_kwargs.title is None: + print("give title to save heat map") plt.show() @@ -603,7 +611,8 @@ def plot_raster( ) plt.figure(dpi=plot_kwargs.dpi) if plot_kwargs.save: - self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + self._save_fig(fig, title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + plt.show() def plot_sm_fr( @@ -786,7 +795,8 @@ def plot_sm_fr( plt.figure(dpi=plot_kwargs.dpi) if plot_kwargs.save: - self._save_fig(title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + self._save_fig(fig, title, extra_title=plot_kwargs.extra_title, format=plot_kwargs.format) + plt.show() def plot_zscores_ind(self, z_bar: Optional[list[int]] = None, show_stim: bool = True):