Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getter Docstrings #1185

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

## Improvements
* Simple writing no longer uses a context manager [PR #1180](https://github.com/catalystneuro/neuroconv/pull/1180)
* Added Returns section to all getter docstrings [PR #1185](https://github.com/catalystneuro/neuroconv/pull/1185)


# v0.6.7 (January 20, 2025)
Expand Down
29 changes: 25 additions & 4 deletions src/neuroconv/basedatainterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,12 +72,26 @@ def __init__(self, verbose: bool = False, **source_data):
self.source_data = source_data

def get_metadata_schema(self) -> dict:
"""Retrieve JSON schema for metadata."""
"""
Retrieve JSON schema for metadata.

Returns
-------
dict
The JSON schema defining the metadata structure.
"""
metadata_schema = load_dict_from_file(Path(__file__).parent / "schemas" / "base_metadata_schema.json")
return metadata_schema

def get_metadata(self) -> DeepDict:
"""Child DataInterface classes should override this to match their metadata."""
"""
Child DataInterface classes should override this to match their metadata.

Returns
-------
DeepDict
The metadata dictionary containing basic NWBFile metadata.
"""
metadata = DeepDict()
metadata["NWBFile"]["session_description"] = ""
metadata["NWBFile"]["identifier"] = str(uuid.uuid4())
Expand Down Expand Up @@ -105,7 +119,14 @@ def validate_metadata(self, metadata: dict, append_mode: bool = False) -> None:
validate(instance=decoded_metadata, schema=metdata_schema)

def get_conversion_options_schema(self) -> dict:
"""Infer the JSON schema for the conversion options from the method signature (annotation typing)."""
"""
Infer the JSON schema for the conversion options from the method signature (annotation typing).

Returns
-------
dict
The JSON schema for the conversion options.
"""
return get_json_schema_from_method_signature(self.add_to_nwbfile, exclude=["nwbfile", "metadata"])

def create_nwbfile(self, metadata: Optional[dict] = None, **conversion_options) -> NWBFile:
Expand Down Expand Up @@ -249,7 +270,7 @@ def get_default_backend_configuration(

Returns
-------
backend_configuration : HDF5BackendConfiguration or ZarrBackendConfiguration
Union[HDF5BackendConfiguration, ZarrBackendConfiguration]
The default configuration for the specified backend type.
"""
return get_default_backend_configuration(nwbfile=nwbfile, backend=backend)
8 changes: 8 additions & 0 deletions src/neuroconv/baseextractorinterface.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ class BaseExtractorInterface(BaseTemporalAlignmentInterface, ABC):

@classmethod
def get_extractor(cls):
"""
Get the extractor class for this interface.

Returns
-------
type
The extractor class that will be used to read the data.
"""
if cls.Extractor is not None:
return cls.Extractor
extractor_module = get_package(package_name=cls.ExtractorModuleName)
Expand Down
83 changes: 78 additions & 5 deletions src/neuroconv/datainterfaces/behavior/video/video_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def get_video_timestamps(file_path: FilePath, max_frames: Optional[int] = None,
The path to a multimedia video file
max_frames : Optional[int], optional
If provided, extract the timestamps of the video only up to max_frames.
display_progress : bool, default: True
Whether to display a progress bar during timestamp extraction.

Returns
-------
Expand All @@ -45,7 +47,24 @@ def __init__(self, file_path: FilePath):
self._video_open_msg = "The video file is not open!"

def get_video_timestamps(self, max_frames: Optional[int] = None, display_progress: bool = True):
"""Return numpy array of the timestamps(s) for a video file."""
"""
Return numpy array of the timestamps(s) for a video file.

Parameters
----------
max_frames : Optional[int], optional
If provided, extract the timestamps of the video only up to max_frames.
display_progress : bool, default: True
Whether to display a progress bar during timestamp extraction.

Returns
-------
numpy.ndarray
Array of timestamps in seconds, representing the time from the start
of the video for each frame. Timestamps are extracted from the video's
metadata using cv2.CAP_PROP_POS_MSEC and converted from milliseconds
to seconds.
"""
cv2 = get_package(package_name="cv2", installation_instructions="pip install opencv-python-headless")

timestamps = []
Expand All @@ -65,13 +84,27 @@ def get_video_timestamps(self, max_frames: Optional[int] = None, display_progres
return np.array(timestamps) / 1000

def get_video_fps(self):
"""Return the internal frames per second (fps) for a video file."""
"""
Return the internal frames per second (fps) for a video file.

Returns
-------
float
The frames per second of the video.
"""
assert self.isOpened(), self._video_open_msg
prop = self.get_cv_attribute("CAP_PROP_FPS")
return self.vc.get(prop)

def get_frame_shape(self) -> Tuple:
"""Return the shape of frames from a video file."""
"""
Return the shape of frames from a video file.

Returns
-------
Tuple
The shape of the video frames (height, width, channels).
"""
frame = self.get_video_frame(0)
if frame is not None:
return frame.shape
Expand All @@ -91,6 +124,14 @@ def frame_count(self, val: int):
self._frame_count = val

def get_video_frame_count(self):
"""
Get the total number of frames in the video.

Returns
-------
int
The total number of frames in the video.
"""
return self.frame_count

def _video_frame_count(self):
Expand All @@ -101,6 +142,19 @@ def _video_frame_count(self):

@staticmethod
def get_cv_attribute(attribute_name: str):
"""
Get an OpenCV attribute by name.

Parameters
----------
attribute_name : str
The name of the OpenCV attribute to get.

Returns
-------
Any
The OpenCV attribute value.
"""
cv2 = get_package(package_name="cv2", installation_instructions="pip install opencv-python-headless")

if int(cv2.__version__.split(".")[0]) < 3: # pragma: no cover
Expand All @@ -122,7 +176,19 @@ def current_frame(self, frame_number: int):
raise ValueError(f"Could not set frame number (received {frame_number}).")

def get_video_frame(self, frame_number: int):
"""Return the specific frame from a video as an RGB colorspace."""
"""
Return the specific frame from a video as an RGB colorspace.

Parameters
----------
frame_number : int
The index of the frame to retrieve.

Returns
-------
numpy.ndarray
The video frame in RGB colorspace with shape (height, width, 3).
"""
assert self.isOpened(), self._video_open_msg
assert frame_number < self.get_video_frame_count(), "frame number is greater than length of video"
initial_frame_number = self.current_frame
Expand All @@ -132,7 +198,14 @@ def get_video_frame(self, frame_number: int):
return np.flip(frame, 2) # np.flip to re-order color channels to RGB

def get_video_frame_dtype(self):
"""Return the dtype for frame in a video file."""
"""
Return the dtype for frame in a video file.

Returns
-------
numpy.dtype
The data type of the video frames.
"""
frame = self.get_video_frame(0)
if frame is not None:
return frame.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def get_timing_type(self) -> Literal["starting_time and rate", "timestamps"]:

Returns
-------
timing_type : 'starting_time and rate' or 'timestamps'
Literal["starting_time and rate", "timestamps"]
The type of timing that has been set explicitly according to alignment.

If only timestamps have been set, then only those will be used.
Expand Down
24 changes: 15 additions & 9 deletions src/neuroconv/datainterfaces/ecephys/axona/axona_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def get_eeg_sampling_frequency(file_path: FilePath) -> float:
Returns
-------
float
Sampling frequency
The sampling frequency in Hz extracted from the file header's sample_rate field.
"""
Fs_entry = parse_generic_header(file_path, ["sample_rate"])
Fs = float(Fs_entry.get("sample_rate").split(" ")[0])
Expand Down Expand Up @@ -76,8 +76,9 @@ def get_all_file_paths(file_path: FilePath) -> list:

Returns
-------
path_list : list
List of file_paths
list
List of file paths for all .eeg or .egf files found in the same directory
as the input file path.
"""

suffix = Path(file_path).suffix[0:4]
Expand Down Expand Up @@ -183,8 +184,9 @@ def get_header_bstring(file: FilePath) -> bytes:

Returns
-------
str
header byte content
bytes
The header content as bytes, including everything from the start of the file
up to and including the 'data_start' marker.
"""
header = b""
with open(file, "rb") as f:
Expand Down Expand Up @@ -365,14 +367,18 @@ def get_position_object(file_path: FilePath) -> Position:
be preferred to read position data from the `.bin` file to ensure
samples are locked to ecephys time courses.

Parameters:
Parameters
----------
file_path (Path or Str):
file_path : Path or str
Full file_path of Axona file with any extension.

Returns:
Returns
-------
position: pynwb.behavior.Position
pynwb.behavior.Position
Position object containing multiple SpatialSeries, one for each position
channel (time, X, Y, x, y, PX, px, px_total, unused). Each series contains
timestamps and corresponding position data. The timestamps are in milliseconds
and are aligned to the start of raw acquisition if reading from a .bin file.
"""
position = Position()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,15 @@ def __init__(self, verbose: bool = False, es_key: str = "ElectricalSeries", **so
self._number_of_segments = self.recording_extractor.get_num_segments()

def get_metadata_schema(self) -> dict:
"""Compile metadata schema for the RecordingExtractor."""
"""
Compile metadata schema for the RecordingExtractor.

Returns
-------
dict
The metadata schema dictionary containing definitions for Device, ElectrodeGroup,
Electrodes, and optionally ElectricalSeries.
"""
metadata_schema = super().get_metadata_schema()
metadata_schema["properties"]["Ecephys"] = get_base_schema(tag="Ecephys")
metadata_schema["properties"]["Ecephys"]["required"] = ["Device", "ElectrodeGroup"]
Expand Down Expand Up @@ -86,6 +94,15 @@ def get_metadata_schema(self) -> dict:
return metadata_schema

def get_metadata(self) -> DeepDict:
"""
Get metadata for the recording extractor.

Returns
-------
DeepDict
Dictionary containing metadata including device information, electrode groups,
and electrical series configuration.
"""
metadata = super().get_metadata()

from ...tools.spikeinterface.spikeinterface import _get_group_name
Expand Down Expand Up @@ -250,8 +267,8 @@ def has_probe(self) -> bool:

Returns
-------
has_probe : bool
True if the recording extractor has probe information.
bool
True if the recording extractor has probe information, False otherwise.
"""
return self.recording_extractor.has_probe()

Expand All @@ -274,6 +291,12 @@ def subset_recording(self, stub_test: bool = False):
Parameters
----------
stub_test : bool, default: False
If True, only a subset of frames will be included.

Returns
-------
spikeinterface.core.BaseRecording
The subsetted recording extractor.
"""
from spikeinterface.core.segmentutils import AppendSegmentRecording

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,15 @@ def __init__(self, verbose: bool = False, **source_data):
self._number_of_segments = self.sorting_extractor.get_num_segments()

def get_metadata_schema(self) -> dict:
"""Compile metadata schema for the RecordingExtractor."""
"""
Compile metadata schema for the RecordingExtractor.

Returns
-------
dict
The metadata schema dictionary containing definitions for Device, ElectrodeGroup,
Electrodes, and UnitProperties.
"""

# Initiate Ecephys metadata
metadata_schema = super().get_metadata_schema()
Expand Down Expand Up @@ -85,6 +93,20 @@ def get_original_timestamps(self) -> np.ndarray:
)

def get_timestamps(self) -> Union[np.ndarray, list[np.ndarray]]:
"""
Get the timestamps for the sorting data.

Returns
-------
numpy.ndarray or list of numpy.ndarray
The timestamps for each spike in the sorting data. If there are multiple segments,
returns a list of timestamp arrays.

Raises
------
NotImplementedError
If no recording is attached to the sorting extractor.
"""
if not self.sorting_extractor.has_recording():
raise NotImplementedError(
"In order to align timestamps for a SortingInterface, it must have a recording "
Expand Down
Loading
Loading