Skip to content

Commit

Permalink
feat: add behavior data loading method with event timestamp processing
Browse files Browse the repository at this point in the history
- Implemented the  method to load behavior data from a CSV file into a pandas DataFrame.
- The method processes each row to extract timestamps, events, and calculates frame numbers based on recording time and fps.
- Mark types are determined based on the event description, distinguishing between 'event_start' and 'event_end'.
- Data is stored in the  with frame number and mark type as keys, and relevant behavior, subject, and trial time information as values.
  • Loading branch information
healthonrails committed Oct 18, 2024
1 parent 47b3afe commit 1edde66
Showing 1 changed file with 54 additions and 12 deletions.
66 changes: 54 additions & 12 deletions annolid/gui/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,10 +583,10 @@ def __init__(self,
self._selectAiModelComboBox.clear()
self.custom_ai_model_names = [
'SAM_HQ', 'Cutie', "EfficientVit_SAM",
"CoTracker",
"sam2_hiera_s",
"sam2_hiera_l",
]
"CoTracker",
"sam2_hiera_s",
"sam2_hiera_l",
]
model_names = [model.name for model in MODELS] + \
self.custom_ai_model_names
self._selectAiModelComboBox.addItems(model_names)
Expand Down Expand Up @@ -1821,8 +1821,9 @@ def add_highlighted_mark(self, val=None,
if init_load or (val, mark_type) not in self.timestamp_dict:
highlighted_mark = VideoSliderMark(mark_type=mark_type,
val=val, _color=color)
self.timestamp_dict[(val, mark_type)
] = self._time_stamp if self._time_stamp else convert_frame_number_to_time(val)
if (val, mark_type) not in self.timestamp_dict:
self.timestamp_dict[(val, mark_type)
] = self._time_stamp if self._time_stamp else convert_frame_number_to_time(val)
self.seekbar.addMark(highlighted_mark)
return highlighted_mark

Expand Down Expand Up @@ -1961,6 +1962,8 @@ def load_tracking_results(self, cur_video_folder, video_filename):

if 'timestamp' in tr.name and video_name in tr.name:
self._load_timestamps(tr)
if tr.name.endswith(f"{video_name}.csv"):
self._load_behavior(tr)

if 'tracking' in tr.name and video_name in tr.name and '_nix' not in tr.name:
tracking_csv_file = tr
Expand Down Expand Up @@ -1991,6 +1994,34 @@ def _load_timestamps(self, timestamp_csv_file):

self.timestamp_dict[(frame_number, mark_type)] = timestamp

def _load_behavior(self, behavior_csv_file: str) -> None:
"""Loads behavior data from a CSV file and stores it in timestamp_dict.
Args:
behavior_csv_file (str): Path to the CSV file containing behavior data.
"""
# Load the CSV file into a DataFrame
df_behaviors = pd.read_csv(behavior_csv_file)

# Iterate through each row of the DataFrame
for _, row in df_behaviors.iterrows():
timestamp: float = row["Recording time"]
event: str = row["Event"]

# Calculate the frame number based on timestamp and fps
frame_number: int = int(float(timestamp) * self.fps)

# Determine the type of event (start or end)
mark_type: str = 'event_start' if 'start' in event.lower() else 'event_end'

# Store the relevant data in the timestamp_dict
self.timestamp_dict[(frame_number, mark_type)] = (
timestamp,
row['Behavior'],
row["Subject"],
row["Trial time"]
)

def _load_labels(self, labels_csv_file):
"""Load labels from the given CSV file."""
self._df = pd.read_csv(labels_csv_file)
Expand Down Expand Up @@ -2037,11 +2068,6 @@ def openVideo(self, _value=False):

if video_filename:
cur_video_folder = Path(video_filename).parent
# go over all the tracking csv files
# use the first matched file with video name
# and segmentation
self.load_tracking_results(cur_video_folder, video_filename)

self.video_results_folder = Path(video_filename).with_suffix('')

self.video_results_folder.mkdir(
Expand Down Expand Up @@ -2108,6 +2134,11 @@ def openVideo(self, _value=False):
self.statusBar().addWidget(self.playButton)
self.statusBar().addWidget(self.seekbar, stretch=1)
self.statusBar().addWidget(self.saveButton)
# go over all the tracking csv files
# use the first matched file with video name
# and segmentation
self.load_tracking_results(cur_video_folder, video_filename)

if self.timestamp_dict:
for frame_number, mark_type in self.timestamp_dict.keys():
self.add_highlighted_mark(val=frame_number,
Expand Down Expand Up @@ -2164,6 +2195,17 @@ def image_to_canvas(self, qimage, filename, frame_number):
prev_shapes = self.canvas.shapes
self.canvas.loadPixmap(QtGui.QPixmap.fromImage(qimage))
flags = {k: False for k in self._config["flags"] or []}
_event_key = (frame_number, 'event_start')
if _event_key not in self.timestamp_dict:
_event_key = (frame_number, 'event_end')
if _event_key in self.timestamp_dict:
try:
timestamp, behaivor, subject, trial_time = self.timestamp_dict[_event_key]
flags[behaivor] = True
except:
print(self.timestamp_dict[_event_key])

self.flag_widget.clear()
self.loadFlags(flags)
if self._config["keep_prev"] and self.noShapes():
self.loadShapes(prev_shapes, replace=False)
Expand Down Expand Up @@ -2285,7 +2327,7 @@ def loadShapes(self, shapes, replace=True):
def loadPredictShapes(self, frame_number, filename):

label_json_file = str(filename).replace(".png", ".json")
#try to load json files generated by SAM2 like 000000000.json
# try to load json files generated by SAM2 like 000000000.json
if not Path(label_json_file).exists():
label_json_file = os.path.join(os.path.dirname(label_json_file),
os.path.basename(label_json_file).split('_')[-1])
Expand Down

0 comments on commit 1edde66

Please sign in to comment.