Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyvista transition #14

Merged
merged 6 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 95 additions & 18 deletions caltrig/gui/exploration_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from PyQt5.QtGui import (QIntValidator, QDoubleValidator, QFont)
from pyqtgraph import (PlotItem, PlotCurveItem, ScatterPlotItem, InfiniteLine, TextItem)
import pyqtgraph as pg
import colorcet as cc
import numpy as np
from scipy.signal import find_peaks
from skimage.segmentation import flood_fill
Expand All @@ -19,11 +20,12 @@
from ..core.exploration_statistics import (GeneralStatsWidget, LocalStatsWidget, MetricsWidget)
from .pyqtgraph_override import ImageViewOverride
from .cofiring_2d_widgets import Cofiring2DWidget
from .sda_widgets import (MayaviQWidget, base_visualization)
from .sda_widgets import (base_visualization, PyVistaWidget)
import os
import matplotlib.pyplot as plt
import pickle
from mayavi.core.lut_manager import lut_mode_list
from concurrent.futures import ProcessPoolExecutor
import time


try:
Expand Down Expand Up @@ -58,23 +60,21 @@ def __init__(self, session, name, main_window_ref, timestamps=None, parent=None)
self.windows = {}
self.single_plot_options = {"enabled": False, "inter_distance": 0, "intra_distance": 0}


# Initialize executor to load next chunks in the background
self.executor = ProcessPoolExecutor(max_workers=4)

# Set up main view
pg.setConfigOptions(imageAxisOrder='row-major')
self.imv_cell = ImageViewOverride()
self.imv_behavior = ImageViewOverride()
self.imv_behavior.setVisible(False)
self.visualization_3D = MayaviQWidget(self.session)
self.visualization_3D.setVisible(False)
self.visualization_3D.point_signal.connect(self.point_selection)


self.session.load_videos()
if not self.session.video_data:
print("Missing Videos")
return None
self.current_video = self.session.video_data["varr"]
self.current_video_serialized = pickle.dumps(self.current_video)
self.video_length = self.current_video.shape[0]
self.mask = np.ones((self.current_video.shape[1], self.current_video.shape[2]))
# We need two seperate masks here. One for the missed cells we confirmed and one for drawing a new missed cell
Expand All @@ -85,6 +85,10 @@ def __init__(self, session, name, main_window_ref, timestamps=None, parent=None)
if "behavior_video" in self.session.video_data:
self.imv_behavior.setImage(self.session.video_data["behavior_video"].sel(frame=self.current_frame).values)

self.visualization_3D = PyVistaWidget(self.session, self.executor)
self.visualization_3D.setVisible(False)
#self.visualization_3D.point_signal.connect(self.point_selection)

# Add Context Menu Action
self.video_to_title = {"varr": "Original", "Y_fm_chk": "Processed"}
self.submenu_videos = self.imv_cell.getView().menu.addMenu('&Video Format')
Expand Down Expand Up @@ -739,9 +743,11 @@ def __init__(self, session, name, main_window_ref, timestamps=None, parent=None)
layout_colormap = QHBoxLayout()
layout_colormap.addWidget(QLabel("Colormap:"))
dropdown_3D_colormap = QComboBox()
dropdown_3D_colormap.addItems(lut_mode_list())
dropdown_3D_colormap.addItems(cc.cm.keys())
# Set index to whatever hot is
dropdown_3D_colormap.setCurrentIndex(lut_mode_list().index("hot"))
name_cmap = "linear_bmy_10_95_c71"
dropdown_3D_colormap.setCurrentIndex(list(cc.cm.keys()).index(name_cmap))
self.visualization_3D.change_colormap(name_cmap)
dropdown_3D_colormap.currentIndexChanged.connect(lambda: self.visualization_3D.change_colormap(dropdown_3D_colormap.currentText()))
layout_colormap.addWidget(dropdown_3D_colormap)

Expand Down Expand Up @@ -1515,7 +1521,7 @@ def visualize_3D(self):
smoothing_size=smoothing_size, smoothing_type=smoothing_type, window_size=window_size, normalize=normalize, average=average, cumulative=cumulative)

self.cofiring_chkbox.setChecked(False)
self.visualization_3D.remove_cofiring()
#self.visualization_3D.remove_cofiring()

def check_if_results_exist(self):
idx_to_cells = {"0":"1", "1":"2", "2":"5", "3":"10", "4":"15", "5":"20"}
Expand Down Expand Up @@ -2648,35 +2654,98 @@ def prev_frame(self):
if bimage is not None:
self.imv_behavior.setImage(bimage, autoRange=False, autoLevels=False)


def create_next_chunk_image(self, current_video_serialized, start, end):
# Submit the task to the pool and return a future
return self.executor.submit(load_next, current_video_serialized, start, end)

def check_preload_image(self):
chunk_length = self.current_video.chunks[0][0] * 10
chunk_length = self.current_video.chunks[0][0]

if self.pre_images is None:
# Check which chunk the current frame is in
chunk_idx = self.current_frame // chunk_length
self.pre_images = self.current_video.sel(frame=slice(chunk_idx*chunk_length, (chunk_idx+1)*chunk_length)).load()

self.pre_images = self.current_video.sel(
frame=slice(chunk_idx * chunk_length, (chunk_idx + 1) * chunk_length)
).load()
self.next_images_future = self.create_next_chunk_image(
self.current_video_serialized, (chunk_idx + 1) * chunk_length, (chunk_idx + 2) * chunk_length
)
else:
frames = self.pre_images.coords["frame"].values

if frames[0] <= self.current_frame <= frames[-1]:
return

elif frames[0] + chunk_length <= self.current_frame <= frames[-1] + chunk_length:
chunk_idx = self.current_frame // chunk_length

# Check if the future is ready
while not self.next_images_future.done():
time.sleep(0.01)

self.pre_images = self.next_images_future.result()

# Start loading the next chunk in the background
self.next_images_future = self.create_next_chunk_image(
self.current_video_serialized, (chunk_idx + 1) * chunk_length, (chunk_idx + 2) * chunk_length
)
else:
# Load the current and next chunks synchronously if the jump is large
chunk_idx = self.current_frame // chunk_length
self.pre_images = self.current_video.sel(frame=slice(chunk_idx*chunk_length, (chunk_idx+1)*chunk_length)).load()
self.pre_images = self.current_video.sel(
frame=slice(chunk_idx * chunk_length, (chunk_idx + 1) * chunk_length)
).load()
self.next_images_future = self.create_next_chunk_image(
self.current_video_serialized, (chunk_idx + 1) * chunk_length, (chunk_idx + 2) * chunk_length
)

def check_preload_bimage(self, current_frame):
chunk_length = self.session.video_data["behavior_video"].chunks[0][0]

if self.pre_bimages is None:
# Check which chunk the current frame is in
chunk_idx = current_frame // chunk_length
self.pre_bimages = self.session.video_data["behavior_video"].sel(frame=slice(chunk_idx*chunk_length, (chunk_idx+1)*chunk_length)).load()

self.pre_bimages = self.session.video_data["behavior_video"].sel(
frame=slice(chunk_idx * chunk_length, (chunk_idx + 1) * chunk_length)
).load()
self.next_bimages_future = self.create_next_chunk_image(
self.session.video_data["behavior_video"],
(chunk_idx + 1) * chunk_length,
(chunk_idx + 2) * chunk_length
)
else:
frames = self.pre_bimages.coords["frame"].values

if frames[0] <= current_frame <= frames[-1]:
return

elif frames[0] + chunk_length <= current_frame <= frames[-1] + chunk_length:
chunk_idx = current_frame // chunk_length

# Check if the future is ready
while not self.next_bimages_future.done():
time.sleep(0.01)

self.pre_bimages = self.next_bimages_future.result()

# Start loading the next chunk in the background
self.next_bimages_future = self.create_next_chunk_image(
self.session.video_data["behavior_video"],
(chunk_idx + 1) * chunk_length,
(chunk_idx + 2) * chunk_length
)
else:
# Load the current and next chunks synchronously if the jump is large
chunk_idx = current_frame // chunk_length
self.pre_bimages = self.session.video_data["behavior_video"].sel(frame=slice(chunk_idx*chunk_length, (chunk_idx+1)*chunk_length)).load()
self.pre_bimages = self.session.video_data["behavior_video"].sel(
frame=slice(chunk_idx * chunk_length, (chunk_idx + 1) * chunk_length)
).load()
self.next_bimages_future = self.create_next_chunk_image(
self.session.video_data["behavior_video"],
(chunk_idx + 1) * chunk_length,
(chunk_idx + 2) * chunk_length
)



Expand Down Expand Up @@ -2711,6 +2780,7 @@ def generate_image(self):
if self.chkbox_3D.isChecked():
self.visualization_3D.set_frame(self.current_frame)


return image, bimage

def refresh_image(self):
Expand Down Expand Up @@ -2831,6 +2901,7 @@ def toggle_videos(self):

def change_cell_video(self, type):
self.current_video = self.session.video_data[type]
self.current_video_serialized = pickle.dumps(self.current_video)
self.pre_images = None
for action in self.submenu_videos.actions():
if action.text() == "&Behavior Video":
Expand Down Expand Up @@ -3202,4 +3273,10 @@ def clicked(self, _, event):
self.main_plot.add_selection(self)
else:
self.setPen('r')
self.main_plot.remove_selection(self)
self.main_plot.remove_selection(self)

def load_next(current_video_serialized, start, end):
"""Load the next chunk and store the result in a shared dictionary."""
video = pickle.loads(current_video_serialized)
return video.sel(frame=slice(start, end)).load()

Loading