diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index dcc9d8e..4f170a1 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -164,6 +164,10 @@ def __init__( step=0.01, description="Bin size:", ) + self.flip_y_axis = widgets.Checkbox( + value=False, + description="Flip y-axis", + ) spatial_series_label = widgets.Label("Spatial Series:") top_panel = widgets.VBox( [ @@ -174,6 +178,7 @@ def __init__( [ spatial_series_label, self.spatial_series_selector, + self.flip_y_axis ] ), widgets.HBox([self.smoothing_slider, self.bin_size_slider]), @@ -311,13 +316,18 @@ def show_unit_rate_maps( else: legend_kwargs.update(bbox_to_anchor=(1.01, 1)) if self.rate_maps is not None: + origin = "lower" if self.flip_y_axis.value else "upper" axs[0].imshow(self.rate_maps[unit_index], cmap="viridis", origin="lower", aspect="auto", extent=self.extent) axs[0].set_xlabel("x") axs[0].set_ylabel("y") else: axs[0].set_title("Rate maps not computed (spatial_maps not installed)") - axs[1].plot(self.nap_position["y"], self.nap_position["x"], color="grey") + tracking_x = self.nap_position["y"] + tracking_y = self.nap_position["x"] + if self.flip_y_axis.value: + tracking_y = 1 - tracking_y + axs[1].plot(tracking_x, tracking_y, color="grey") spk_pos = self.nap_units[unit_index].value_from(self.nap_position) axs[1].plot(spk_pos["y"], spk_pos["x"], "o", color="red", markersize=5, alpha=0.5) axs[1].set_xlabel("x") diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index b6b4a43..65a97c3 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -104,9 +104,10 @@ def __init__(self, project): load_from_phy = ipywidgets.Button(description="Load from Phy", layout={"width": "initial"}) load_from_phy.style.button_color = "pink" + self.restore_phy_clicked = False restore_phy = ipywidgets.ToggleButton( value=False, - description="Restore", + description="Restore (click twice to restore)", disabled=False, button_style="", # 'success', 'info', 'warning', 'danger' or '' tooltip="Restore unsorted clusters", @@ -205,6 +206,14 @@ def on_action(change): sorters = [p.name for p in si_path.iterdir() if p.is_dir()] sorter_list.options = sorters if len(sorter_list.value) == 1: + units_raw = self.sorting_curator.load_raw_units(sorter_list.value[0]) + if units_raw is not None: + w = nwb2widget(units_raw, custom_raw_unit_vis) + units_viewers["raw"] = w + units_main = self.sorting_curator.load_main_units() + if units_main is not None: + w = nwb2widget(units_main, custom_main_unit_vis) + units_viewers["main"] = w if strategy.value == "Sortingview": sv_visualization_link.value = self.sorting_curator.get_sortingview_link(sorter_list.value[0]) elif strategy.value == "Phy": @@ -212,6 +221,7 @@ def on_action(change): units_dropdown.value = "Raw" on_choose_units(None) + def on_sorter(change): required_values_filled(actions_list) if len(sorter_list.value) > 1: @@ -313,7 +323,15 @@ def on_restore_phy(change): if len(sorter_list.value) > 1: print("Select one spike sorting output at a time") else: - self.sorting_curator.restore_phy(sorter_list.value[0]) + if self.restore_phy_clicked: + self.sorting_curator.restore_phy(sorter_list.value[0]) + self.restore_phy_clicked = False + restore_phy.description = "Restore (click twice to restore)" + restore_phy.button_style = "primary" + else: + self.restore_phy_clicked = True + restore_phy.description = "Restore (click again to confirm)" + restore_phy.button_style = "danger" @self.output.capture() def on_apply_qm_curation(change):