From 1edde66ed07c1cb62cf1c20c426633446c9fc2d1 Mon Sep 17 00:00:00 2001 From: healthonrails Date: Fri, 18 Oct 2024 17:03:56 -0400 Subject: [PATCH] feat: add behavior data loading method with event timestamp processing - Implemented the method to load behavior data from a CSV file into a pandas DataFrame. - The method processes each row to extract timestamps, events, and calculates frame numbers based on recording time and fps. - Mark types are determined based on the event description, distinguishing between 'event_start' and 'event_end'. - Data is stored in the with frame number and mark type as keys, and relevant behavior, subject, and trial time information as values. --- annolid/gui/app.py | 66 +++++++++++++++++++++++++++++++++++++--------- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/annolid/gui/app.py b/annolid/gui/app.py index 6c90eee..be9ee3a 100644 --- a/annolid/gui/app.py +++ b/annolid/gui/app.py @@ -583,10 +583,10 @@ def __init__(self, self._selectAiModelComboBox.clear() self.custom_ai_model_names = [ 'SAM_HQ', 'Cutie', "EfficientVit_SAM", - "CoTracker", - "sam2_hiera_s", - "sam2_hiera_l", - ] + "CoTracker", + "sam2_hiera_s", + "sam2_hiera_l", + ] model_names = [model.name for model in MODELS] + \ self.custom_ai_model_names self._selectAiModelComboBox.addItems(model_names) @@ -1821,8 +1821,9 @@ def add_highlighted_mark(self, val=None, if init_load or (val, mark_type) not in self.timestamp_dict: highlighted_mark = VideoSliderMark(mark_type=mark_type, val=val, _color=color) - self.timestamp_dict[(val, mark_type) - ] = self._time_stamp if self._time_stamp else convert_frame_number_to_time(val) + if (val, mark_type) not in self.timestamp_dict: + self.timestamp_dict[(val, mark_type) + ] = self._time_stamp if self._time_stamp else convert_frame_number_to_time(val) self.seekbar.addMark(highlighted_mark) return highlighted_mark @@ -1961,6 +1962,8 @@ def load_tracking_results(self, cur_video_folder, video_filename): if 'timestamp' in tr.name and video_name in tr.name: self._load_timestamps(tr) + if tr.name.endswith(f"{video_name}.csv"): + self._load_behavior(tr) if 'tracking' in tr.name and video_name in tr.name and '_nix' not in tr.name: tracking_csv_file = tr @@ -1991,6 +1994,34 @@ def _load_timestamps(self, timestamp_csv_file): self.timestamp_dict[(frame_number, mark_type)] = timestamp + def _load_behavior(self, behavior_csv_file: str) -> None: + """Loads behavior data from a CSV file and stores it in timestamp_dict. + + Args: + behavior_csv_file (str): Path to the CSV file containing behavior data. + """ + # Load the CSV file into a DataFrame + df_behaviors = pd.read_csv(behavior_csv_file) + + # Iterate through each row of the DataFrame + for _, row in df_behaviors.iterrows(): + timestamp: float = row["Recording time"] + event: str = row["Event"] + + # Calculate the frame number based on timestamp and fps + frame_number: int = int(float(timestamp) * self.fps) + + # Determine the type of event (start or end) + mark_type: str = 'event_start' if 'start' in event.lower() else 'event_end' + + # Store the relevant data in the timestamp_dict + self.timestamp_dict[(frame_number, mark_type)] = ( + timestamp, + row['Behavior'], + row["Subject"], + row["Trial time"] + ) + def _load_labels(self, labels_csv_file): """Load labels from the given CSV file.""" self._df = pd.read_csv(labels_csv_file) @@ -2037,11 +2068,6 @@ def openVideo(self, _value=False): if video_filename: cur_video_folder = Path(video_filename).parent - # go over all the tracking csv files - # use the first matched file with video name - # and segmentation - self.load_tracking_results(cur_video_folder, video_filename) - self.video_results_folder = Path(video_filename).with_suffix('') self.video_results_folder.mkdir( @@ -2108,6 +2134,11 @@ def openVideo(self, _value=False): self.statusBar().addWidget(self.playButton) self.statusBar().addWidget(self.seekbar, stretch=1) self.statusBar().addWidget(self.saveButton) + # go over all the tracking csv files + # use the first matched file with video name + # and segmentation + self.load_tracking_results(cur_video_folder, video_filename) + if self.timestamp_dict: for frame_number, mark_type in self.timestamp_dict.keys(): self.add_highlighted_mark(val=frame_number, @@ -2164,6 +2195,17 @@ def image_to_canvas(self, qimage, filename, frame_number): prev_shapes = self.canvas.shapes self.canvas.loadPixmap(QtGui.QPixmap.fromImage(qimage)) flags = {k: False for k in self._config["flags"] or []} + _event_key = (frame_number, 'event_start') + if _event_key not in self.timestamp_dict: + _event_key = (frame_number, 'event_end') + if _event_key in self.timestamp_dict: + try: + timestamp, behaivor, subject, trial_time = self.timestamp_dict[_event_key] + flags[behaivor] = True + except: + print(self.timestamp_dict[_event_key]) + + self.flag_widget.clear() self.loadFlags(flags) if self._config["keep_prev"] and self.noShapes(): self.loadShapes(prev_shapes, replace=False) @@ -2285,7 +2327,7 @@ def loadShapes(self, shapes, replace=True): def loadPredictShapes(self, frame_number, filename): label_json_file = str(filename).replace(".png", ".json") - #try to load json files generated by SAM2 like 000000000.json + # try to load json files generated by SAM2 like 000000000.json if not Path(label_json_file).exists(): label_json_file = os.path.join(os.path.dirname(label_json_file), os.path.basename(label_json_file).split('_')[-1])