láÿw;ôx´ù}¢óÛ‡ÒSÝ¿h/\çxб<¯%¸42©FN|¤uœSÿ; 6*´S÷¸
tô}Ì'ùIN5
\ No newline at end of file
+xÚ–Ýnë Çï÷•¶Û³i·çeNŠf0â£jß~@ª³å4Y—ÀMû/~±cã OÚ;…edÿ
+!èÑ
+Yœáà®
«Dxx?€Óo3õíùåðçI®/¿‡
ÉʨٶáH·³ŽG¾ˆÈ“µÁñIB°¢»Ñ’ká%zQL^ƒ6
+úÑž‘Ø¡ˆžÓxjX¡Ýx>Ph«´ìQËV¸ˆœó˜¥lb;Ù£á3Šò‚|÷Îi;ŠÝ
%9WÆ•{¶—¿‹¿V)’òGü!¢mÌÉè„Y›P“• 8"5'XÛnÐT6b+Í&ã®ed9M¬pµñ„QËý™Èû4EM÷?Š™úpùmr^™XÜ„…ˆ62‰A mÆv’A?â,ºÐuısÆ!‹ÑkG€™úpyn;ö*”¶5Ž„Ï
+Ø‘–6ÃFð]”}p‹Cd#kÈËÿ/ùÑ;&¨¸ftíšÉ•òe©Ž60îªÊBD³~ùOÿâ6
€H%¬e7¸3´Ú%Ø/_gøõn‡î6ÿt>±Bº«û7íÁÒ)t,OK n
Lª‘ ŸÂþ¥µŸSŸ; V*´S¿qèô{Ì'ITù
\ No newline at end of file
diff --git a/search/search_index.json b/search/search_index.json
index daaab21..c2779e9 100644
--- a/search/search_index.json
+++ b/search/search_index.json
@@ -1 +1 @@
-{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Introduction","text":" Chirpdetector
Detect communication signals of electric fish using deep neural networks \ud83d\udc1f\u26a1\ud83e\udde0 This project is still work in progress and will approximately be released in spring of 2024.
Why? \ud83e\udd28
Chirps are by far the most thoroughly researched communication signal of electric, probably even all fish. But detecting chirps becomes hard when more than one fish is recorded. As a result, most of the research to date analyzes this signal in isolated individuals. This is not good.
To tackle this isse, this package provides a simple toolbox to detect chirps of multiple fish on spectrograms. This enables true quantitative analyses of chirping between freely behaving fish for the first time.
"},{"location":"assingment/","title":"Assingment","text":"Wow, such empty
"},{"location":"contributing/","title":"Contributing","text":"We are thrilled to have you join in making this project even better. Please feel free to browse through the resources and guidelines provided here, and let us know if there is anything specific you would like to contribute or discuss.
If you would like to help to develop this package you can skim through the to-do list below as well as the contribution guidelines. Just fork the project, add your code and send a pull request. We are always happy to get some help !
If you encountered an issue using the chirpdetector
, feel free to open an issue here.
"},{"location":"contributing/#contributors-guidelines","title":"Contributors guidelines","text":"I try our best to adhere to good coding practices and catch up on writing tests for this package. As I am currently the only one working on it, here is some documentation of the development packages I use:
pre-commit
for pre-commit hooks pytest
and pytest-coverage
for unit tests ruff
for linting and formatting pyright
for static type checking
Before every commit, a pre-commit hook runs all these packages on the code base and refuses a push if errors are raised. If you want to contribute, please make sure that your code is proberly formatted and run the tests before issuing a pull request. The formatting guidelines should be automatically picked up by your ruff
installaton from the pyproject.toml
file.
"},{"location":"contributing/#to-do","title":"To Do","text":"After the first release, this section will be removed an tasks will be organized as github issues. Until them, if you fixed something, please check it off on this list before opening a pull request.
- Refactor train, detect, convert. All into much smaller functions. Move accesory functions to utils
- Move hardcoded params from assignment algo into config.toml
- Resolve all pylint and mypy errors and warnings.. and ruff warnings ... etc
- Fix make test, fails after ruff run
- Build github actions CI/CD pipeline for codecov etc.
- Move the dataconverter from
gridtools
to chirpdetector
- Extend the dataconverter to just output the spectrograms so that hand-labelling can be done in a separate step
- Add a main script so that the cli is
chirpdetector <task> --<flag> <args>
- Improve simulation of chirps to include more realistic noise, undershoot and maybe even phasic-tonic evolution of the frequency of the big chirps
- make the
copyconfig
script more - start writing the chirp assignment algorithm
- Move all the pprinting and logging constructors to a separate module and build a unified console object so that saving logs to file is easier, also log to file as well
- Split the messy training loop into functions
- Add label-studio
- Supply scripts to convert completely unannotated or partially annotated data to the label-studio format to make manual labeling easier
- Make possible to output detections as a yolo dataset
- Look up how to convert a yolo dataset to a label-studio input so we can label pre-annotated data, facilitating a full human-in-the-loop approach
- Add augmentation transforms to the dataset class and add augmentations to the simulation in
gridtools
. Note to this: Unnessecary, using real data. - Change bbox to actual yolo format, not the weird one I made up (which is x1, y1, x2, y2 instead of x1, y1, w, h). This is why the label-studio export is not working.
- Port cli to click, works better
- Try clustering the detected chirp windows on a spectrogram, could be interesting
"},{"location":"dataset/","title":"Creating a dataset","text":"Wow, such empty
"},{"location":"demo/","title":"Detecting chirps with a few terminal commands","text":"Once everything is set up correctly, detecting chirps is a breeze. The terminal utility can be called by chirpdetector
or simply cpd
.
Simply run
cpd detect --path \"/path/to/dataset\"\n
And the bounding boxes will be computed and saved to a .csv
file. Then run cpd assign --path \"/path/to/dataset\"\n
to assing each detected chirp to a fundamental frequency of a fish. The results will be added to the .csv
file in the dataset. To check if this went well, you can run cpd plot --path \"/path/to/dataset\"\n
And the spectrograms, bounding boxes, and assigned chirps of all the detected chirps will be plotted and saved as .png
images into a subfolder of your dataset. The result will look something like this:
15 seconds of a recording containing two chirping fish with bounding boxes around chirps and dots indicating to which frequency they are assigned to.
"},{"location":"detection/","title":"Detection","text":"Wow, such empty
"},{"location":"how_it_works/","title":"How it works","text":" How? \ud83e\udd14
Chirps manifest as excursions in the electric organ discharge frequency. To discern the individual chirps in a recording featuring multiple fish separated solely by frequency, we delve into the frequency domain. This involves the computation of spectrograms, ensuring ample temporal resolution for chirp distinction and sufficient frequency resolution for fish differentiation. The outcome is a series of images.
This framework facilitates the application of potent computer vision algorithms, such as a faster-R-CNN, for the detection of objects like chirps within these 'images.' Each chirp detection yields a bounding box, a motif echoed in the package's logo.
Post-processing steps refine the results, assigning chirp times to the fundamental frequencies of each fish captured in the recording.
Still not sold? Check out the demo \u00bb
"},{"location":"installation/","title":"Installation","text":"Wow, such empty
"},{"location":"labeling/","title":"Labeling a dataset","text":"Wow, such empty
"},{"location":"setup/","title":"Setup","text":"Wow, such empty
"},{"location":"training/","title":"Training","text":"Wow, such empty
"},{"location":"visualization/","title":"Visualization","text":"Wow, such empty
"},{"location":"yolo-helpers/","title":"Helper commands","text":"Wow, such empty
"},{"location":"api/assign_chirps/","title":"assign_chirps","text":"Assign chirps detected on a spectrogram to wavetracker tracks.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_chirps","title":"assign_chirps(assign_data, chirp_df, data)
","text":"Assign chirps to wavetracker tracks.
This function uses the extracted envelope troughs to assign chirps to tracks. It computes a cost function that is high when the trough prominence is high and the distance to the chirp center is low. For each chirp, the track with the highest cost function value is chosen.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_chirps--parameters","title":"Parameters","text":" assign_data
: dict
Dictionary containing the data needed for assignment chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data
Source code in chirpdetector/assign_chirps.py
def assign_chirps(\n assign_data: Dict[str, np.ndarray],\n chirp_df: pd.DataFrame,\n data: Dataset,\n) -> None:\n \"\"\"Assign chirps to wavetracker tracks.\n\n This function uses the extracted envelope troughs to assign chirps to\n tracks. It computes a cost function that is high when the trough prominence\n is high and the distance to the chirp center is low. For each chirp, the\n track with the highest cost function value is chosen.\n\n Parameters\n ----------\n - `assign_data`: `dict`\n Dictionary containing the data needed for assignment\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n \"\"\"\n # extract data from assign_data\n peak_prominences = assign_data[\"proms\"]\n peak_distances = assign_data[\"peaks\"]\n peak_times = assign_data[\"ptimes\"]\n chirp_indices = assign_data[\"cindices\"]\n track_ids = assign_data[\"track_ids\"]\n\n # compute cost function.\n # this function is high when the trough prominence is high\n # (-> chirp with high contrast)\n # and when the trough is close to the chirp center as detected by the\n # r-cnn (-> detected chirp is close to the actual chirp)\n cost = peak_prominences / peak_distances**2\n\n # set cost to zero for cases where no peak was found\n cost[np.isnan(cost)] = 0\n\n # for each chirp, choose the track where the cost is highest\n # TODO: to avoid confusion make a cost function where high is good and low\n # is bad. this is more like a \"gain function\"\n chosen_tracks = []\n chosen_track_times = []\n for idx in np.unique(chirp_indices):\n candidate_tracks = track_ids[chirp_indices == idx]\n candidate_costs = cost[chirp_indices == idx]\n candidate_times = peak_times[chirp_indices == idx]\n chosen_tracks.append(candidate_tracks[np.argmax(candidate_costs)])\n chosen_track_times.append(candidate_times[np.argmax(candidate_costs)])\n\n # store chosen tracks in chirp_df\n chirp_df[\"assigned_track\"] = chosen_tracks\n\n # store chirp time estimated from envelope trough in chirp_df\n chirp_df[\"envelope_trough_time\"] = chosen_track_times\n\n # save chirp_df\n chirp_df.to_csv(data.path / \"chirpdetector_bboxes.csv\", index=False)\n\n # save old format:\n np.save(data.path / \"chirp_ids_rcnn.npy\", chosen_tracks)\n np.save(data.path / \"chirp_times_rcnn.npy\", chosen_track_times)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_cli","title":"assign_cli(path)
","text":"Assign chirps to wavetracker tracks.
this is the command line interface for the assign_chirps function.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_cli--parameters","title":"Parameters","text":" path
: pathlib.path
path to the directory containing the chirpdetector.toml file
Source code in chirpdetector/assign_chirps.py
def assign_cli(path: pathlib.Path) -> None:\n \"\"\"Assign chirps to wavetracker tracks.\n\n this is the command line interface for the assign_chirps function.\n\n Parameters\n ----------\n - `path`: `pathlib.path`\n path to the directory containing the chirpdetector.toml file\n \"\"\"\n if not path.is_dir():\n msg = f\"{path} is not a directory\"\n raise ValueError(msg)\n\n if not (path / \"chirpdetector.toml\").is_file():\n msg = f\"{path} does not contain a chirpdetector.toml file\"\n raise ValueError(msg)\n\n logger = make_logger(__name__, path / \"chirpdetector.log\")\n # config = load_config(path / \"chirpdetector.toml\")\n recs = list(path.iterdir())\n recs = [r for r in recs if r.is_dir()]\n # recs = [path / \"subset_2020-03-18-10_34_t0_9320.0_t1_9920.0\"]\n\n msg = f\"found {len(recs)} recordings in {path}, starting assignment\"\n prog.console.log(msg)\n logger.info(msg)\n\n prog.console.rule(\"starting assignment\")\n with prog:\n task = prog.add_task(\"assigning chirps\", total=len(recs))\n for rec in recs:\n msg = f\"assigning chirps in {rec}\"\n logger.info(msg)\n prog.console.log(msg)\n\n data = load(rec)\n chirp_df = pd.read_csv(rec / \"chirpdetector_bboxes.csv\")\n assign_data, chirp_df, data = extract_assignment_data(\n data, chirp_df\n )\n assign_chirps(assign_data, chirp_df, data)\n prog.update(task, advance=1)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes","title":"bbox_to_chirptimes(chirp_df)
","text":"Convert chirp bboxes to chirp times.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
dataframe containing the chirp bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes--returns","title":"Returns","text":" chirp_df
: pd.dataframe
dataframe containing the chirp bboxes with chirp times.
Source code in chirpdetector/assign_chirps.py
def bbox_to_chirptimes(chirp_df: pd.DataFrame) -> pd.DataFrame:\n \"\"\"Convert chirp bboxes to chirp times.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n dataframe containing the chirp bboxes\n\n Returns\n -------\n - `chirp_df`: `pd.dataframe`\n dataframe containing the chirp bboxes with chirp times.\n \"\"\"\n chirp_df[\"chirp_times\"] = np.mean(chirp_df[[\"t1\", \"t2\"]], axis=1)\n\n return chirp_df\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes","title":"clean_bboxes(data, chirp_df)
","text":"Clean the chirp bboxes.
This is a collection of filters that remove bboxes that either overlap, are out of range or otherwise do not make sense.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes--parameters","title":"Parameters","text":" data
: gridtools.datasets.Dataset
Dataset object containing the data chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes--returns","title":"Returns","text":" chirp_df_tf
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def clean_bboxes(data: Dataset, chirp_df: pd.DataFrame) -> pd.DataFrame:\n \"\"\"Clean the chirp bboxes.\n\n This is a collection of filters that remove bboxes that\n either overlap, are out of range or otherwise do not make sense.\n\n Parameters\n ----------\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n\n Returns\n -------\n - `chirp_df_tf`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # non-max suppression: remove all chirp bboxes that overlap with\n # another more than threshold\n pick_indices = non_max_suppression_fast(chirp_df, 0.5)\n chirp_df_nms = chirp_df.loc[pick_indices, :]\n\n # track filter: remove all chirp bboxes that do not overlap with\n # the range spanned by the min and max of the wavetracker frequency tracks\n minf = np.min(data.track.freqs).astype(float)\n maxf = np.max(data.track.freqs).astype(float)\n # maybe add some more cleaning here, such\n # as removing chirps that are too short or too long\n return track_filter(chirp_df_nms, minf, maxf)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup","title":"cleanup(chirp_df, data)
","text":"Clean the chirp bboxes.
This is a collection of filters that remove bboxes that either overlap, are out of range or otherwise do not make sense.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup--returns","title":"Returns","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def cleanup(chirp_df: pd.DataFrame, data: Dataset) -> pd.DataFrame:\n \"\"\"Clean the chirp bboxes.\n\n This is a collection of filters that remove bboxes that\n either overlap, are out of range or otherwise do not make sense.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n\n Returns\n -------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # first clean the bboxes\n chirp_df = clean_bboxes(data, chirp_df)\n # sort chirps in df by time, i.e. t1\n chirp_df = chirp_df.sort_values(by=\"t1\", ascending=True)\n # compute chirp times, i.e. center of the bbox x axis\n return bbox_to_chirptimes(chirp_df)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_assignment_data","title":"extract_assignment_data(data, chirp_df)
","text":"Get envelope troughs to determine chirp assignment.
This algorigthm assigns chirps to wavetracker tracks by a series of steps: 1. clean the chirp bboxes 2. for each fish track, filter the signal on the best electrode 3. find troughs in the envelope of the filtered signal 4. compute the prominence of the trough and the distance to the chirp center 5. compute a cost function that is high when the trough prominence is high and the distance to the chirp center is low 6. compare the value of the cost function for each track and choose the track with the highest cost function value
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_assignment_data--parameters","title":"Parameters","text":" data
: dataset
Dataset object containing the data chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes
Source code in chirpdetector/assign_chirps.py
def extract_assignment_data(\n data: Dataset, chirp_df: pd.DataFrame\n) -> Tuple[Dict[str, np.ndarray], pd.DataFrame, Dataset]:\n \"\"\"Get envelope troughs to determine chirp assignment.\n\n This algorigthm assigns chirps to wavetracker tracks by a series of steps:\n 1. clean the chirp bboxes\n 2. for each fish track, filter the signal on the best electrode\n 3. find troughs in the envelope of the filtered signal\n 4. compute the prominence of the trough and the distance to the chirp\n center\n 5. compute a cost function that is high when the trough prominence is high\n and the distance to the chirp center is low\n 6. compare the value of the cost function for each track and choose the\n track with the highest cost function value\n\n Parameters\n ----------\n - `data`: `dataset`\n Dataset object containing the data\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n \"\"\"\n # clean the chirp bboxes\n chirp_df = cleanup(chirp_df, data)\n\n # now loop over all tracks and assign chirps to tracks\n chirp_indices = [] # index of chirp in chirp_df\n track_ids = [] # id of track / fish\n peak_prominences = [] # prominence of trough in envelope\n peak_distances = [] # distance of trough to chirp center\n peak_times = [] # time of trough in envelope, should be close to chirp\n\n for fish_id in data.track.ids:\n # get chirps, times and freqs and powers for this track\n chirps = np.array(chirp_df.chirp_times.values)\n time = data.track.times[\n data.track.indices[data.track.idents == fish_id]\n ]\n freq = data.track.freqs[data.track.idents == fish_id]\n powers = data.track.powers[data.track.idents == fish_id, :]\n\n if len(time) == 0:\n continue # skip if no track is found\n\n for idx, chirp in enumerate(chirps):\n # find the closest time, freq and power to the chirp time\n closest_idx = np.argmin(np.abs(time - chirp))\n best_electrode = np.argmax(powers[closest_idx, :]).astype(int)\n second_best_electrode = np.argsort(powers[closest_idx, :])[-2]\n best_freq = freq[closest_idx]\n\n # check if chirp overlaps with track\n f1 = chirp_df.f1.to_numpy()[idx]\n f2 = chirp_df.f2.to_numpy()[idx]\n f2 = f1 + (f2 - f1) * 0.5 # range is the lower half of the bbox\n if (f1 > best_freq) or (f2 < best_freq):\n peak_distances.append(np.nan)\n peak_prominences.append(np.nan)\n peak_times.append(np.nan)\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n continue\n\n # determine start and stop index of time window on raw data\n # using bounding box start and stop times of chirp detection\n start_idx, stop_idx, center_idx = make_indices(\n chirp_df, data, idx, chirp\n )\n\n indices = (start_idx, stop_idx, center_idx)\n peaks, proms = extract_envelope_trough(\n data,\n best_electrode,\n second_best_electrode,\n best_freq,\n indices,\n )\n\n # if no peaks are found, skip this chirp\n if len(peaks) == 0:\n peak_distances.append(np.nan)\n peak_prominences.append(np.nan)\n peak_times.append(np.nan)\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n continue\n\n # compute index to closest peak to chirp center\n distances = np.abs(peaks - (center_idx - start_idx))\n closest_peak_idx = np.argmin(distances)\n\n # store peak prominence and distance to chirp center\n peak_distances.append(distances[closest_peak_idx])\n peak_prominences.append(proms[closest_peak_idx])\n peak_times.append(\n (start_idx + peaks[closest_peak_idx]) / data.grid.samplerate,\n )\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n\n peak_prominences = np.array(peak_prominences)\n peak_distances = (\n np.array(peak_distances) + 1\n ) # add 1 to avoid division by zero\n peak_times = np.array(peak_times)\n chirp_indices = np.array(chirp_indices)\n track_ids = np.array(track_ids)\n\n assignment_data = {\n \"proms\": peak_prominences,\n \"peaks\": peak_distances,\n \"ptimes\": peak_times,\n \"cindices\": chirp_indices,\n \"track_ids\": track_ids,\n }\n return (\n assignment_data,\n chirp_df,\n data,\n )\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough","title":"extract_envelope_trough(data, best_electrode, second_best_electrode, best_freq, indices)
","text":"Extract envelope troughs.
Extracts a snippet from the raw data around the chirp time and computes the envelope of the bandpass filtered signal. Then finds the troughs in the envelope and computes their prominences.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough--parameters","title":"Parameters","text":" data
: gridtools.datasets.Dataset
Dataset object containing the data best_electrode
: int
Index of the best electrode second_best_electrode
: int
Index of the second best electrode best_freq
: float
Frequency of the chirp indices
: Tuple[int, int, int]
Tuple containing the start, center, stop indices of the chirp
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough--returns","title":"Returns","text":" peaks
: np.ndarray
Indices of the envelope troughs proms
: np.ndarray
Prominences of the envelope troughs
Source code in chirpdetector/assign_chirps.py
def extract_envelope_trough(\n data: Dataset,\n best_electrode: int,\n second_best_electrode: int,\n best_freq: float,\n indices: Tuple[int, int, int],\n) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"Extract envelope troughs.\n\n Extracts a snippet from the raw data around the chirp time and computes\n the envelope of the bandpass filtered signal. Then finds the troughs in\n the envelope and computes their prominences.\n\n Parameters\n ----------\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `best_electrode`: `int`\n Index of the best electrode\n - `second_best_electrode`: `int`\n Index of the second best electrode\n - `best_freq`: `float`\n Frequency of the chirp\n - `indices`: `Tuple[int, int, int]`\n Tuple containing the start, center, stop indices of the chirp\n\n Returns\n -------\n - `peaks`: `np.ndarray`\n Indices of the envelope troughs\n - `proms`: `np.ndarray`\n Prominences of the envelope troughs\n \"\"\"\n start_idx, stop_idx, _= indices\n\n # determine bandpass cutoffs above and below baseline frequency\n lower_f = best_freq - 15\n upper_f = best_freq + 15\n\n # get the raw signal on the 2 best electrodes and make differential\n raw1 = data.grid.rec[start_idx:stop_idx, best_electrode]\n raw2 = data.grid.rec[start_idx:stop_idx, second_best_electrode]\n raw = raw1 - raw2\n\n # bandpass filter the raw signal\n raw_filtered = bandpass_filter(\n raw,\n data.grid.samplerate,\n lower_f,\n upper_f,\n )\n\n # compute the envelope of the filtered signal\n env = envelope(\n signal=raw_filtered,\n samplerate=data.grid.samplerate,\n cutoff_frequency=50,\n )\n peaks, proms = get_env_trough(env, raw_filtered)\n # mpl.use(\"TkAgg\")\n # plt.plot(env)\n # plt.plot(raw_filtered)\n # plt.plot(peaks, env[peaks], \"x\")\n # plt.show()\n return peaks, proms\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough","title":"get_env_trough(env, raw)
","text":"Get the envelope troughs and their prominences.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough--parameters","title":"Parameters","text":" env
: np.ndarray
Envelope of the filtered signal raw
: np.ndarray
Raw signal
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough--returns","title":"Returns","text":" peaks
: np.ndarray
Indices of the envelope troughs proms
: np.ndarray
Prominences of the envelope troughs
Source code in chirpdetector/assign_chirps.py
def get_env_trough(\n env: np.ndarray,\n raw: np.ndarray,\n) -> Tuple[np.ndarray, np.ndarray]:\n \"\"\"Get the envelope troughs and their prominences.\n\n Parameters\n ----------\n - `env`: `np.ndarray`\n Envelope of the filtered signal\n - `raw`: `np.ndarray`\n Raw signal\n\n Returns\n -------\n - `peaks`: `np.ndarray`\n Indices of the envelope troughs\n - `proms`: `np.ndarray`\n Prominences of the envelope troughs\n \"\"\"\n # normalize the envelope using the amplitude of the raw signal\n # to preserve the amplitude of the envelope\n env = env / np.max(np.abs(raw))\n\n # cut of the first and last 20% of the envelope\n env[: int(0.25 * len(env))] = np.nan\n env[int(0.75 * len(env)) :] = np.nan\n\n # find troughs in the envelope and compute trough prominences\n peaks, params = find_peaks(-env, prominence=1e-3)\n proms = params[\"prominences\"]\n return peaks, proms\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices","title":"make_indices(chirp_df, data, idx, chirp)
","text":"Make indices for the chirp window.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data idx
: int
Index of the chirp in the chirp_df chirp
: float
Chirp time
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices--returns","title":"Returns","text":" start_idx
: int
Start index of the chirp window stop_idx
: int
Stop index of the chirp window center_idx
: int
Center index of the chirp window
Source code in chirpdetector/assign_chirps.py
def make_indices(\n chirp_df: pd.DataFrame, data: Dataset, idx: int, chirp: float\n) -> Tuple[int, int, int]:\n \"\"\"Make indices for the chirp window.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `idx`: `int`\n Index of the chirp in the chirp_df\n - `chirp`: `float`\n Chirp time\n\n Returns\n -------\n - `start_idx`: `int`\n Start index of the chirp window\n - `stop_idx`: `int`\n Stop index of the chirp window\n - `center_idx`: `int`\n Center index of the chirp window\n \"\"\"\n # determine start and stop index of time window on raw data\n # using bounding box start and stop times of chirp detection\n diffr = chirp_df.t2.to_numpy()[idx] - chirp_df.t1.to_numpy()[idx]\n t1 = chirp_df.t1.to_numpy()[idx] - 0.5 * diffr\n t2 = chirp_df.t2.to_numpy()[idx] + 0.5 * diffr\n\n start_idx = int(np.round(t1 * data.grid.samplerate))\n stop_idx = int(np.round(t2 * data.grid.samplerate))\n center_idx = int(np.round(chirp * data.grid.samplerate))\n\n return start_idx, stop_idx, center_idx\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast","title":"non_max_suppression_fast(chirp_df, overlapthresh)
","text":"Raster implementation of non-maximum suppression.
To remove overlapping bounding boxes.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes overlapthresh
: float
Threshold for overlap between bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast--returns","title":"Returns","text":" pick
: list
List of indices of bboxes to keep
Source code in chirpdetector/assign_chirps.py
def non_max_suppression_fast(\n chirp_df: pd.DataFrame,\n overlapthresh: float,\n) -> list:\n \"\"\"Raster implementation of non-maximum suppression.\n\n To remove overlapping bounding boxes.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `overlapthresh`: `float`\n Threshold for overlap between bboxes\n\n Returns\n -------\n - `pick`: `list`\n List of indices of bboxes to keep\n \"\"\"\n # slightly modified version of\n # https://pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/\n\n # convert boxes to list of tuples and then to numpy array\n boxes = chirp_df[[\"t1\", \"f1\", \"t2\", \"f2\"]].to_numpy()\n\n # if there are no boxes, return an empty list\n if len(boxes) == 0:\n return []\n\n # initialize the list of picked indexes\n pick = []\n\n # grab the coordinates of the bounding boxes\n x1 = boxes[:, 0]\n y1 = boxes[:, 1]\n x2 = boxes[:, 2]\n y2 = boxes[:, 3]\n\n # compute the area of the bounding boxes and sort the bounding\n # boxes by the bottom-right y-coordinate of the bounding box\n area = (x2 - x1) * (y2 - y1)\n idxs = np.argsort(y2)\n\n # keep looping while some indexes still remain in the indexes\n # list\n while len(idxs) > 0:\n # grab the last index in the indexes list and add the\n # index value to the list of picked indexes\n last = len(idxs) - 1\n i = idxs[last]\n pick.append(i)\n\n # find the largest (x, y) coordinates for the start of\n # the bounding box and the smallest (x, y) coordinates\n # for the end of the bounding box\n xx1 = np.maximum(x1[i], x1[idxs[:last]])\n yy1 = np.maximum(y1[i], y1[idxs[:last]])\n xx2 = np.minimum(x2[i], x2[idxs[:last]])\n yy2 = np.minimum(y2[i], y2[idxs[:last]])\n\n # compute the width and height of the bounding box\n w = np.maximum(0, xx2 - xx1)\n h = np.maximum(0, yy2 - yy1)\n\n # compute the ratio of overlap (intersection over union)\n overlap = (w * h) / area[idxs[:last]]\n\n # delete all indexes from the index list that have\n idxs = np.delete(\n idxs,\n np.concatenate(([last], np.where(overlap > overlapthresh)[0])),\n )\n # return the indicies of the picked boxes\n return pick\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.track_filter","title":"track_filter(chirp_df, minf, maxf)
","text":"Remove chirp bboxes that do not overlap with tracks.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.track_filter--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes minf
: float
Minimum frequency of the range maxf
: float
Maximum frequency of the range
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.track_filter--returns","title":"Returns","text":" chirp_df_tf
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def track_filter(\n chirp_df: pd.DataFrame,\n minf: float,\n maxf: float,\n) -> pd.DataFrame:\n \"\"\"Remove chirp bboxes that do not overlap with tracks.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `minf`: `float`\n Minimum frequency of the range\n - `maxf`: `float`\n Maximum frequency of the range\n\n Returns\n -------\n - `chirp_df_tf`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # remove all chirp bboxes that have no overlap with the range spanned by\n # minf and maxf\n\n # first build a box that spans the entire range\n range_box = np.array([0, minf, np.max(chirp_df.t2), maxf])\n\n # now compute the intersection between the range box and each chirp bboxes\n # and keep only those that have an intersection area > 0\n chirp_df_tf = chirp_df.copy()\n intersection = chirp_df_tf.apply(\n lambda row: (\n max(0, min(row[\"t2\"], range_box[2]) - max(row[\"t1\"], range_box[0]))\n * max(\n 0,\n min(row[\"f2\"], range_box[3]) - max(row[\"f1\"], range_box[1]),\n )\n ),\n axis=1,\n )\n return chirp_df_tf.loc[intersection > 0, :]\n
"},{"location":"api/convert_data/","title":"convert_data","text":"Functions and classes for converting data.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes","title":"chirp_bounding_boxes(data, nfft)
","text":"Make bounding boxes of simulated chirps using the chirp parameters.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes--parameters","title":"Parameters","text":" data
: Dataset
The dataset to make bounding boxes for. nfft
: int The number of samples in the FFT.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes--returns","title":"Returns","text":"pandas.DataFrame
A dataframe with the bounding boxes.
Source code in chirpdetector/convert_data.py
def chirp_bounding_boxes(data: Dataset, nfft: int) -> pd.DataFrame:\n \"\"\"Make bounding boxes of simulated chirps using the chirp parameters.\n\n Parameters\n ----------\n - `data` : `Dataset`\n The dataset to make bounding boxes for.\n - `nfft` : int\n The number of samples in the FFT.\n\n Returns\n -------\n `pandas.DataFrame`\n A dataframe with the bounding boxes.\n \"\"\"\n assert hasattr(\n data.com.chirp,\n \"params\",\n ), \"Dataset must have a chirp attribute with a params attribute\"\n\n # Time padding is one NFFT window\n pad_time = nfft / data.grid.samplerate\n\n # Freq padding is fixed by the frequency resolution\n freq_res = data.grid.samplerate / nfft\n pad_freq = freq_res * 50\n\n boxes = []\n ids = []\n for fish_id in data.track.ids:\n freqs = data.track.freqs[data.track.idents == fish_id]\n times = data.track.times[\n data.track.indices[data.track.idents == fish_id]\n ]\n chirps = data.com.chirp.times[data.com.chirp.idents == fish_id]\n params = data.com.chirp.params[data.com.chirp.idents == fish_id]\n\n for chirp, param in zip(chirps, params):\n # take the two closest frequency points\n f_closest = freqs[np.argsort(np.abs(times - chirp))[:2]]\n\n # take the two closest time points\n t_closest = times[np.argsort(np.abs(times - chirp))[:2]]\n\n # compute the weighted average of the two closest frequency points\n # using the dt between chirp time and sampled time as weights\n f_closest = np.average(\n f_closest,\n weights=np.abs(t_closest - chirp),\n )\n\n # we now have baseline eodf and time point of the chirp. Now\n # we get some parameters from the params to build the bounding box\n # for the chirp\n height = param[1]\n width = param[2]\n\n # now define bounding box as center coordinates, width and height\n t_center = chirp\n f_center = f_closest + height / 2\n\n bbox_height = height + pad_freq\n bbox_width = width + pad_time\n\n boxes.append((t_center, f_center, bbox_width, bbox_height))\n ids.append(fish_id)\n\n dataframe = pd.DataFrame(\n boxes,\n columns=[\"t_center\", \"f_center\", \"width\", \"height\"],\n )\n dataframe[\"fish_id\"] = ids\n return dataframe\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert","title":"convert(data, conf, output, label_mode)
","text":"Convert a gridtools dataset to a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert--parameters","title":"Parameters","text":" data
: Dataset
The dataset to convert. conf
: Config
The configuration. output
: pathlib.Path
The output directory. label_mode
: str
The label mode. Can be one of 'none', 'synthetic' or 'detected'.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert--returns","title":"Returns","text":""},{"location":"api/convert_data/#chirpdetector.convert_data.convert--notes","title":"Notes","text":"This function iterates through a raw recording in chunks and computes the sum spectrogram of each chunk. The chunk size needs to be chosen such that the images can be nicely fed to a detector. The function also computes the bounding boxes of chirps in that chunk and saves them to a dataframe and a txt file into a labels directory.
Source code in chirpdetector/convert_data.py
def convert(\n data: Dataset,\n conf: Config,\n output: pathlib.Path,\n label_mode: str,\n) -> None:\n \"\"\"Convert a gridtools dataset to a YOLO dataset.\n\n Parameters\n ----------\n - `data` : `Dataset`\n The dataset to convert.\n - `conf` : `Config`\n The configuration.\n - `output` : `pathlib.Path`\n The output directory.\n - `label_mode` : `str`\n The label mode. Can be one of 'none', 'synthetic' or 'detected'.\n\n Returns\n -------\n - `None`\n\n Notes\n -----\n This function iterates through a raw recording in chunks and computes the\n sum spectrogram of each chunk. The chunk size needs to be chosen such that\n the images can be nicely fed to a detector. The function also computes\n the bounding boxes of chirps in that chunk and saves them to a dataframe\n and a txt file into a labels directory.\n \"\"\"\n assert hasattr(data, \"grid\"), \"Dataset must have a grid attribute\"\n assert label_mode in [\n \"none\",\n \"synthetic\",\n \"detected\",\n ], \"label_mode must be one of 'none', 'synthetic' or 'detected'\"\n\n dataroot = output\n\n n_electrodes = data.grid.rec.shape[1]\n\n # How much time to put into each spectrogram\n time_window = conf.spec.time_window # seconds\n window_overlap = conf.spec.spec_overlap # seconds\n freq_pad = conf.spec.freq_pad # Hz\n window_overlap_samples = window_overlap * data.grid.samplerate # samples\n\n # Spectrogram computation parameters\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = time_window * data.grid.samplerate # samples\n n_chunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n\n rprint(\n \"Dividing recording of duration\"\n f\"{data.grid.rec.shape[0] / data.grid.samplerate} into {n_chunks}\"\n f\"chunks of {time_window} seconds each.\",\n )\n\n bbox_dfs = []\n\n # shift the time of the tracks to start at 0\n # because a subset starts at the orignal time\n # TODO: Remove this when gridtools is fixed\n data.track.times -= data.track.times[0]\n\n for chunk_no in range(n_chunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = sint(chunk_no * chunksize)\n idx2 = sint((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == n_chunks - 1:\n idx1 = sint(chunk_no * chunksize - window_overlap_samples)\n idx2 = sint((chunk_no + 1) * chunksize)\n else:\n idx1 = sint(chunk_no * chunksize - window_overlap_samples)\n idx2 = sint((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n\n chunk = subset(data, idx1, idx2, mode=\"index\")\n\n # compute the spectrogram for each electrode of the current chunk\n spec = None\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n # the spec is a tensor\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n if spec is None:\n msg = \"Failed to compute spectrogram.\"\n raise ValueError(msg)\n\n # normalize spectrogram by the number of electrodes\n # the spec is still a tensor\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n # .. still a tensor\n spec = decibel(spec)\n\n # cut off everything outside the upper frequency limit\n # the spec is still a tensor\n\n spectrogram_freq_limits = (\n np.min(chunk.track.freqs) - freq_pad,\n np.max(chunk.track.freqs) + freq_pad,\n )\n\n spec = spec[\n (spec_freqs >= spectrogram_freq_limits[0])\n & (spec_freqs <= spectrogram_freq_limits[1]),\n :,\n ]\n spec_freqs = spec_freqs[\n (spec_freqs >= spectrogram_freq_limits[0])\n & (spec_freqs <= spectrogram_freq_limits[1])\n ]\n\n # normalize the spectrogram to zero mean and unit variance\n # the spec is still a tensor\n spec = (spec - spec.mean()) / spec.std()\n\n # convert the spectrogram to a PIL image\n spec = spec.detach().cpu().numpy()\n img = numpy_to_pil(spec)\n\n imgname = f\"{chunk.path.name}.png\"\n if label_mode == \"synthetic\":\n bbox_df, img = synthetic_labels(\n dataroot,\n chunk,\n nfft,\n spec,\n spec_times,\n spec_freqs,\n imgname,\n chunk_no,\n img,\n )\n if bbox_df is None:\n continue\n bbox_dfs.append(bbox_df)\n elif label_mode == \"detected\":\n detected_labels(dataroot, chunk, imgname, spec, spec_times)\n\n # save image\n img.save(dataroot / \"images\" / f\"{imgname}\")\n\n if label_mode == \"synthetic\":\n bbox_df = pd.concat(bbox_dfs, ignore_index=True)\n bbox_df.to_csv(dataroot / f\"{data.path.name}_bboxes.csv\", index=False)\n\n # save the classes.txt file\n classes = [\"__background__\", \"chirp\"]\n with pathlib.Path.open(dataroot / \"classes.txt\", \"w\") as f:\n f.write(\"\\n\".join(classes))\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli","title":"convert_cli(path, output, label_mode)
","text":"Parse all datasets in a directory and convert them to a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli--parameters","title":"Parameters","text":" path
: pathlib.Path
The root directory of the datasets.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli--returns","title":"Returns","text":" Source code in chirpdetector/convert_data.py
def convert_cli(\n path: pathlib.Path,\n output: pathlib.Path,\n label_mode: str,\n) -> None:\n \"\"\"Parse all datasets in a directory and convert them to a YOLO dataset.\n\n Parameters\n ----------\n - `path` : `pathlib.Path`\n The root directory of the datasets.\n\n Returns\n -------\n - `None`\n \"\"\"\n make_file_tree(output)\n config = load_config(str(path / \"chirpdetector.toml\"))\n\n for p in track(list(path.iterdir()), description=\"Building datasets\"):\n if p.is_file():\n continue\n data = load(p)\n convert(data, config, output, label_mode)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels","title":"detected_labels(output, chunk, imgname, spec, spec_times)
","text":"Use the detect_chirps to make a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels--parameters","title":"Parameters","text":" output
: pathlib.Path
The output directory. chunk
: Dataset
The dataset to make bounding boxes for. imgname
: str
The name of the image. spec
: np.ndarray
The spectrogram. spec_times
: np.ndarray
The time axis of the spectrogram.
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels--returns","title":"Returns","text":" Source code in chirpdetector/convert_data.py
def detected_labels(\n output: pathlib.Path,\n chunk: Dataset,\n imgname: str,\n spec: np.ndarray,\n spec_times: np.ndarray,\n) -> None:\n \"\"\"Use the detect_chirps to make a YOLO dataset.\n\n Parameters\n ----------\n - `output` : `pathlib.Path`\n The output directory.\n - `chunk` : `Dataset`\n The dataset to make bounding boxes for.\n - `imgname` : `str`\n The name of the image.\n - `spec` : `np.ndarray`\n The spectrogram.\n - `spec_times` : `np.ndarray`\n The time axis of the spectrogram.\n\n Returns\n -------\n - `None`\n \"\"\"\n # load the detected bboxes csv\n # TODO: This is a workaround. Instead improve the subset naming convention\n # in gridtools\n source_dataset = chunk.path.name.split(\"_\")[1:-4]\n source_dataset = \"_\".join(source_dataset)\n source_dataset = chunk.path.parent / source_dataset\n\n dataframe = pd.read_csv(source_dataset / \"chirpdetector_bboxes.csv\")\n\n # get chunk start and stop time\n start, stop = spec_times[0], spec_times[-1]\n\n # get the bboxes for this chunk\n bboxes = dataframe[(dataframe.t1 >= start) & (dataframe.t2 <= stop)]\n\n # get the x and y coordinates of the bboxes in pixels as dataframe\n bboxes_xy = bboxes[[\"x1\", \"y1\", \"x2\", \"y2\"]]\n\n # convert from x1, y1, x2, y2 to centerx, centery, width, height\n centerx = np.array((bboxes_xy[\"x1\"] + bboxes_xy[\"x2\"]) / 2)\n centery = np.array((bboxes_xy[\"y1\"] + bboxes_xy[\"y2\"]) / 2)\n width = np.array(bboxes_xy[\"x2\"] - bboxes_xy[\"x1\"])\n height = np.array(bboxes_xy[\"y2\"] - bboxes_xy[\"y1\"])\n\n # flip centery because origin is top left\n centery = spec.shape[0] - centery\n\n # make relative to image size\n centerx = centerx / spec.shape[1]\n centery = centery / spec.shape[0]\n width = width / spec.shape[1]\n height = height / spec.shape[0]\n labels = np.ones_like(centerx, dtype=int)\n\n # make a new dataframe with the relative coordinates\n new_bboxes = pd.DataFrame(\n {\"l\": labels, \"x\": centerx, \"y\": centery, \"w\": width, \"h\": height},\n )\n\n # save dataframe for every spec without headers as txt\n new_bboxes.to_csv(\n output / \"labels\" / f\"{imgname[:-4]}.txt\",\n header=False,\n index=False,\n sep=\" \",\n )\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.make_file_tree","title":"make_file_tree(path)
","text":"Build a file tree for the training dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.make_file_tree--parameters","title":"Parameters","text":"path : pathlib.Path The root directory of the dataset.
Source code in chirpdetector/convert_data.py
def make_file_tree(path: pathlib.Path) -> None:\n \"\"\"Build a file tree for the training dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The root directory of the dataset.\n \"\"\"\n if path.parent.exists() and path.parent.is_file():\n msg = (\n f\"Parent directory of {path} is a file. \"\n \"Please specify a directory.\"\n )\n raise ValueError(msg)\n\n if path.exists():\n shutil.rmtree(path)\n\n path.mkdir(exist_ok=True, parents=True)\n\n train_imgs = path / \"images\"\n train_labels = path / \"labels\"\n train_imgs.mkdir(exist_ok=True, parents=True)\n train_labels.mkdir(exist_ok=True, parents=True)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil","title":"numpy_to_pil(img)
","text":"Convert a 2D numpy array to a PIL image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil--parameters","title":"Parameters","text":"img : np.ndarray The input image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil--returns","title":"Returns","text":"PIL.Image The converted image.
Source code in chirpdetector/convert_data.py
def numpy_to_pil(img: np.ndarray) -> Image.Image:\n \"\"\"Convert a 2D numpy array to a PIL image.\n\n Parameters\n ----------\n img : np.ndarray\n The input image.\n\n Returns\n -------\n PIL.Image\n The converted image.\n \"\"\"\n img_dimens = 2\n if len(img.shape) != img_dimens:\n msg = f\"Image must be {img_dimens}D\"\n raise ValueError(msg)\n\n if img.max() == img.min():\n msg = \"Image must have more than one value\"\n raise ValueError(msg)\n\n img = np.flipud(img)\n intimg = np.uint8((img - img.min()) / (img.max() - img.min()) * 255)\n return Image.fromarray(intimg)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels","title":"synthetic_labels(output, chunk, nfft, spec, spec_times, spec_freqs, imgname, chunk_no, img)
","text":"Generate labels of a simulated dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels--parameters","title":"Parameters","text":" output
: pathlib.Path
The output directory. chunk
: Dataset
The dataset to make bounding boxes for. nfft
: int
The number of samples in the FFT. spec
: np.ndarray
The spectrogram. spec_times
: np.ndarray
The time axis of the spectrogram. spec_freqs
: np.ndarray
The frequency axis of the spectrogram. imgname
: str
The name of the image. chunk_no
: int
The chunk number. img
: Image
The image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels--returns","title":"Returns","text":" pandas.DataFrame
A dataframe with the bounding boxes.
Source code in chirpdetector/convert_data.py
def synthetic_labels(\n output: pathlib.Path,\n chunk: Dataset,\n nfft: int,\n spec: np.ndarray,\n spec_times: np.ndarray,\n spec_freqs: np.ndarray,\n imgname: str,\n chunk_no: int,\n img: Image.Image,\n) -> Union[Tuple[pd.DataFrame, Image.Image], Tuple[None, None]]:\n \"\"\"Generate labels of a simulated dataset.\n\n Parameters\n ----------\n - `output` : `pathlib.Path`\n The output directory.\n - `chunk` : `Dataset`\n The dataset to make bounding boxes for.\n - `nfft` : `int`\n The number of samples in the FFT.\n - `spec` : `np.ndarray`\n The spectrogram.\n - `spec_times` : `np.ndarray`\n The time axis of the spectrogram.\n - `spec_freqs` : `np.ndarray`\n The frequency axis of the spectrogram.\n - `imgname` : `str`\n The name of the image.\n - `chunk_no` : `int`\n The chunk number.\n - `img` : `Image`\n The image.\n\n Returns\n -------\n - `pandas.DataFrame`\n A dataframe with the bounding boxes.\n \"\"\"\n # compute the bounding boxes for this chunk\n bboxes = chirp_bounding_boxes(chunk, nfft)\n\n if len(bboxes) == 0:\n return None, None\n\n # convert bounding box center coordinates to spectrogram coordinates\n # find the indices on the spec_times corresponding to the center times\n x = np.searchsorted(spec_times, bboxes.t_center)\n y = np.searchsorted(spec_freqs, bboxes.f_center)\n widths = np.searchsorted(spec_times - spec_times[0], bboxes.width)\n heights = np.searchsorted(spec_freqs - spec_freqs[0], bboxes.height)\n\n # now we have center coordinates, widths and heights in indices. But PIL\n # expects coordinates in pixels in the format\n # (Upper left x coordinate, upper left y coordinate,\n # lower right x coordinate, lower right y coordinate)\n # In addiotion, an image starts in the top left corner so the bboxes\n # need to be mirrored horizontally.\n\n y = spec.shape[0] - y # flip the y values to fit y=0 at the top\n lxs, lys = x - widths / 2, y - heights / 2\n rxs, rys = x + widths / 2, y + heights / 2\n\n # add them to the bboxes dataframe\n bboxes[\"upperleft_img_x\"] = lxs\n bboxes[\"upperleft_img_y\"] = lys\n bboxes[\"lowerright_img_x\"] = rxs\n bboxes[\"lowerright_img_y\"] = rys\n\n # yolo format is centerx, centery, width, height\n # convert xmin, ymin, xmax, ymax to centerx, centery, width, height\n centerx = (lxs + rxs) / 2\n centery = (lys + rys) / 2\n width = rxs - lxs\n height = rys - lys\n\n # most deep learning frameworks expect bounding box coordinates\n # as relative to the image size. So we normalize the coordinates\n # to the image size\n centerx_norm = centerx / spec.shape[1]\n centery_norm = centery / spec.shape[0]\n width_norm = width / spec.shape[1]\n height_norm = height / spec.shape[0]\n\n # add them to the bboxes dataframe\n bboxes[\"centerx_norm\"] = centerx_norm\n bboxes[\"centery_norm\"] = centery_norm\n bboxes[\"width_norm\"] = width_norm\n bboxes[\"height_norm\"] = height_norm\n\n # add chunk ID to the bboxes dataframe\n bboxes[\"chunk_id\"] = chunk_no\n\n # put them into a dataframe to save for eahc spectrogram\n dataframe = pd.DataFrame(\n {\n \"cx\": centerx_norm,\n \"cy\": centery_norm,\n \"w\": width_norm,\n \"h\": height_norm,\n },\n )\n\n # add as first colum instance id\n dataframe.insert(0, \"instance_id\", np.ones_like(lxs, dtype=int))\n\n # stash the bboxes dataframe for this chunk\n bboxes[\"image\"] = imgname\n\n # save dataframe for every spec without headers as txt\n dataframe.to_csv(\n output / \"labels\" / f\"{chunk.path.name}.txt\",\n header=False,\n index=False,\n sep=\" \",\n )\n return bboxes, img\n
"},{"location":"api/dataset_utils/","title":"dataset_utils","text":"Utility functions for training datasets in the YOLO format.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset","title":"clean_yolo_dataset(path, img_ext)
","text":"Remove images and labels when the label file is empty.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset. img_ext : str
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def clean_yolo_dataset(path: pathlib.Path, img_ext: str) -> None:\n \"\"\"Remove images and labels when the label file is empty.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset.\n img_ext : str\n\n Returns\n -------\n None\n \"\"\"\n img_path = path / \"images\"\n lbl_path = path / \"labels\"\n\n images = list(img_path.glob(f\"*{img_ext}\"))\n\n for image in images:\n lbl = lbl_path / f\"{image.stem}.txt\"\n if lbl.stat().st_size == 0:\n image.unlink()\n lbl.unlink()\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img","title":"load_img(path)
","text":"Load an image from a path as a numpy array.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img--parameters","title":"Parameters","text":"path : pathlib.Path The path to the image.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img--returns","title":"Returns","text":"img : np.ndarray The image as a numpy array.
Source code in chirpdetector/dataset_utils.py
def load_img(path: pathlib.Path) -> np.ndarray:\n \"\"\"Load an image from a path as a numpy array.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the image.\n\n Returns\n -------\n img : np.ndarray\n The image as a numpy array.\n \"\"\"\n img = Image.open(path)\n return np.asarray(img)\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets","title":"merge_yolo_datasets(dataset1, dataset2, output)
","text":"Merge two yolo-style datasets into one.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets--parameters","title":"Parameters","text":"dataset1 : str The path to the first dataset. dataset2 : str The path to the second dataset. output : str The path to the output dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def merge_yolo_datasets(\n dataset1: pathlib.Path,\n dataset2: pathlib.Path,\n output: pathlib.Path,\n) -> None:\n \"\"\"Merge two yolo-style datasets into one.\n\n Parameters\n ----------\n dataset1 : str\n The path to the first dataset.\n dataset2 : str\n The path to the second dataset.\n output : str\n The path to the output dataset.\n\n Returns\n -------\n None\n \"\"\"\n dataset1 = pathlib.Path(dataset1)\n dataset2 = pathlib.Path(dataset2)\n output = pathlib.Path(output)\n\n if not dataset1.exists():\n msg = f\"{dataset1} does not exist.\"\n raise FileNotFoundError(msg)\n if not dataset2.exists():\n msg = f\"{dataset2} does not exist.\"\n raise FileNotFoundError(msg)\n if output.exists():\n msg = f\"{output} already exists.\"\n raise FileExistsError(msg)\n\n output_images = output / \"images\"\n output_images.mkdir(parents=True, exist_ok=False)\n output_labels = output / \"labels\"\n output_labels.mkdir(parents=True, exist_ok=False)\n\n imgs1 = list((dataset1 / \"images\").iterdir())\n labels1 = list((dataset1 / \"labels\").iterdir())\n imgs2 = list((dataset2 / \"images\").iterdir())\n labels2 = list((dataset2 / \"labels\").iterdir())\n\n print(f\"Found {len(imgs1)} images in {dataset1}.\")\n print(f\"Found {len(imgs2)} images in {dataset2}.\")\n\n print(f\"Copying images and labels to {output}...\")\n for idx, _ in enumerate(imgs1):\n shutil.copy(imgs1[idx], output_images / imgs1[idx].name)\n shutil.copy(labels1[idx], output_labels / labels1[idx].name)\n\n for idx, _ in enumerate(imgs2):\n shutil.copy(imgs2[idx], output_images / imgs2[idx].name)\n shutil.copy(labels2[idx], output_labels / labels2[idx].name)\n\n classes = dataset1 / \"classes.txt\"\n shutil.copy(classes, output / classes.name)\n\n print(f\"Done. Merged {len(imgs1) + len(imgs2)} images.\")\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset","title":"plot_yolo_dataset(path, n)
","text":"Plot n random images YOLO-style dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def plot_yolo_dataset(path: pathlib.Path, n: int) -> None:\n \"\"\"Plot n random images YOLO-style dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset.\n\n Returns\n -------\n None\n \"\"\"\n mpl.use(\"TkAgg\")\n labelpath = path / \"labels\"\n imgpath = path / \"images\"\n\n label_paths = np.array(list(labelpath.glob(\"*.txt\")))\n label_paths = np.random.choice(label_paths, n)\n\n for lp in label_paths:\n imgp = imgpath / (lp.stem + \".png\")\n img = load_img(imgp)\n labs = np.loadtxt(lp, dtype=np.float32).reshape(-1, 5)\n\n coords = labs[:, 1:]\n\n # make coords absolute and normalize\n coords[:, 0] *= img.shape[1]\n coords[:, 1] *= img.shape[0]\n coords[:, 2] *= img.shape[1]\n coords[:, 3] *= img.shape[0]\n\n # turn centerx, centery, width, height into xmin, ymin, xmax, ymax\n xmin = coords[:, 0] - coords[:, 2] / 2\n ymin = coords[:, 1] - coords[:, 3] / 2\n xmax = coords[:, 0] + coords[:, 2] / 2\n ymax = coords[:, 1] + coords[:, 3] / 2\n\n # plot the image\n _, ax = plt.subplots(figsize=(15, 5), constrained_layout=True)\n ax.imshow(img, cmap=\"magma\")\n for i in range(len(xmin)):\n ax.add_patch(\n Rectangle(\n (xmin[i], ymin[i]),\n xmax[i] - xmin[i],\n ymax[i] - ymin[i],\n fill=False,\n color=\"white\",\n ),\n )\n ax.set_title(imgp.stem)\n plt.axis(\"off\")\n plt.show()\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset","title":"subset_yolo_dataset(path, img_ext, n)
","text":"Subset a YOLO dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset root. img_ext : str The image extension, e.g. .png or .jpg n : int The size of the subset
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def subset_yolo_dataset(path: pathlib.Path, img_ext: str, n: int) -> None:\n \"\"\"Subset a YOLO dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset root.\n img_ext : str\n The image extension, e.g. .png or .jpg\n n : int\n The size of the subset\n\n Returns\n -------\n None\n \"\"\"\n img_path = path / \"images\"\n lbl_path = path / \"labels\"\n\n images = np.array(img_path.glob(f\"*{img_ext}\"))\n np.random.shuffle(images)\n\n images = images[:n]\n\n subset_dir = path.parent / f\"{path.name}_subset\"\n subset_dir.mkdir(exist_ok=True)\n\n subset_img_path = subset_dir / \"images\"\n subset_img_path.mkdir(exist_ok=True)\n subset_lbl_path = subset_dir / \"labels\"\n subset_lbl_path.mkdir(exist_ok=True)\n\n shutil.copy(path / \"classes.txt\", subset_dir)\n\n for image in images:\n shutil.copy(image, subset_img_path)\n shutil.copy(lbl_path / f\"{image.stem}.txt\", subset_lbl_path)\n
"},{"location":"api/detect_chirps/","title":"detect_chirps","text":"Detect chirps on a spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle","title":"coords_to_mpl_rectangle(boxes)
","text":"Convert normal bounding box to matplotlib.pathes.Rectangle format.
Convert box defined by corner coordinates (x1, y1, x2, y2) to box defined by lower left, width and height (x1, y1, w, h).
The corner coordinates are the model output, but the center coordinates are needed by the matplotlib.patches.Rectangle
object for plotting.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle--parameters","title":"Parameters","text":" boxes
: numpy.ndarray
The boxes to be converted.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle--returns","title":"Returns","text":" numpy.ndarray
The converted boxes.
Source code in chirpdetector/detect_chirps.py
def coords_to_mpl_rectangle(boxes: np.ndarray) -> np.ndarray:\n \"\"\"Convert normal bounding box to matplotlib.pathes.Rectangle format.\n\n Convert box defined by corner coordinates (x1, y1, x2, y2)\n to box defined by lower left, width and height (x1, y1, w, h).\n\n The corner coordinates are the model output, but the center coordinates\n are needed by the `matplotlib.patches.Rectangle` object for plotting.\n\n Parameters\n ----------\n - `boxes` : `numpy.ndarray`\n The boxes to be converted.\n\n Returns\n -------\n - `numpy.ndarray`\n The converted boxes.\n \"\"\"\n boxes_dims = 2\n if len(boxes.shape) != boxes_dims:\n msg = (\n \"The boxes array must be 2-dimensional.\\n\"\n f\"Shape of boxes: {boxes.shape}\"\n )\n raise ValueError(msg)\n boxes_cols = 4\n if boxes.shape[1] != boxes_cols:\n msg = (\n \"The boxes array must have 4 columns.\\n\"\n f\"Shape of boxes: {boxes.shape}\"\n )\n raise ValueError(msg)\n\n new_boxes = np.zeros_like(boxes)\n new_boxes[:, 0] = boxes[:, 0]\n new_boxes[:, 1] = boxes[:, 1]\n new_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]\n new_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]\n\n return new_boxes\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps","title":"detect_chirps(conf, data)
","text":"Detect chirps on a spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps--parameters","title":"Parameters","text":" conf
: Config
The configuration object. data
: Dataset
The gridtools dataset to detect chirps on.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def detect_chirps(conf: Config, data: Dataset) -> None:\n \"\"\"Detect chirps on a spectrogram.\n\n Parameters\n ----------\n - `conf` : `Config`\n The configuration object.\n - `data` : `Dataset`\n The gridtools dataset to detect chirps on.\n\n Returns\n -------\n - `None`\n \"\"\"\n # get the number of electrodes\n n_electrodes = data.grid.rec.shape[1]\n\n # load the model and the checkpoint, and set it to evaluation mode\n device = get_device()\n model = load_fasterrcnn(num_classes=len(conf.hyper.classes))\n checkpoint = torch.load(\n f\"{conf.hyper.modelpath}/model.pt\",\n map_location=device,\n )\n model.load_state_dict(checkpoint[\"model_state_dict\"])\n model.to(device).eval()\n\n # make spec config\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = conf.spec.time_window * data.grid.samplerate # samples\n nchunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n window_overlap_samples = int(conf.spec.spec_overlap * data.grid.samplerate)\n\n bbox_dfs = []\n\n # iterate over the chunks\n overwritten = False\n for chunk_no in range(nchunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = int(chunk_no * chunksize)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == nchunks - 1:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize)\n else:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n\n # This bit should alleviate the edge effects of the tracks\n # by limiting the start and stop times of the spectrogram\n # to the start and stop times of the track.\n start_t = idx1 / data.grid.samplerate\n stop_t = idx2 / data.grid.samplerate\n if data.track.times[-1] < stop_t:\n stop_t = data.track.times[-1]\n idx2 = int(stop_t * data.grid.samplerate)\n if data.track.times[0] > start_t:\n start_t = data.track.times[0]\n idx1 = int(start_t * data.grid.samplerate)\n if start_t > data.track.times[-1] or stop_t < data.track.times[0]:\n continue\n\n chunk = subset(data, idx1, idx2, mode=\"index\")\n if len(chunk.track.indices) == 0:\n continue\n\n # compute the spectrogram for each electrode of the current chunk\n spec = torch.zeros((len(spec_freqs), len(spec_times)))\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.rec.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n # the spec is a tensor\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n # normalize spectrogram by the number of electrodes\n # the spec is still a tensor\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n # .. still a tensor\n spec = decibel(spec)\n\n # cut off everything outside the upper frequency limit\n # the spec is still a tensor\n # TODO: THIS IS SKETCHY AS HELL! As a result, only time and frequency\n # bounding boxes can be used later! The spectrogram limits change\n # for every window!\n flims = (\n np.min(chunk.track.freqs) - conf.spec.freq_pad,\n np.max(chunk.track.freqs) + conf.spec.freq_pad,\n )\n spec = spec[(spec_freqs >= flims[0]) & (spec_freqs <= flims[1]), :]\n spec_freqs = spec_freqs[\n (spec_freqs >= flims[0]) & (spec_freqs <= flims[1])\n ]\n\n # make a path to save the spectrogram\n path = data.path / \"chirpdetections\"\n if path.exists() and overwritten is False:\n shutil.rmtree(path)\n overwritten = True\n path.mkdir(exist_ok=True)\n path /= f\"chunk{chunk_no:05d}.png\"\n\n # add the 3 channels, normalize to 0-1, etc\n img = spec_to_image(spec)\n\n # perform the detection\n with torch.inference_mode():\n outputs = model([img])\n\n # put the boxes, scores and labels into the dataset\n bboxes = outputs[0][\"boxes\"].detach().cpu().numpy()\n scores = outputs[0][\"scores\"].detach().cpu().numpy()\n labels = outputs[0][\"labels\"].detach().cpu().numpy()\n\n # remove all boxes with a score below the threshold\n bboxes = bboxes[scores > conf.det.threshold]\n labels = labels[scores > conf.det.threshold]\n scores = scores[scores > conf.det.threshold]\n\n # if np.any(scores > conf.det.threshold):\n # plot_detections(img, outputs[0], conf.det.threshold, path, conf)\n\n # save the bboxes to a dataframe\n bbox_df = pd.DataFrame(\n data=bboxes,\n columns=[\"x1\", \"y1\", \"x2\", \"y2\"],\n )\n bbox_df[\"score\"] = scores\n bbox_df[\"label\"] = labels\n\n # convert x values to time on spec_times\n spec_times_index = np.arange(0, len(spec_times))\n bbox_df[\"t1\"] = float_index_interpolation(\n bbox_df[\"x1\"].to_numpy(),\n spec_times_index,\n spec_times,\n )\n bbox_df[\"t2\"] = float_index_interpolation(\n bbox_df[\"x2\"].to_numpy(),\n spec_times_index,\n spec_times,\n )\n\n # convert y values to frequency on spec_freqs\n spec_freqs_index = np.arange(len(spec_freqs))\n bbox_df[\"f1\"] = float_index_interpolation(\n bbox_df[\"y1\"].to_numpy(),\n spec_freqs_index,\n spec_freqs,\n )\n bbox_df[\"f2\"] = float_index_interpolation(\n bbox_df[\"y2\"].to_numpy(),\n spec_freqs_index,\n spec_freqs,\n )\n\n # save df to list\n bbox_dfs.append(bbox_df)\n\n # concatenate all dataframes\n bbox_df = pd.concat(bbox_dfs)\n bbox_reset = bbox_df.reset_index(drop=True)\n\n # sort the dataframe by t1\n bbox_sorted = bbox_reset.sort_values(by=\"t1\")\n\n # sort the columns\n bbox_sorted = bbox_sorted[\n [\"label\", \"score\", \"x1\", \"y1\", \"x2\", \"y2\", \"t1\", \"f1\", \"t2\", \"f2\"]\n ]\n\n # save the dataframe\n bbox_sorted.to_csv(data.path / \"chirpdetector_bboxes.csv\", index=False)\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli","title":"detect_cli(input_path)
","text":"Terminal interface for the detection function.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli--parameters","title":"Parameters","text":""},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def detect_cli(input_path: pathlib.Path) -> None:\n \"\"\"Terminal interface for the detection function.\n\n Parameters\n ----------\n - `path` : `str`\n\n Returns\n -------\n - `None`\n \"\"\"\n # make the global logger object\n # global logger # pylint: disable=global-statement\n path = pathlib.Path(input_path)\n logger = make_logger(__name__, path / \"chirpdetector.log\")\n datasets = [folder for folder in path.iterdir() if folder.is_dir()]\n confpath = path / \"chirpdetector.toml\"\n\n # load the config file and print a warning if it does not exist\n if confpath.exists():\n config = load_config(str(confpath))\n else:\n msg = (\n \"The configuration file could not be found in the specified path.\"\n \"Please run `chirpdetector copyconfig` and change the \"\n \"configuration file to your needs.\"\n )\n raise FileNotFoundError(msg)\n\n # detect chirps in all datasets in the specified path\n # and show a progress bar\n prog.console.rule(\"Starting detection\")\n with prog:\n task = prog.add_task(\"Detecting chirps...\", total=len(datasets))\n for dataset in datasets:\n msg = f\"Detecting chirps in {dataset.name}...\"\n prog.console.log(msg)\n logger.info(msg)\n\n data = load(dataset)\n detect_chirps(config, data)\n prog.update(task, advance=1)\n prog.update(task, completed=len(datasets))\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation","title":"float_index_interpolation(values, index_arr, data_arr)
","text":"Convert float indices to values by linear interpolation.
Interpolates a set of float indices within the given index array to obtain corresponding values from the data array using linear interpolation.
Given a set of float indices (values
), this function determines the corresponding values in the data_arr
by linearly interpolating between adjacent indices in the index_arr
. Linear interpolation involves calculating weighted averages based on the fractional parts of the float indices.
This function is useful to transform float coordinates on a spectrogram matrix to the corresponding time and frequency values. The reason for this is, that the model outputs bounding boxes in float coordinates, i.e. it does not care about the exact pixel location of the bounding box.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--parameters","title":"Parameters","text":" values
: np.ndarray
The index value as a float that should be interpolated. index_arr
: numpy.ndarray
The array of indices on the data array. data_arr
: numpy.ndarray
The array of data.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--returns","title":"Returns","text":" numpy.ndarray
The interpolated value.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--raises","title":"Raises","text":" ValueError
If any of the input float indices (values
) are outside the range of the provided index_arr
.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--examples","title":"Examples","text":"values = np.array([2.5, 3.2, 4.8]) index_arr = np.array([2, 3, 4, 5]) data_arr = np.array([10, 15, 20, 25]) result = float_index_interpolation(values, index_arr, data_arr) print(result) array([12.5, 16. , 22.5])
Source code in chirpdetector/detect_chirps.py
def float_index_interpolation(\n values: np.ndarray,\n index_arr: np.ndarray,\n data_arr: np.ndarray,\n) -> np.ndarray:\n \"\"\"Convert float indices to values by linear interpolation.\n\n Interpolates a set of float indices within the given index\n array to obtain corresponding values from the data\n array using linear interpolation.\n\n Given a set of float indices (`values`), this function determines\n the corresponding values in the `data_arr` by linearly interpolating\n between adjacent indices in the `index_arr`. Linear interpolation\n involves calculating weighted averages based on the fractional\n parts of the float indices.\n\n This function is useful to transform float coordinates on a spectrogram\n matrix to the corresponding time and frequency values. The reason for\n this is, that the model outputs bounding boxes in float coordinates,\n i.e. it does not care about the exact pixel location of the bounding\n box.\n\n Parameters\n ----------\n - `values` : `np.ndarray`\n The index value as a float that should be interpolated.\n - `index_arr` : `numpy.ndarray`\n The array of indices on the data array.\n - `data_arr` : `numpy.ndarray`\n The array of data.\n\n Returns\n -------\n - `numpy.ndarray`\n The interpolated value.\n\n Raises\n ------\n - `ValueError`\n If any of the input float indices (`values`) are outside\n the range of the provided `index_arr`.\n\n Examples\n --------\n >>> values = np.array([2.5, 3.2, 4.8])\n >>> index_arr = np.array([2, 3, 4, 5])\n >>> data_arr = np.array([10, 15, 20, 25])\n >>> result = float_index_interpolation(values, index_arr, data_arr)\n >>> print(result)\n array([12.5, 16. , 22.5])\n \"\"\"\n # Check if the values are within the range of the index array\n if np.any(values < (np.min(index_arr) - 1)) or np.any(\n values > (np.max(index_arr) + 1),\n ):\n msg = (\n \"Values outside the range of index array\\n\"\n f\"Target values: {values}\\n\"\n f\"Index array: {index_arr}\\n\"\n f\"Data array: {data_arr}\"\n )\n raise ValueError(msg)\n\n # Find the indices corresponding to the values\n lower_indices = np.floor(values).astype(int)\n upper_indices = np.ceil(values).astype(int)\n\n # Ensure upper indices are within the array bounds\n upper_indices = np.minimum(upper_indices, len(index_arr) - 1)\n lower_indices = np.minimum(lower_indices, len(index_arr) - 1)\n\n # Calculate the interpolation weights\n weights = values - lower_indices\n\n # Linear interpolation\n return (1 - weights) * data_arr[lower_indices] + weights * data_arr[\n upper_indices\n ]\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections","title":"plot_detections(img_tensor, output, threshold, save_path, conf)
","text":"Plot the detections on the spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections--parameters","title":"Parameters","text":" img_tensor
: torch.Tensor
The spectrogram. output
: torch.Tensor
The output of the model. threshold
: float
The threshold for the detections. save_path
: pathlib.Path
The path to save the plot to. conf
: Config
The configuration object.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def plot_detections(\n img_tensor: torch.Tensor,\n output: torch.Tensor,\n threshold: float,\n save_path: pathlib.Path,\n conf: Config,\n) -> None:\n \"\"\"Plot the detections on the spectrogram.\n\n Parameters\n ----------\n - `img_tensor` : `torch.Tensor`\n The spectrogram.\n - `output` : `torch.Tensor`\n The output of the model.\n - `threshold` : `float`\n The threshold for the detections.\n - `save_path` : `pathlib.Path`\n The path to save the plot to.\n - `conf` : `Config`\n The configuration object.\n\n Returns\n -------\n - `None`\n \"\"\"\n # retrieve all the data from the output and convert\n # spectrogram to numpy array\n img = img_tensor.detach().cpu().numpy().transpose(1, 2, 0)[..., 0]\n boxes = output[\"boxes\"].detach().cpu().numpy()\n boxes = coords_to_mpl_rectangle(boxes)\n scores = output[\"scores\"].detach().cpu().numpy()\n labels = output[\"labels\"].detach().cpu().numpy()\n labels = [conf.hyper.classes[i] for i in labels]\n\n _, ax = plt.subplots(figsize=(20, 10))\n\n ax.pcolormesh(img, cmap=\"magma\")\n\n for i, box in enumerate(boxes):\n if scores[i] > threshold:\n ax.scatter(\n box[0],\n box[1],\n )\n ax.add_patch(\n Rectangle(\n box[:2],\n box[2],\n box[3],\n fill=False,\n color=\"white\",\n linewidth=1,\n ),\n )\n ax.text(\n box[0],\n box[1],\n f\"{scores[i]:.2f}\",\n color=\"black\",\n fontsize=8,\n bbox={\"facecolor\":\"white\", \"alpha\":1},\n )\n plt.axis(\"off\")\n plt.savefig(save_path, dpi=300, bbox_inches=\"tight\", pad_inches=0)\n plt.close()\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image","title":"spec_to_image(spec)
","text":"Convert a spectrogram to an image.
Add 3 color channels, normalize to 0-1, etc.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image--parameters","title":"Parameters","text":""},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def spec_to_image(spec: torch.Tensor) -> torch.Tensor:\n \"\"\"Convert a spectrogram to an image.\n\n Add 3 color channels, normalize to 0-1, etc.\n\n Parameters\n ----------\n - `spec` : `torch.Tensor`\n\n Returns\n -------\n - `torch.Tensor`\n \"\"\"\n # make sure the spectrogram is a tensor\n if not isinstance(spec, torch.Tensor):\n msg = (\n \"The spectrogram must be a torch.Tensor.\\n\"\n f\"Type of spectrogram: {type(spec)}\"\n )\n raise TypeError(msg)\n\n # make sure the spectrogram is 2-dimensional\n spec_dims = 2\n if len(spec.size()) != spec_dims:\n msg = (\n \"The spectrogram must be a 2-dimensional matrix.\\n\"\n f\"Shape of spectrogram: {spec.size()}\"\n )\n raise ValueError(msg)\n\n # make sure the spectrogram contains some data\n if (\n np.max(spec.detach().cpu().numpy())\n - np.min(spec.detach().cpu().numpy())\n == 0\n ):\n msg = (\n \"The spectrogram must contain some data.\\n\"\n f\"Max value: {np.max(spec.detach().cpu().numpy())}\\n\"\n f\"Min value: {np.min(spec.detach().cpu().numpy())}\"\n )\n raise ValueError(msg)\n\n # Get the dimensions of the original matrix\n original_shape = spec.size()\n\n # Calculate the number of rows and columns in the matrix\n num_rows, num_cols = original_shape\n\n # duplicate the matrix 3 times\n spec = spec.repeat(3, 1, 1)\n\n # Reshape the matrix to the desired shape (3, num_rows, num_cols)\n desired_shape = (3, num_rows, num_cols)\n reshaped_tensor = spec.view(desired_shape)\n\n # normalize the spectrogram to be between 0 and 1\n normalized_tensor = (reshaped_tensor - reshaped_tensor.min()) / (\n reshaped_tensor.max() - reshaped_tensor.min()\n )\n\n # make sure image is float32\n return normalized_tensor.float()\n
"},{"location":"api/plot_detections/","title":"plot_detections","text":"Functions to visualize detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_all_plots_cli","title":"clean_all_plots_cli(path)
","text":"Remove all plots from the chirpdetections folder.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_all_plots_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def clean_all_plots_cli(path: pathlib.Path) -> None:\n \"\"\"Remove all plots from the chirpdetections folder.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n dirs = [dataset for dataset in path.iterdir() if dataset.is_dir()]\n with prog:\n task = prog.add_task(\"Cleaning plots...\", total=len(dirs))\n for dataset in dirs:\n prog.console.log(f\"Cleaning plots for {dataset.name}\")\n clean_plots_cli(dataset)\n prog.advance(task)\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_plots_cli","title":"clean_plots_cli(path)
","text":"Remove all plots from the chirpdetections folder.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_plots_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def clean_plots_cli(path: pathlib.Path) -> None:\n \"\"\"Remove all plots from the chirpdetections folder.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n savepath = path / \"chirpdetections\"\n for f in savepath.iterdir():\n f.unlink()\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_all_detections_cli","title":"plot_all_detections_cli(path)
","text":"Plot detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_all_detections_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def plot_all_detections_cli(path: pathlib.Path) -> None:\n \"\"\"Plot detections on images.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n conf = load_config(path / \"chirpdetector.toml\")\n\n dirs = [dataset for dataset in path.iterdir() if dataset.is_dir()]\n with prog:\n task = prog.add_task(\"Plotting detections...\", total=len(dirs))\n for dataset in dirs:\n prog.console.log(f\"Plotting detections for {dataset.name}\")\n data = load(dataset)\n chirp_df = pd.read_csv(dataset / \"chirpdetector_bboxes.csv\")\n plot_detections(data, chirp_df, conf)\n prog.advance(task)\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections","title":"plot_detections(data, chirp_df, conf)
","text":"Plot detections on spectrograms.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections--parameters","title":"Parameters","text":"data : Dataset The dataset. chirp_df : pd.DataFrame The dataframe containing the chirp detections. conf : Config The config file.
Source code in chirpdetector/plot_detections.py
def plot_detections(\n data: Dataset,\n chirp_df: pd.DataFrame,\n conf: Config,\n) -> None:\n \"\"\"Plot detections on spectrograms.\n\n Parameters\n ----------\n data : Dataset\n The dataset.\n chirp_df : pd.DataFrame\n The dataframe containing the chirp detections.\n conf : Config\n The config file.\n \"\"\"\n time_window = 15\n n_electrodes = data.grid.rec.shape[1]\n\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = time_window * data.grid.samplerate # samples\n nchunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n window_overlap_samples = int(conf.spec.spec_overlap * data.grid.samplerate)\n\n for chunk_no in range(nchunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = int(chunk_no * chunksize)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == nchunks - 1:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize)\n else:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n chunk = subset(data, idx1, idx2, mode=\"index\")\n\n # dont plot chunks without chirps\n if len(chunk.com.chirp.times) == 0:\n continue\n\n # compute the spectrogram for each electrode of the current chunk\n spec = torch.zeros((len(spec_freqs), len(spec_times)))\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n # normalize spectrogram by the number of electrodes\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n spec = decibel(spec)\n spec = spec.detach().cpu().numpy()\n\n # Set y limits\n flims = (\n np.min(data.track.freqs) - 200,\n np.max(data.track.freqs) + 700,\n )\n spec = spec[(spec_freqs >= flims[0]) & (spec_freqs <= flims[1]), :]\n spec_freqs = spec_freqs[\n (spec_freqs >= flims[0]) & (spec_freqs <= flims[1])\n ]\n\n # Extract the bounding boxes for the current chunk\n chunk_t1 = idx1 / data.grid.samplerate\n chunk_t2 = idx2 / data.grid.samplerate\n chunk_df = chirp_df[\n (chirp_df[\"t1\"] >= chunk_t1) & (chirp_df[\"t2\"] <= chunk_t2)\n ]\n\n # get t1, t2, f1, f2 from chunk_df\n bboxes = chunk_df[[\"score\", \"t1\", \"f1\", \"t2\", \"f2\"]].to_numpy()\n\n # get chirp times and chirp ids\n chirp_times = chunk_df[\"envelope_trough_time\"]\n chirp_ids = chunk_df[\"assigned_track\"]\n\n _, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)\n\n # plot bounding boxes\n ax.imshow(\n spec,\n aspect=\"auto\",\n origin=\"lower\",\n interpolation=\"gaussian\",\n extent=[\n spec_times[0],\n spec_times[-1],\n spec_freqs[0],\n spec_freqs[-1],\n ],\n cmap=\"magma\",\n vmin=-80,\n vmax=-45,\n )\n for bbox in bboxes:\n ax.add_patch(\n Rectangle(\n (bbox[1], bbox[2]),\n bbox[3] - bbox[1],\n bbox[4] - bbox[2],\n fill=False,\n color=\"gray\",\n linewidth=1,\n label=\"faster-R-CNN predictions\",\n ),\n )\n ax.text(\n bbox[1],\n bbox[4] + 15,\n f\"{bbox[0]:.2f}\",\n color=\"gray\",\n fontsize=10,\n verticalalignment=\"bottom\",\n horizontalalignment=\"left\",\n rotation=90,\n )\n\n # plot chirp times and frequency traces\n for track_id in np.unique(data.track.idents):\n ctimes = chirp_times[chirp_ids == track_id]\n\n freqs = data.track.freqs[data.track.idents == track_id]\n times = data.track.times[\n data.track.indices[data.track.idents == track_id]\n ]\n freqs = freqs[\n (times >= spec_times[0] - 10) & (times <= spec_times[-1] + 10)\n ]\n times = times[\n (times >= spec_times[0] - 10) & (times <= spec_times[-1] + 10)\n ]\n\n # get freqs where times are closest to ctimes\n cfreqs = np.zeros_like(ctimes)\n for i, ctime in enumerate(ctimes):\n try:\n indx = np.argmin(np.abs(times - ctime))\n cfreqs[i] = freqs[indx]\n except ValueError:\n msg = (\n \"Failed to find track time closest to chirp time \"\n f\"in chunk {chunk_no}, check the plots.\"\n )\n prog.console.log(msg)\n\n if len(times) != 0:\n ax.plot(\n times,\n freqs,\n lw=2,\n color=\"black\",\n label=\"Frequency traces\",\n )\n\n ax.scatter(\n ctimes,\n cfreqs,\n marker=\"o\",\n lw=1,\n facecolor=\"white\",\n edgecolor=\"black\",\n s=25,\n zorder=10,\n label=\"Chirp assignments\",\n )\n\n ax.set_ylim(flims[0] + 5, flims[1] - 5)\n ax.set_xlim([spec_times[0], spec_times[-1]])\n ax.set_xlabel(\"Time [s]\", fontsize=12)\n ax.set_ylabel(\"Frequency [Hz]\", fontsize=12)\n\n handles, labels = plt.gca().get_legend_handles_labels()\n by_label = dict(zip(labels, handles))\n plt.legend(\n by_label.values(),\n by_label.keys(),\n bbox_to_anchor=(0.5, 1.02),\n loc=\"lower center\",\n mode=\"None\",\n borderaxespad=0,\n ncol=3,\n fancybox=False,\n framealpha=0,\n )\n\n savepath = data.path / \"chirpdetections\"\n savepath.mkdir(exist_ok=True)\n plt.savefig(\n savepath / f\"cpd_{chunk_no}.png\",\n dpi=300,\n bbox_inches=\"tight\",\n )\n\n plt.close()\n plt.clf()\n plt.cla()\n plt.close(\"all\")\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections_cli","title":"plot_detections_cli(path)
","text":"Plot detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def plot_detections_cli(path: pathlib.Path) -> None:\n \"\"\"Plot detections on images.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n conf = load_config(path.parent / \"chirpdetector.toml\")\n data = load(path)\n chirp_df = pd.read_csv(path / \"chirpdetector_bboxes.csv\")\n plot_detections(data, chirp_df, conf)\n
"},{"location":"api/train_model/","title":"train_model","text":""},{"location":"api/train_model/#chirpdetector.train_model--train-the-faster-r-cnn-model","title":"Train the faster-R-CNN model.","text":"Train and test the neural network specified in the config file.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs","title":"plot_epochs(epoch_train_loss, epoch_val_loss, epoch_avg_train_loss, epoch_avg_val_loss, path)
","text":"Plot the loss for each epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs--parameters","title":"Parameters","text":" epoch_train_loss
: list
The training loss for each epoch. epoch_val_loss
: list
The validation loss for each epoch. epoch_avg_train_loss
: list
The average training loss for each epoch. epoch_avg_val_loss
: list
The average validation loss for each epoch. path
: pathlib.Path
The path to save the plot to.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def plot_epochs(\n epoch_train_loss: list,\n epoch_val_loss: list,\n epoch_avg_train_loss: list,\n epoch_avg_val_loss: list,\n path: pathlib.Path,\n) -> None:\n \"\"\"Plot the loss for each epoch.\n\n Parameters\n ----------\n - `epoch_train_loss`: `list`\n The training loss for each epoch.\n - `epoch_val_loss`: `list`\n The validation loss for each epoch.\n - `epoch_avg_train_loss`: `list`\n The average training loss for each epoch.\n - `epoch_avg_val_loss`: `list`\n The average validation loss for each epoch.\n - `path`: `pathlib.Path`\n The path to save the plot to.\n\n Returns\n -------\n - `None`\n \"\"\"\n _, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)\n\n x_train = np.arange(len(epoch_train_loss[0])) + 1\n x_val = np.arange(len(epoch_val_loss[0])) + len(epoch_train_loss[0]) + 1\n\n for train_loss, val_loss in zip(epoch_train_loss, epoch_val_loss):\n ax[0].plot(x_train, train_loss, c=\"tab:blue\", label=\"_\")\n ax[0].plot(x_val, val_loss, c=\"tab:orange\", label=\"_\")\n x_train = np.arange(len(epoch_train_loss[0])) + x_val[-1]\n x_val = np.arange(len(epoch_val_loss[0])) + x_train[-1]\n\n x_avg = np.arange(len(epoch_avg_train_loss)) + 1\n ax[1].plot(\n x_avg,\n epoch_avg_train_loss,\n label=\"Training Loss\",\n c=\"tab:blue\",\n )\n ax[1].plot(\n x_avg,\n epoch_avg_val_loss,\n label=\"Validation Loss\",\n c=\"tab:orange\",\n )\n\n ax[0].set_ylabel(\"Loss\")\n ax[0].set_xlabel(\"Batch\")\n ax[0].set_ylim(bottom=0)\n ax[0].set_title(\"Loss per batch\")\n\n ax[1].set_ylabel(\"Loss\")\n ax[1].set_xlabel(\"Epoch\")\n ax[1].legend()\n ax[1].set_ylim(bottom=0)\n ax[1].set_title(\"Avg loss per epoch\")\n\n plt.savefig(path)\n plt.close()\n
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds","title":"plot_folds(fold_avg_train_loss, fold_avg_val_loss, path)
","text":"Plot the loss for each fold.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds--parameters","title":"Parameters","text":" fold_avg_train_loss
: list
The average training loss for each fold. fold_avg_val_loss
: list
The average validation loss for each fold. path
: pathlib.Path
The path to save the plot to.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def plot_folds(\n fold_avg_train_loss: list,\n fold_avg_val_loss: list,\n path: pathlib.Path,\n) -> None:\n \"\"\"Plot the loss for each fold.\n\n Parameters\n ----------\n - `fold_avg_train_loss`: `list`\n The average training loss for each fold.\n - `fold_avg_val_loss`: `list`\n The average validation loss for each fold.\n - `path`: `pathlib.Path`\n The path to save the plot to.\n\n Returns\n -------\n - `None`\n \"\"\"\n _, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)\n\n for train_loss, val_loss in zip(fold_avg_train_loss, fold_avg_val_loss):\n x = np.arange(len(train_loss)) + 1\n ax.plot(x, train_loss, c=\"tab:blue\", alpha=0.3, label=\"_\")\n ax.plot(x, val_loss, c=\"tab:orange\", alpha=0.3, label=\"_\")\n\n avg_train = np.mean(fold_avg_train_loss, axis=0)\n avg_val = np.mean(fold_avg_val_loss, axis=0)\n x = np.arange(len(avg_train)) + 1\n ax.plot(\n x,\n avg_train,\n label=\"Training Loss\",\n c=\"tab:blue\",\n )\n ax.plot(\n x,\n avg_val,\n label=\"Validation Loss\",\n c=\"tab:orange\",\n )\n\n ax.set_ylabel(\"Loss\")\n ax.set_xlabel(\"Epoch\")\n ax.legend()\n ax.set_ylim(bottom=0)\n\n plt.savefig(path)\n plt.close()\n
"},{"location":"api/train_model/#chirpdetector.train_model.save_model","title":"save_model(epoch, model, optimizer, path)
","text":"Save the model state dict.
"},{"location":"api/train_model/#chirpdetector.train_model.save_model--parameters","title":"Parameters","text":" epoch
: int
The current epoch. model
: torch.nn.Module
The model to save. optimizer
: torch.optim.Optimizer
The optimizer to save. path
: str
The path to save the model to.
"},{"location":"api/train_model/#chirpdetector.train_model.save_model--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def save_model(\n epoch: int,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n path: str,\n) -> None:\n \"\"\"Save the model state dict.\n\n Parameters\n ----------\n - `epoch`: `int`\n The current epoch.\n - `model`: `torch.nn.Module`\n The model to save.\n - `optimizer`: `torch.optim.Optimizer`\n The optimizer to save.\n - `path`: `str`\n The path to save the model to.\n\n Returns\n -------\n - `None`\n \"\"\"\n path = pathlib.Path(path)\n path.mkdir(parents=True, exist_ok=True)\n torch.save(\n {\n \"epoch\": epoch,\n \"model_state_dict\": model.state_dict(),\n \"optimizer_state_dict\": optimizer.state_dict(),\n },\n path / \"model.pt\",\n )\n
"},{"location":"api/train_model/#chirpdetector.train_model.train","title":"train(config, mode='pretrain')
","text":"Train the model.
"},{"location":"api/train_model/#chirpdetector.train_model.train--parameters","title":"Parameters","text":" config
: Config
The config file. mode
: str
The mode to train in. Either pretrain
or finetune
.
"},{"location":"api/train_model/#chirpdetector.train_model.train--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def train(config: Config, mode: str = \"pretrain\") -> None:\n \"\"\"Train the model.\n\n Parameters\n ----------\n - `config`: `Config`\n The config file.\n - `mode`: `str`\n The mode to train in. Either `pretrain` or `finetune`.\n\n Returns\n -------\n - `None`\n \"\"\"\n # Load a pretrained model from pytorch if in pretrain mode,\n # otherwise open an already trained model from the\n # model state dict.\n assert mode in [\"pretrain\", \"finetune\"]\n if mode == \"pretrain\":\n assert config.train.datapath is not None\n datapath = config.train.datapath\n elif mode == \"finetune\":\n assert config.finetune.datapath is not None\n datapath = config.finetune.datapath\n\n # Check if the path to the data actually exists\n if not pathlib.Path(datapath).exists():\n raise FileNotFoundError(f\"Path {datapath} does not exist.\")\n\n # Initialize the logger and progress bar, make the logger global\n global logger\n logger = make_logger(\n __name__,\n pathlib.Path(config.path).parent / \"chirpdetector.log\",\n )\n\n # Get the device (e.g. GPU or CPU)\n device = get_device()\n\n # Print information about starting training\n progress.console.rule(\"Starting training\")\n msg = (\n f\"Device: {device}, Config: {config.path},\"\n f\" Mode: {mode}, Data: {datapath}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n # initialize the dataset\n data = CustomDataset(\n path=datapath,\n classes=config.hyper.classes,\n )\n\n # initialize the k-fold cross-validation\n splits = KFold(n_splits=config.hyper.kfolds, shuffle=True, random_state=42)\n\n # initialize the best validation loss to a large number\n best_val_loss = float(\"inf\")\n\n # iterate over the folds for k-fold cross-validation\n with progress:\n # save loss across all epochs and folds\n fold_train_loss = []\n fold_val_loss = []\n fold_avg_train_loss = []\n fold_avg_val_loss = []\n\n # Add kfolds progress bar that runs alongside the epochs progress bar\n task_folds = progress.add_task(\n f\"[blue]{config.hyper.kfolds}-Fold Crossvalidation\",\n total=config.hyper.kfolds,\n )\n\n # iterate over the folds\n for fold, (train_idx, val_idx) in enumerate(\n splits.split(np.arange(len(data))),\n ):\n # initialize the model and optimizer\n model = load_fasterrcnn(num_classes=len(config.hyper.classes)).to(\n device,\n )\n\n # If the mode is finetune, load the model state dict from\n # previous training\n if mode == \"finetune\":\n modelpath = pathlib.Path(config.hyper.modelpath) / \"model.pt\"\n checkpoint = torch.load(modelpath, map_location=device)\n model.load_state_dict(checkpoint[\"model_state_dict\"])\n\n # Initialize stochastic gradient descent optimizer\n params = [p for p in model.parameters() if p.requires_grad]\n optimizer = torch.optim.SGD(\n params,\n lr=config.hyper.learning_rate,\n momentum=config.hyper.momentum,\n weight_decay=config.hyper.weight_decay,\n )\n\n # make train and validation dataloaders for the current fold\n train_data = torch.utils.data.Subset(data, train_idx)\n val_data = torch.utils.data.Subset(data, val_idx)\n\n # this is for training\n train_loader = DataLoader(\n train_data,\n batch_size=config.hyper.batch_size,\n shuffle=True,\n num_workers=config.hyper.num_workers,\n collate_fn=collate_fn,\n )\n\n # this is only for validation\n val_loader = DataLoader(\n val_data,\n batch_size=config.hyper.batch_size,\n shuffle=True,\n num_workers=config.hyper.num_workers,\n collate_fn=collate_fn,\n )\n\n # save loss across all epochs\n epoch_avg_train_loss = []\n epoch_avg_val_loss = []\n epoch_train_loss = []\n epoch_val_loss = []\n\n # train the model for the specified number of epochs\n task_epochs = progress.add_task(\n f\"{config.hyper.num_epochs} Epochs for fold k={fold + 1}\",\n total=config.hyper.num_epochs,\n )\n\n # iterate across n epochs\n for epoch in range(config.hyper.num_epochs):\n # print information about the current epoch\n msg = (\n f\"Training epoch {epoch + 1} of {config.hyper.num_epochs} \"\n f\"for fold {fold + 1} of {config.hyper.kfolds}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n # train the epoch\n train_loss = train_epoch(\n dataloader=train_loader,\n device=device,\n model=model,\n optimizer=optimizer,\n )\n\n # validate the epoch\n _, val_loss = val_epoch(\n dataloader=val_loader,\n device=device,\n model=model,\n )\n\n # save losses for this epoch\n epoch_train_loss.append(train_loss)\n epoch_val_loss.append(val_loss)\n\n # save the average loss for this epoch\n epoch_avg_train_loss.append(np.median(train_loss))\n epoch_avg_val_loss.append(np.median(val_loss))\n\n # save the model if it is the best so far\n if np.mean(val_loss) < best_val_loss:\n best_val_loss = sum(val_loss) / len(val_loss)\n\n msg = (\n f\"New best validation loss: {best_val_loss:.4f}, \"\n \"saving model...\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n save_model(\n epoch=epoch,\n model=model,\n optimizer=optimizer,\n path=config.hyper.modelpath,\n )\n\n # plot the losses for this epoch\n plot_epochs(\n epoch_train_loss=epoch_train_loss,\n epoch_val_loss=epoch_val_loss,\n epoch_avg_train_loss=epoch_avg_train_loss,\n epoch_avg_val_loss=epoch_avg_val_loss,\n path=pathlib.Path(config.hyper.modelpath)\n / f\"fold{fold + 1}.png\",\n )\n\n # update the progress bar for the epochs\n progress.update(task_epochs, advance=1)\n\n # update the progress bar for the epochs and hide it if done\n progress.update(task_epochs, visible=False)\n\n # save the losses for this fold\n fold_train_loss.append(epoch_train_loss)\n fold_val_loss.append(epoch_val_loss)\n fold_avg_train_loss.append(epoch_avg_train_loss)\n fold_avg_val_loss.append(epoch_avg_val_loss)\n\n plot_folds(\n fold_avg_train_loss=fold_avg_train_loss,\n fold_avg_val_loss=fold_avg_val_loss,\n path=pathlib.Path(config.hyper.modelpath) / \"losses.png\",\n )\n\n # update the progress bar for the folds\n progress.update(task_folds, advance=1)\n\n # update the progress bar for the folds and hide it if done\n progress.update(task_folds, visible=False)\n\n # print information about the training\n msg = (\n \"Average validation loss of last epoch across folds: \"\n f\"{np.mean(fold_val_loss):.4f}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n progress.console.rule(\"[bold blue]Finished training\")\n
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli","title":"train_cli(config_path, mode)
","text":"Train the model from the command line.
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli--parameters","title":"Parameters","text":" config_path
: pathlib.Path
The path to the config file. mode
: str
The mode to train in. Either pretrain
or finetune
.
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def train_cli(config_path: pathlib.Path, mode: str) -> None:\n \"\"\"Train the model from the command line.\n\n Parameters\n ----------\n - `config_path`: `pathlib.Path`\n The path to the config file.\n - `mode`: `str`\n The mode to train in. Either `pretrain` or `finetune`.\n\n Returns\n -------\n - `None`\n \"\"\"\n config = load_config(config_path)\n train(config, mode=mode)\n
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch","title":"train_epoch(dataloader, device, model, optimizer)
","text":"Train the model for one epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch--parameters","title":"Parameters","text":" dataloader
: DataLoader
The dataloader for the training data. device
: torch.device
The device to train on. model
: torch.nn.Module
The model to train. optimizer
: torch.optim.Optimizer
The optimizer to use.
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch--returns","title":"Returns","text":" train_loss
: List
The training loss for each batch.
Source code in chirpdetector/train_model.py
def train_epoch(\n dataloader: DataLoader,\n device: torch.device,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n) -> List:\n \"\"\"Train the model for one epoch.\n\n Parameters\n ----------\n - `dataloader`: `DataLoader`\n The dataloader for the training data.\n - `device`: `torch.device`\n The device to train on.\n - `model`: `torch.nn.Module`\n The model to train.\n - `optimizer`: `torch.optim.Optimizer`\n The optimizer to use.\n\n Returns\n -------\n - `train_loss`: `List`\n The training loss for each batch.\n \"\"\"\n train_loss = []\n\n for samples, targets in dataloader:\n images = list(sample.to(device) for sample in samples)\n targets = [\n {k: v.to(device) for k, v in t.items() if k != \"image_name\"}\n for t in targets\n ]\n\n loss_dict = model(images, targets)\n losses = sum(loss for loss in loss_dict.values())\n train_loss.append(losses.item())\n\n optimizer.zero_grad()\n losses.backward()\n optimizer.step()\n\n return train_loss\n
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch","title":"val_epoch(dataloader, device, model)
","text":"Validate the model for one epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch--parameters","title":"Parameters","text":" dataloader
: DataLoader
The dataloader for the validation data. device
: torch.device
The device to train on. model
: torch.nn.Module
The model to train.
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch--returns","title":"Returns","text":" loss_dict
: dict
The loss dictionary.
Source code in chirpdetector/train_model.py
def val_epoch(\n dataloader: DataLoader,\n device: torch.device,\n model: torch.nn.Module,\n) -> List:\n \"\"\"Validate the model for one epoch.\n\n Parameters\n ----------\n - `dataloader`: `DataLoader`\n The dataloader for the validation data.\n - `device`: `torch.device`\n The device to train on.\n - `model`: `torch.nn.Module`\n The model to train.\n\n Returns\n -------\n - `loss_dict`: `dict`\n The loss dictionary.\n \"\"\"\n val_loss = []\n for samples, targets in dataloader:\n images = list(sample.to(device) for sample in samples)\n targets = [\n {k: v.to(device) for k, v in t.items() if k != \"image_name\"}\n for t in targets\n ]\n\n with torch.inference_mode():\n loss_dict = model(images, targets)\n\n losses = sum(loss for loss in loss_dict.values())\n val_loss.append(losses.item())\n\n return loss_dict, val_loss\n
"}]}
\ No newline at end of file
+{"config":{"lang":["en"],"separator":"[\\s\\-]+","pipeline":["stopWordFilter"]},"docs":[{"location":"","title":"Introduction","text":" Chirpdetector
Detect communication signals of electric fish using deep neural networks \ud83d\udc1f\u26a1\ud83e\udde0 This project is still work in progress and will approximately be released in spring of 2024.
Why? \ud83e\udd28
Chirps are by far the most thoroughly researched communication signal of electric, probably even all fish. But detecting chirps becomes hard when more than one fish is recorded. As a result, most of the research to date analyzes this signal in isolated individuals. This is not good.
To tackle this isse, this package provides a simple toolbox to detect chirps of multiple fish on spectrograms. This enables true quantitative analyses of chirping between freely behaving fish for the first time.
"},{"location":"assingment/","title":"Assingment","text":"Wow, such empty
"},{"location":"contributing/","title":"Contributing","text":"We are thrilled to have you join in making this project even better. Please feel free to browse through the resources and guidelines provided here, and let us know if there is anything specific you would like to contribute or discuss.
If you would like to help to develop this package you can skim through the to-do list below as well as the contribution guidelines. Just fork the project, add your code and send a pull request. We are always happy to get some help !
If you encountered an issue using the chirpdetector
, feel free to open an issue here.
"},{"location":"contributing/#contributors-guidelines","title":"Contributors guidelines","text":"I try our best to adhere to good coding practices and catch up on writing tests for this package. As I am currently the only one working on it, here is some documentation of the development packages I use:
pre-commit
for pre-commit hooks pytest
and pytest-coverage
for unit tests ruff
for linting and formatting pyright
for static type checking
Before every commit, a pre-commit hook runs all these packages on the code base and refuses a push if errors are raised. If you want to contribute, please make sure that your code is proberly formatted and run the tests before issuing a pull request. The formatting guidelines should be automatically picked up by your ruff
installaton from the pyproject.toml
file.
"},{"location":"contributing/#to-do","title":"To Do","text":"After the first release, this section will be removed an tasks will be organized as github issues. Until them, if you fixed something, please check it off on this list before opening a pull request.
- Refactor train, detect, convert. All into much smaller functions. Move accesory functions to utils
- Move hardcoded params from assignment algo into config.toml
- Resolve all pylint and mypy errors and warnings.. and ruff warnings ... etc
- Fix make test, fails after ruff run
- Build github actions CI/CD pipeline for codecov etc.
- Move the dataconverter from
gridtools
to chirpdetector
- Extend the dataconverter to just output the spectrograms so that hand-labelling can be done in a separate step
- Add a main script so that the cli is
chirpdetector <task> --<flag> <args>
- Improve simulation of chirps to include more realistic noise, undershoot and maybe even phasic-tonic evolution of the frequency of the big chirps
- make the
copyconfig
script more - start writing the chirp assignment algorithm
- Move all the pprinting and logging constructors to a separate module and build a unified console object so that saving logs to file is easier, also log to file as well
- Split the messy training loop into functions
- Add label-studio
- Supply scripts to convert completely unannotated or partially annotated data to the label-studio format to make manual labeling easier
- Make possible to output detections as a yolo dataset
- Look up how to convert a yolo dataset to a label-studio input so we can label pre-annotated data, facilitating a full human-in-the-loop approach
- Add augmentation transforms to the dataset class and add augmentations to the simulation in
gridtools
. Note to this: Unnessecary, using real data. - Change bbox to actual yolo format, not the weird one I made up (which is x1, y1, x2, y2 instead of x1, y1, w, h). This is why the label-studio export is not working.
- Port cli to click, works better
- Try clustering the detected chirp windows on a spectrogram, could be interesting
"},{"location":"dataset/","title":"Creating a dataset","text":"Wow, such empty
"},{"location":"demo/","title":"Detecting chirps with a few terminal commands","text":"Once everything is set up correctly, detecting chirps is a breeze. The terminal utility can be called by chirpdetector
or simply cpd
.
Simply run
cpd detect --path \"/path/to/dataset\"\n
And the bounding boxes will be computed and saved to a .csv
file. Then run cpd assign --path \"/path/to/dataset\"\n
to assing each detected chirp to a fundamental frequency of a fish. The results will be added to the .csv
file in the dataset. To check if this went well, you can run cpd plot --path \"/path/to/dataset\"\n
And the spectrograms, bounding boxes, and assigned chirps of all the detected chirps will be plotted and saved as .png
images into a subfolder of your dataset. The result will look something like this:
15 seconds of a recording containing two chirping fish with bounding boxes around chirps and dots indicating to which frequency they are assigned to.
"},{"location":"detection/","title":"Detection","text":"Wow, such empty
"},{"location":"how_it_works/","title":"How it works","text":" How? \ud83e\udd14
Chirps manifest as excursions in the electric organ discharge frequency. To discern the individual chirps in a recording featuring multiple fish separated solely by frequency, we delve into the frequency domain. This involves the computation of spectrograms, ensuring ample temporal resolution for chirp distinction and sufficient frequency resolution for fish differentiation. The outcome is a series of images.
This framework facilitates the application of potent computer vision algorithms, such as a faster-R-CNN, for the detection of objects like chirps within these 'images.' Each chirp detection yields a bounding box, a motif echoed in the package's logo.
Post-processing steps refine the results, assigning chirp times to the fundamental frequencies of each fish captured in the recording.
Still not sold? Check out the demo \u00bb
"},{"location":"installation/","title":"Installation","text":"Wow, such empty
"},{"location":"labeling/","title":"Labeling a dataset","text":"Wow, such empty
"},{"location":"setup/","title":"Setup","text":"Wow, such empty
"},{"location":"training/","title":"Training","text":"Wow, such empty
"},{"location":"visualization/","title":"Visualization","text":"Wow, such empty
"},{"location":"yolo-helpers/","title":"Helper commands","text":"Wow, such empty
"},{"location":"api/assign_chirps/","title":"assign_chirps","text":"Assign chirps detected on a spectrogram to wavetracker tracks.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_chirps","title":"assign_chirps(assign_data, chirp_df, data)
","text":"Assign chirps to wavetracker tracks.
This function uses the extracted envelope troughs to assign chirps to tracks. It computes a cost function that is high when the trough prominence is high and the distance to the chirp center is low. For each chirp, the track with the highest cost function value is chosen.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_chirps--parameters","title":"Parameters","text":" assign_data
: dict
Dictionary containing the data needed for assignment chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data
Source code in chirpdetector/assign_chirps.py
def assign_chirps(\n assign_data: Dict[str, np.ndarray],\n chirp_df: pd.DataFrame,\n data: Dataset,\n) -> None:\n \"\"\"Assign chirps to wavetracker tracks.\n\n This function uses the extracted envelope troughs to assign chirps to\n tracks. It computes a cost function that is high when the trough prominence\n is high and the distance to the chirp center is low. For each chirp, the\n track with the highest cost function value is chosen.\n\n Parameters\n ----------\n - `assign_data`: `dict`\n Dictionary containing the data needed for assignment\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n \"\"\"\n # extract data from assign_data\n peak_prominences = assign_data[\"proms\"]\n peak_distances = assign_data[\"peaks\"]\n peak_times = assign_data[\"ptimes\"]\n chirp_indices = assign_data[\"cindices\"]\n track_ids = assign_data[\"track_ids\"]\n envs = assign_data[\"envs\"]\n\n # compute cost function.\n # this function is high when the trough prominence is high\n # (-> chirp with high contrast)\n # and when the trough is close to the chirp center as detected by the\n # r-cnn (-> detected chirp is close to the actual chirp)\n cost = peak_prominences / peak_distances**2\n\n # set cost to zero for cases where no peak was found\n cost[np.isnan(cost)] = 0\n\n # for each chirp, choose the track where the cost is highest\n # TODO: to avoid confusion make a cost function where high is good and low\n # is bad. this is more like a \"gain function\"\n chosen_tracks = []\n chosen_track_times = []\n chirp_envs = []\n non_chirp_envs = []\n for idx in np.unique(chirp_indices):\n candidate_tracks = track_ids[chirp_indices == idx]\n candidate_costs = cost[chirp_indices == idx]\n candidate_times = peak_times[chirp_indices == idx]\n chosen_tracks.append(candidate_tracks[np.argmax(candidate_costs)])\n chosen_track_times.append(candidate_times[np.argmax(candidate_costs)])\n # TODO: Save envs do disk for plotting\n\n # store chosen tracks in chirp_df\n chirp_df[\"assigned_track\"] = chosen_tracks\n\n # store chirp time estimated from envelope trough in chirp_df\n chirp_df[\"envelope_trough_time\"] = chosen_track_times\n\n # save chirp_df\n chirp_df.to_csv(data.path / \"chirpdetector_bboxes.csv\", index=False)\n\n # save old format:\n np.save(data.path / \"chirp_ids_rcnn.npy\", chosen_tracks)\n np.save(data.path / \"chirp_times_rcnn.npy\", chosen_track_times)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_cli","title":"assign_cli(path)
","text":"Assign chirps to wavetracker tracks.
this is the command line interface for the assign_chirps function.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.assign_cli--parameters","title":"Parameters","text":" path
: pathlib.path
path to the directory containing the chirpdetector.toml file
Source code in chirpdetector/assign_chirps.py
def assign_cli(path: pathlib.Path) -> None:\n \"\"\"Assign chirps to wavetracker tracks.\n\n this is the command line interface for the assign_chirps function.\n\n Parameters\n ----------\n - `path`: `pathlib.path`\n path to the directory containing the chirpdetector.toml file\n \"\"\"\n if not path.is_dir():\n msg = f\"{path} is not a directory\"\n raise ValueError(msg)\n\n if not (path / \"chirpdetector.toml\").is_file():\n msg = f\"{path} does not contain a chirpdetector.toml file\"\n raise ValueError(msg)\n\n logger = make_logger(__name__, path / \"chirpdetector.log\")\n # config = load_config(path / \"chirpdetector.toml\")\n recs = list(path.iterdir())\n recs = [r for r in recs if r.is_dir()]\n # recs = [path / \"subset_2020-03-18-10_34_t0_9320.0_t1_9920.0\"]\n\n msg = f\"found {len(recs)} recordings in {path}, starting assignment\"\n prog.console.log(msg)\n logger.info(msg)\n\n prog.console.rule(\"starting assignment\")\n with prog:\n task = prog.add_task(\"assigning chirps\", total=len(recs))\n for rec in recs:\n msg = f\"assigning chirps in {rec}\"\n logger.info(msg)\n prog.console.log(msg)\n\n data = load(rec)\n chirp_df = pd.read_csv(rec / \"chirpdetector_bboxes.csv\")\n assign_data, chirp_df, data = extract_assignment_data(\n data, chirp_df\n )\n assign_chirps(assign_data, chirp_df, data)\n prog.update(task, advance=1)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes","title":"bbox_to_chirptimes(chirp_df)
","text":"Convert chirp bboxes to chirp times.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
dataframe containing the chirp bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.bbox_to_chirptimes--returns","title":"Returns","text":" chirp_df
: pd.dataframe
dataframe containing the chirp bboxes with chirp times.
Source code in chirpdetector/assign_chirps.py
def bbox_to_chirptimes(chirp_df: pd.DataFrame) -> pd.DataFrame:\n \"\"\"Convert chirp bboxes to chirp times.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n dataframe containing the chirp bboxes\n\n Returns\n -------\n - `chirp_df`: `pd.dataframe`\n dataframe containing the chirp bboxes with chirp times.\n \"\"\"\n chirp_df[\"chirp_times\"] = np.mean(chirp_df[[\"t1\", \"t2\"]], axis=1)\n\n return chirp_df\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes","title":"clean_bboxes(data, chirp_df)
","text":"Clean the chirp bboxes.
This is a collection of filters that remove bboxes that either overlap, are out of range or otherwise do not make sense.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes--parameters","title":"Parameters","text":" data
: gridtools.datasets.Dataset
Dataset object containing the data chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.clean_bboxes--returns","title":"Returns","text":" chirp_df_tf
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def clean_bboxes(data: Dataset, chirp_df: pd.DataFrame) -> pd.DataFrame:\n \"\"\"Clean the chirp bboxes.\n\n This is a collection of filters that remove bboxes that\n either overlap, are out of range or otherwise do not make sense.\n\n Parameters\n ----------\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n\n Returns\n -------\n - `chirp_df_tf`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # non-max suppression: remove all chirp bboxes that overlap with\n # another more than threshold\n pick_indices = non_max_suppression_fast(chirp_df, 0.5)\n chirp_df_nms = chirp_df.loc[pick_indices, :]\n\n # track filter: remove all chirp bboxes that do not overlap with\n # the range spanned by the min and max of the wavetracker frequency tracks\n minf = np.min(data.track.freqs).astype(float)\n maxf = np.max(data.track.freqs).astype(float)\n # maybe add some more cleaning here, such\n # as removing chirps that are too short or too long\n return remove_non_overlapping_boxes(chirp_df_nms, minf, maxf)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup","title":"cleanup(chirp_df, data)
","text":"Clean the chirp bboxes.
This is a collection of filters that remove bboxes that either overlap, are out of range or otherwise do not make sense.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.cleanup--returns","title":"Returns","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def cleanup(chirp_df: pd.DataFrame, data: Dataset) -> pd.DataFrame:\n \"\"\"Clean the chirp bboxes.\n\n This is a collection of filters that remove bboxes that\n either overlap, are out of range or otherwise do not make sense.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n\n Returns\n -------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # first clean the bboxes\n chirp_df = clean_bboxes(data, chirp_df)\n # sort chirps in df by time, i.e. t1\n chirp_df = chirp_df.sort_values(by=\"t1\", ascending=True)\n # compute chirp times, i.e. center of the bbox x axis\n return bbox_to_chirptimes(chirp_df)\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_assignment_data","title":"extract_assignment_data(data, chirp_df)
","text":"Get envelope troughs to determine chirp assignment.
This algorigthm assigns chirps to wavetracker tracks by a series of steps: 1. clean the chirp bboxes 2. for each fish track, filter the signal on the best electrode 3. find troughs in the envelope of the filtered signal 4. compute the prominence of the trough and the distance to the chirp center 5. compute a cost function that is high when the trough prominence is high and the distance to the chirp center is low 6. compare the value of the cost function for each track and choose the track with the highest cost function value
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_assignment_data--parameters","title":"Parameters","text":" data
: dataset
Dataset object containing the data chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes
Source code in chirpdetector/assign_chirps.py
def extract_assignment_data(\n data: Dataset, chirp_df: pd.DataFrame\n) -> Tuple[Dict[str, np.ndarray], pd.DataFrame, Dataset]:\n \"\"\"Get envelope troughs to determine chirp assignment.\n\n This algorigthm assigns chirps to wavetracker tracks by a series of steps:\n 1. clean the chirp bboxes\n 2. for each fish track, filter the signal on the best electrode\n 3. find troughs in the envelope of the filtered signal\n 4. compute the prominence of the trough and the distance to the chirp\n center\n 5. compute a cost function that is high when the trough prominence is high\n and the distance to the chirp center is low\n 6. compare the value of the cost function for each track and choose the\n track with the highest cost function value\n\n Parameters\n ----------\n - `data`: `dataset`\n Dataset object containing the data\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n \"\"\"\n # clean the chirp bboxes\n chirp_df = cleanup(chirp_df, data)\n\n # now loop over all tracks and assign chirps to tracks\n chirp_indices = [] # index of chirp in chirp_df\n track_ids = [] # id of track / fish\n peak_prominences = [] # prominence of trough in envelope\n peak_distances = [] # distance of trough to chirp center\n peak_times = [] # time of trough in envelope, should be close to chirp\n envs = [] # envelope of filtered signal\n\n for fish_id in data.track.ids:\n # get chirps, times and freqs and powers for this track\n chirps = np.array(chirp_df.chirp_times.values)\n time = data.track.times[\n data.track.indices[data.track.idents == fish_id]\n ]\n freq = data.track.freqs[data.track.idents == fish_id]\n powers = data.track.powers[data.track.idents == fish_id, :]\n\n if len(time) == 0:\n continue # skip if no track is found\n\n for idx, chirp in enumerate(chirps):\n # find the closest time, freq and power to the chirp time\n closest_idx = np.argmin(np.abs(time - chirp))\n best_electrode = np.argmax(powers[closest_idx, :]).astype(int)\n second_best_electrode = np.argsort(powers[closest_idx, :])[-2]\n best_freq = freq[closest_idx]\n\n # check if chirp overlaps with track\n f1 = chirp_df.f1.to_numpy()[idx]\n f2 = chirp_df.f2.to_numpy()[idx]\n f2 = f1 + (f2 - f1) * 0.5 # range is the lower half of the bbox\n if (f1 > best_freq) or (f2 < best_freq):\n peak_distances.append(np.nan)\n peak_prominences.append(np.nan)\n peak_times.append(np.nan)\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n continue\n\n # determine start and stop index of time window on raw data\n # using bounding box start and stop times of chirp detection\n start_idx, stop_idx, center_idx = make_indices(\n chirp_df, data, idx, chirp\n )\n\n indices = (start_idx, stop_idx, center_idx)\n peaks, proms, env = extract_envelope_trough(\n data,\n best_electrode,\n second_best_electrode,\n best_freq,\n indices,\n )\n\n # if no peaks are found, skip this chirp\n if len(peaks) == 0:\n peak_distances.append(np.nan)\n peak_prominences.append(np.nan)\n peak_times.append(np.nan)\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n continue\n\n # compute index to closest peak to chirp center\n distances = np.abs(peaks - (center_idx - start_idx))\n closest_peak_idx = np.argmin(distances)\n\n # store peak prominence and distance to chirp center\n peak_distances.append(distances[closest_peak_idx])\n peak_prominences.append(proms[closest_peak_idx])\n peak_times.append(\n (start_idx + peaks[closest_peak_idx]) / data.grid.samplerate,\n )\n chirp_indices.append(idx)\n track_ids.append(fish_id)\n envs.append(env)\n\n peak_prominences = np.array(peak_prominences)\n peak_distances = (\n np.array(peak_distances) + 1\n ) # add 1 to avoid division by zero\n peak_times = np.array(peak_times)\n chirp_indices = np.array(chirp_indices)\n track_ids = np.array(track_ids)\n envs = np.array(envs)\n\n assignment_data = {\n \"proms\": peak_prominences,\n \"peaks\": peak_distances,\n \"ptimes\": peak_times,\n \"cindices\": chirp_indices,\n \"track_ids\": track_ids,\n \"envs\": envs,\n }\n return (\n assignment_data,\n chirp_df,\n data,\n )\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough","title":"extract_envelope_trough(data, best_electrode, second_best_electrode, best_freq, indices)
","text":"Extract envelope troughs.
Extracts a snippet from the raw data around the chirp time and computes the envelope of the bandpass filtered signal. Then finds the troughs in the envelope and computes their prominences.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough--parameters","title":"Parameters","text":" data
: gridtools.datasets.Dataset
Dataset object containing the data best_electrode
: int
Index of the best electrode second_best_electrode
: int
Index of the second best electrode best_freq
: float
Frequency of the chirp indices
: Tuple[int, int, int]
Tuple containing the start, center, stop indices of the chirp
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.extract_envelope_trough--returns","title":"Returns","text":" peaks
: np.ndarray
Indices of the envelope troughs proms
: np.ndarray
Prominences of the envelope troughs env
: np.ndarray
Envelope of the filtered signal
Source code in chirpdetector/assign_chirps.py
def extract_envelope_trough(\n data: Dataset,\n best_electrode: int,\n second_best_electrode: int,\n best_freq: float,\n indices: Tuple[int, int, int],\n) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\n \"\"\"Extract envelope troughs.\n\n Extracts a snippet from the raw data around the chirp time and computes\n the envelope of the bandpass filtered signal. Then finds the troughs in\n the envelope and computes their prominences.\n\n Parameters\n ----------\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `best_electrode`: `int`\n Index of the best electrode\n - `second_best_electrode`: `int`\n Index of the second best electrode\n - `best_freq`: `float`\n Frequency of the chirp\n - `indices`: `Tuple[int, int, int]`\n Tuple containing the start, center, stop indices of the chirp\n\n Returns\n -------\n - `peaks`: `np.ndarray`\n Indices of the envelope troughs\n - `proms`: `np.ndarray`\n Prominences of the envelope troughs\n - `env`: `np.ndarray`\n Envelope of the filtered signal\n \"\"\"\n start_idx, stop_idx, _= indices\n\n # determine bandpass cutoffs above and below baseline frequency\n lower_f = best_freq - 15\n upper_f = best_freq + 15\n\n # get the raw signal on the 2 best electrodes and make differential\n raw1 = data.grid.rec[start_idx:stop_idx, best_electrode]\n raw2 = data.grid.rec[start_idx:stop_idx, second_best_electrode]\n raw = raw1 - raw2\n\n # bandpass filter the raw signal\n raw_filtered = bandpass_filter(\n raw,\n data.grid.samplerate,\n lower_f,\n upper_f,\n )\n\n # compute the envelope of the filtered signal\n env = envelope(\n signal=raw_filtered,\n samplerate=data.grid.samplerate,\n cutoff_frequency=50,\n )\n peaks, proms, env = get_env_trough(env, raw_filtered)\n # mpl.use(\"TkAgg\")\n # plt.plot(env)\n # plt.plot(raw_filtered)\n # plt.plot(peaks, env[peaks], \"x\")\n # plt.show()\n return peaks, proms, env\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough","title":"get_env_trough(env, raw)
","text":"Get the envelope troughs and their prominences.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough--parameters","title":"Parameters","text":" env
: np.ndarray
Envelope of the filtered signal raw
: np.ndarray
Raw signal
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.get_env_trough--returns","title":"Returns","text":" peaks
: np.ndarray
Indices of the envelope troughs proms
: np.ndarray
Prominences of the envelope troughs env
: np.ndarray
Envelope of the filtered signal
Source code in chirpdetector/assign_chirps.py
def get_env_trough(\n env: np.ndarray,\n raw: np.ndarray,\n) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:\n \"\"\"Get the envelope troughs and their prominences.\n\n Parameters\n ----------\n - `env`: `np.ndarray`\n Envelope of the filtered signal\n - `raw`: `np.ndarray`\n Raw signal\n\n Returns\n -------\n - `peaks`: `np.ndarray`\n Indices of the envelope troughs\n - `proms`: `np.ndarray`\n Prominences of the envelope troughs\n - `env`: `np.ndarray`\n Envelope of the filtered signal\n \"\"\"\n # normalize the envelope using the amplitude of the raw signal\n # to preserve the amplitude of the envelope\n env = env / np.max(np.abs(raw))\n\n # cut of the first and last 20% of the envelope\n env[: int(0.25 * len(env))] = np.nan\n env[int(0.75 * len(env)) :] = np.nan\n\n # find troughs in the envelope and compute trough prominences\n peaks, params = find_peaks(-env, prominence=1e-3)\n proms = params[\"prominences\"]\n return peaks, proms, env\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices","title":"make_indices(chirp_df, data, idx, chirp)
","text":"Make indices for the chirp window.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes data
: gridtools.datasets.Dataset
Dataset object containing the data idx
: int
Index of the chirp in the chirp_df chirp
: float
Chirp time
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.make_indices--returns","title":"Returns","text":" start_idx
: int
Start index of the chirp window stop_idx
: int
Stop index of the chirp window center_idx
: int
Center index of the chirp window
Source code in chirpdetector/assign_chirps.py
def make_indices(\n chirp_df: pd.DataFrame, data: Dataset, idx: int, chirp: float\n) -> Tuple[int, int, int]:\n \"\"\"Make indices for the chirp window.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `data`: `gridtools.datasets.Dataset`\n Dataset object containing the data\n - `idx`: `int`\n Index of the chirp in the chirp_df\n - `chirp`: `float`\n Chirp time\n\n Returns\n -------\n - `start_idx`: `int`\n Start index of the chirp window\n - `stop_idx`: `int`\n Stop index of the chirp window\n - `center_idx`: `int`\n Center index of the chirp window\n \"\"\"\n # determine start and stop index of time window on raw data\n # using bounding box start and stop times of chirp detection\n diffr = chirp_df.t2.to_numpy()[idx] - chirp_df.t1.to_numpy()[idx]\n t1 = chirp_df.t1.to_numpy()[idx] - 0.5 * diffr\n t2 = chirp_df.t2.to_numpy()[idx] + 0.5 * diffr\n\n start_idx = int(np.round(t1 * data.grid.samplerate))\n stop_idx = int(np.round(t2 * data.grid.samplerate))\n center_idx = int(np.round(chirp * data.grid.samplerate))\n\n return start_idx, stop_idx, center_idx\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast","title":"non_max_suppression_fast(chirp_df, overlapthresh)
","text":"Faster implementation of non-maximum suppression.
To remove overlapping bounding boxes. Is a slightly modified version of https://pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes overlapthresh
: float
Threshold for overlap between bboxes
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.non_max_suppression_fast--returns","title":"Returns","text":" pick
: list
List of indices of bboxes to keep
Source code in chirpdetector/assign_chirps.py
def non_max_suppression_fast(\n chirp_df: pd.DataFrame,\n overlapthresh: float,\n) -> list:\n \"\"\"Faster implementation of non-maximum suppression.\n\n To remove overlapping bounding boxes.\n Is a slightly modified version of https://pyimagesearch.com/2015/02/16/faster-non-maximum-suppression-python/.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `overlapthresh`: `float`\n Threshold for overlap between bboxes\n\n Returns\n -------\n - `pick`: `list`\n List of indices of bboxes to keep\n \"\"\"\n # convert boxes to list of tuples and then to numpy array\n boxes = chirp_df[[\"t1\", \"f1\", \"t2\", \"f2\"]].to_numpy()\n\n # if there are no boxes, return an empty list\n if len(boxes) == 0:\n return []\n\n # initialize the list of picked indexes\n pick = []\n\n # grab the coordinates of the bounding boxes\n x1 = boxes[:, 0]\n y1 = boxes[:, 1]\n x2 = boxes[:, 2]\n y2 = boxes[:, 3]\n\n # compute the area of the bounding boxes and sort the bounding\n # boxes by the bottom-right y-coordinate of the bounding box\n area = (x2 - x1) * (y2 - y1)\n idxs = np.argsort(y2)\n\n # keep looping while some indexes still remain in the indexes\n # list\n while len(idxs) > 0:\n # grab the last index in the indexes list and add the\n # index value to the list of picked indexes\n last = len(idxs) - 1\n i = idxs[last]\n pick.append(i)\n\n # find the largest (x, y) coordinates for the start of\n # the bounding box and the smallest (x, y) coordinates\n # for the end of the bounding box\n xx1 = np.maximum(x1[i], x1[idxs[:last]])\n yy1 = np.maximum(y1[i], y1[idxs[:last]])\n xx2 = np.minimum(x2[i], x2[idxs[:last]])\n yy2 = np.minimum(y2[i], y2[idxs[:last]])\n\n # compute the width and height of the bounding box\n w = np.maximum(0, xx2 - xx1)\n h = np.maximum(0, yy2 - yy1)\n\n # compute the ratio of overlap (intersection over union)\n overlap = (w * h) / area[idxs[:last]]\n\n # delete all indexes from the index list that have\n idxs = np.delete(\n idxs,\n np.concatenate(([last], np.where(overlap > overlapthresh)[0])),\n )\n # return the indicies of the picked boxes\n return pick\n
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.remove_non_overlapping_boxes","title":"remove_non_overlapping_boxes(chirp_df, minf, maxf)
","text":"Remove chirp bboxes that do not overlap with tracks.
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.remove_non_overlapping_boxes--parameters","title":"Parameters","text":" chirp_df
: pd.dataframe
Dataframe containing the chirp bboxes minf
: float
Minimum frequency of the range maxf
: float
Maximum frequency of the range
"},{"location":"api/assign_chirps/#chirpdetector.assign_chirps.remove_non_overlapping_boxes--returns","title":"Returns","text":" chirp_df_tf
: pd.dataframe
Dataframe containing the chirp bboxes that overlap with the range
Source code in chirpdetector/assign_chirps.py
def remove_non_overlapping_boxes(\n chirp_df: pd.DataFrame,\n minf: float,\n maxf: float,\n) -> pd.DataFrame:\n \"\"\"Remove chirp bboxes that do not overlap with tracks.\n\n Parameters\n ----------\n - `chirp_df`: `pd.dataframe`\n Dataframe containing the chirp bboxes\n - `minf`: `float`\n Minimum frequency of the range\n - `maxf`: `float`\n Maximum frequency of the range\n\n Returns\n -------\n - `chirp_df_tf`: `pd.dataframe`\n Dataframe containing the chirp bboxes that overlap with the range\n \"\"\"\n # remove all chirp bboxes that have no overlap with the range spanned by\n # minf and maxf\n\n # first build a box that spans the entire range\n range_box = np.array([0, minf, np.max(chirp_df.t2), maxf])\n\n # now compute the intersection between the range box and each chirp bboxes\n # and keep only those that have an intersection area > 0\n chirp_df_tf = chirp_df.copy()\n intersection = chirp_df_tf.apply(\n lambda row: (\n max(0, min(row[\"t2\"], range_box[2]) - max(row[\"t1\"], range_box[0]))\n * max(\n 0,\n min(row[\"f2\"], range_box[3]) - max(row[\"f1\"], range_box[1]),\n )\n ),\n axis=1,\n )\n return chirp_df_tf.loc[intersection > 0, :]\n
"},{"location":"api/convert_data/","title":"convert_data","text":"Functions and classes for converting data.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes","title":"chirp_bounding_boxes(data, nfft)
","text":"Make bounding boxes of simulated chirps using the chirp parameters.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes--parameters","title":"Parameters","text":" data
: Dataset
The dataset to make bounding boxes for. nfft
: int The number of samples in the FFT.
"},{"location":"api/convert_data/#chirpdetector.convert_data.chirp_bounding_boxes--returns","title":"Returns","text":"pandas.DataFrame
A dataframe with the bounding boxes.
Source code in chirpdetector/convert_data.py
def chirp_bounding_boxes(data: Dataset, nfft: int) -> pd.DataFrame:\n \"\"\"Make bounding boxes of simulated chirps using the chirp parameters.\n\n Parameters\n ----------\n - `data` : `Dataset`\n The dataset to make bounding boxes for.\n - `nfft` : int\n The number of samples in the FFT.\n\n Returns\n -------\n `pandas.DataFrame`\n A dataframe with the bounding boxes.\n \"\"\"\n assert hasattr(\n data.com.chirp,\n \"params\",\n ), \"Dataset must have a chirp attribute with a params attribute\"\n\n # Time padding is one NFFT window\n pad_time = nfft / data.grid.samplerate\n\n # Freq padding is fixed by the frequency resolution\n freq_res = data.grid.samplerate / nfft\n pad_freq = freq_res * 50\n\n boxes = []\n ids = []\n for fish_id in data.track.ids:\n freqs = data.track.freqs[data.track.idents == fish_id]\n times = data.track.times[\n data.track.indices[data.track.idents == fish_id]\n ]\n chirps = data.com.chirp.times[data.com.chirp.idents == fish_id]\n params = data.com.chirp.params[data.com.chirp.idents == fish_id]\n\n for chirp, param in zip(chirps, params):\n # take the two closest frequency points\n f_closest = freqs[np.argsort(np.abs(times - chirp))[:2]]\n\n # take the two closest time points\n t_closest = times[np.argsort(np.abs(times - chirp))[:2]]\n\n # compute the weighted average of the two closest frequency points\n # using the dt between chirp time and sampled time as weights\n f_closest = np.average(\n f_closest,\n weights=np.abs(t_closest - chirp),\n )\n\n # we now have baseline eodf and time point of the chirp. Now\n # we get some parameters from the params to build the bounding box\n # for the chirp\n height = param[1]\n width = param[2]\n\n # now define bounding box as center coordinates, width and height\n t_center = chirp\n f_center = f_closest + height / 2\n\n bbox_height = height + pad_freq\n bbox_width = width + pad_time\n\n boxes.append((t_center, f_center, bbox_width, bbox_height))\n ids.append(fish_id)\n\n dataframe = pd.DataFrame(\n boxes,\n columns=[\"t_center\", \"f_center\", \"width\", \"height\"],\n )\n dataframe[\"fish_id\"] = ids\n return dataframe\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert","title":"convert(data, conf, output, label_mode)
","text":"Convert a gridtools dataset to a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert--parameters","title":"Parameters","text":" data
: Dataset
The dataset to convert. conf
: Config
The configuration. output
: pathlib.Path
The output directory. label_mode
: str
The label mode. Can be one of 'none', 'synthetic' or 'detected'.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert--returns","title":"Returns","text":""},{"location":"api/convert_data/#chirpdetector.convert_data.convert--notes","title":"Notes","text":"This function iterates through a raw recording in chunks and computes the sum spectrogram of each chunk. The chunk size needs to be chosen such that the images can be nicely fed to a detector. The function also computes the bounding boxes of chirps in that chunk and saves them to a dataframe and a txt file into a labels directory.
Source code in chirpdetector/convert_data.py
def convert(\n data: Dataset,\n conf: Config,\n output: pathlib.Path,\n label_mode: str,\n) -> None:\n \"\"\"Convert a gridtools dataset to a YOLO dataset.\n\n Parameters\n ----------\n - `data` : `Dataset`\n The dataset to convert.\n - `conf` : `Config`\n The configuration.\n - `output` : `pathlib.Path`\n The output directory.\n - `label_mode` : `str`\n The label mode. Can be one of 'none', 'synthetic' or 'detected'.\n\n Returns\n -------\n - `None`\n\n Notes\n -----\n This function iterates through a raw recording in chunks and computes the\n sum spectrogram of each chunk. The chunk size needs to be chosen such that\n the images can be nicely fed to a detector. The function also computes\n the bounding boxes of chirps in that chunk and saves them to a dataframe\n and a txt file into a labels directory.\n \"\"\"\n assert hasattr(data, \"grid\"), \"Dataset must have a grid attribute\"\n assert label_mode in [\n \"none\",\n \"synthetic\",\n \"detected\",\n ], \"label_mode must be one of 'none', 'synthetic' or 'detected'\"\n\n dataroot = output\n\n n_electrodes = data.grid.rec.shape[1]\n\n # How much time to put into each spectrogram\n time_window = conf.spec.time_window # seconds\n window_overlap = conf.spec.spec_overlap # seconds\n freq_pad = conf.spec.freq_pad # Hz\n window_overlap_samples = window_overlap * data.grid.samplerate # samples\n\n # Spectrogram computation parameters\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = time_window * data.grid.samplerate # samples\n n_chunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n\n rprint(\n \"Dividing recording of duration\"\n f\"{data.grid.rec.shape[0] / data.grid.samplerate} into {n_chunks}\"\n f\"chunks of {time_window} seconds each.\",\n )\n\n bbox_dfs = []\n\n # shift the time of the tracks to start at 0\n # because a subset starts at the orignal time\n # TODO: Remove this when gridtools is fixed\n data.track.times -= data.track.times[0]\n\n for chunk_no in range(n_chunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = sint(chunk_no * chunksize)\n idx2 = sint((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == n_chunks - 1:\n idx1 = sint(chunk_no * chunksize - window_overlap_samples)\n idx2 = sint((chunk_no + 1) * chunksize)\n else:\n idx1 = sint(chunk_no * chunksize - window_overlap_samples)\n idx2 = sint((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n\n chunk = subset(data, idx1, idx2, mode=\"index\")\n\n # compute the spectrogram for each electrode of the current chunk\n spec = None\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n # the spec is a tensor\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n if spec is None:\n msg = \"Failed to compute spectrogram.\"\n raise ValueError(msg)\n\n # normalize spectrogram by the number of electrodes\n # the spec is still a tensor\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n # .. still a tensor\n spec = decibel(spec)\n\n # cut off everything outside the upper frequency limit\n # the spec is still a tensor\n\n spectrogram_freq_limits = (\n np.min(chunk.track.freqs) - freq_pad,\n np.max(chunk.track.freqs) + freq_pad,\n )\n\n spec = spec[\n (spec_freqs >= spectrogram_freq_limits[0])\n & (spec_freqs <= spectrogram_freq_limits[1]),\n :,\n ]\n spec_freqs = spec_freqs[\n (spec_freqs >= spectrogram_freq_limits[0])\n & (spec_freqs <= spectrogram_freq_limits[1])\n ]\n\n # normalize the spectrogram to zero mean and unit variance\n # the spec is still a tensor\n spec = (spec - spec.mean()) / spec.std()\n\n # convert the spectrogram to a PIL image\n spec = spec.detach().cpu().numpy()\n img = numpy_to_pil(spec)\n\n imgname = f\"{chunk.path.name}.png\"\n if label_mode == \"synthetic\":\n bbox_df, img = synthetic_labels(\n dataroot,\n chunk,\n nfft,\n spec,\n spec_times,\n spec_freqs,\n imgname,\n chunk_no,\n img,\n )\n if bbox_df is None:\n continue\n bbox_dfs.append(bbox_df)\n elif label_mode == \"detected\":\n detected_labels(dataroot, chunk, imgname, spec, spec_times)\n\n # save image\n img.save(dataroot / \"images\" / f\"{imgname}\")\n\n if label_mode == \"synthetic\":\n bbox_df = pd.concat(bbox_dfs, ignore_index=True)\n bbox_df.to_csv(dataroot / f\"{data.path.name}_bboxes.csv\", index=False)\n\n # save the classes.txt file\n classes = [\"__background__\", \"chirp\"]\n with pathlib.Path.open(dataroot / \"classes.txt\", \"w\") as f:\n f.write(\"\\n\".join(classes))\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli","title":"convert_cli(path, output, label_mode)
","text":"Parse all datasets in a directory and convert them to a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli--parameters","title":"Parameters","text":" path
: pathlib.Path
The root directory of the datasets.
"},{"location":"api/convert_data/#chirpdetector.convert_data.convert_cli--returns","title":"Returns","text":" Source code in chirpdetector/convert_data.py
def convert_cli(\n path: pathlib.Path,\n output: pathlib.Path,\n label_mode: str,\n) -> None:\n \"\"\"Parse all datasets in a directory and convert them to a YOLO dataset.\n\n Parameters\n ----------\n - `path` : `pathlib.Path`\n The root directory of the datasets.\n\n Returns\n -------\n - `None`\n \"\"\"\n make_file_tree(output)\n config = load_config(str(path / \"chirpdetector.toml\"))\n\n for p in track(list(path.iterdir()), description=\"Building datasets\"):\n if p.is_file():\n continue\n data = load(p)\n convert(data, config, output, label_mode)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels","title":"detected_labels(output, chunk, imgname, spec, spec_times)
","text":"Use the detect_chirps to make a YOLO dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels--parameters","title":"Parameters","text":" output
: pathlib.Path
The output directory. chunk
: Dataset
The dataset to make bounding boxes for. imgname
: str
The name of the image. spec
: np.ndarray
The spectrogram. spec_times
: np.ndarray
The time axis of the spectrogram.
"},{"location":"api/convert_data/#chirpdetector.convert_data.detected_labels--returns","title":"Returns","text":" Source code in chirpdetector/convert_data.py
def detected_labels(\n output: pathlib.Path,\n chunk: Dataset,\n imgname: str,\n spec: np.ndarray,\n spec_times: np.ndarray,\n) -> None:\n \"\"\"Use the detect_chirps to make a YOLO dataset.\n\n Parameters\n ----------\n - `output` : `pathlib.Path`\n The output directory.\n - `chunk` : `Dataset`\n The dataset to make bounding boxes for.\n - `imgname` : `str`\n The name of the image.\n - `spec` : `np.ndarray`\n The spectrogram.\n - `spec_times` : `np.ndarray`\n The time axis of the spectrogram.\n\n Returns\n -------\n - `None`\n \"\"\"\n # load the detected bboxes csv\n # TODO: This is a workaround. Instead improve the subset naming convention\n # in gridtools\n source_dataset = chunk.path.name.split(\"_\")[1:-4]\n source_dataset = \"_\".join(source_dataset)\n source_dataset = chunk.path.parent / source_dataset\n\n dataframe = pd.read_csv(source_dataset / \"chirpdetector_bboxes.csv\")\n\n # get chunk start and stop time\n start, stop = spec_times[0], spec_times[-1]\n\n # get the bboxes for this chunk\n bboxes = dataframe[(dataframe.t1 >= start) & (dataframe.t2 <= stop)]\n\n # get the x and y coordinates of the bboxes in pixels as dataframe\n bboxes_xy = bboxes[[\"x1\", \"y1\", \"x2\", \"y2\"]]\n\n # convert from x1, y1, x2, y2 to centerx, centery, width, height\n centerx = np.array((bboxes_xy[\"x1\"] + bboxes_xy[\"x2\"]) / 2)\n centery = np.array((bboxes_xy[\"y1\"] + bboxes_xy[\"y2\"]) / 2)\n width = np.array(bboxes_xy[\"x2\"] - bboxes_xy[\"x1\"])\n height = np.array(bboxes_xy[\"y2\"] - bboxes_xy[\"y1\"])\n\n # flip centery because origin is top left\n centery = spec.shape[0] - centery\n\n # make relative to image size\n centerx = centerx / spec.shape[1]\n centery = centery / spec.shape[0]\n width = width / spec.shape[1]\n height = height / spec.shape[0]\n labels = np.ones_like(centerx, dtype=int)\n\n # make a new dataframe with the relative coordinates\n new_bboxes = pd.DataFrame(\n {\"l\": labels, \"x\": centerx, \"y\": centery, \"w\": width, \"h\": height},\n )\n\n # save dataframe for every spec without headers as txt\n new_bboxes.to_csv(\n output / \"labels\" / f\"{imgname[:-4]}.txt\",\n header=False,\n index=False,\n sep=\" \",\n )\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.make_file_tree","title":"make_file_tree(path)
","text":"Build a file tree for the training dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.make_file_tree--parameters","title":"Parameters","text":"path : pathlib.Path The root directory of the dataset.
Source code in chirpdetector/convert_data.py
def make_file_tree(path: pathlib.Path) -> None:\n \"\"\"Build a file tree for the training dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The root directory of the dataset.\n \"\"\"\n if path.parent.exists() and path.parent.is_file():\n msg = (\n f\"Parent directory of {path} is a file. \"\n \"Please specify a directory.\"\n )\n raise ValueError(msg)\n\n if path.exists():\n shutil.rmtree(path)\n\n path.mkdir(exist_ok=True, parents=True)\n\n train_imgs = path / \"images\"\n train_labels = path / \"labels\"\n train_imgs.mkdir(exist_ok=True, parents=True)\n train_labels.mkdir(exist_ok=True, parents=True)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil","title":"numpy_to_pil(img)
","text":"Convert a 2D numpy array to a PIL image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil--parameters","title":"Parameters","text":"img : np.ndarray The input image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.numpy_to_pil--returns","title":"Returns","text":"PIL.Image The converted image.
Source code in chirpdetector/convert_data.py
def numpy_to_pil(img: np.ndarray) -> Image.Image:\n \"\"\"Convert a 2D numpy array to a PIL image.\n\n Parameters\n ----------\n img : np.ndarray\n The input image.\n\n Returns\n -------\n PIL.Image\n The converted image.\n \"\"\"\n img_dimens = 2\n if len(img.shape) != img_dimens:\n msg = f\"Image must be {img_dimens}D\"\n raise ValueError(msg)\n\n if img.max() == img.min():\n msg = \"Image must have more than one value\"\n raise ValueError(msg)\n\n img = np.flipud(img)\n intimg = np.uint8((img - img.min()) / (img.max() - img.min()) * 255)\n return Image.fromarray(intimg)\n
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels","title":"synthetic_labels(output, chunk, nfft, spec, spec_times, spec_freqs, imgname, chunk_no, img)
","text":"Generate labels of a simulated dataset.
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels--parameters","title":"Parameters","text":" output
: pathlib.Path
The output directory. chunk
: Dataset
The dataset to make bounding boxes for. nfft
: int
The number of samples in the FFT. spec
: np.ndarray
The spectrogram. spec_times
: np.ndarray
The time axis of the spectrogram. spec_freqs
: np.ndarray
The frequency axis of the spectrogram. imgname
: str
The name of the image. chunk_no
: int
The chunk number. img
: Image
The image.
"},{"location":"api/convert_data/#chirpdetector.convert_data.synthetic_labels--returns","title":"Returns","text":" pandas.DataFrame
A dataframe with the bounding boxes.
Source code in chirpdetector/convert_data.py
def synthetic_labels(\n output: pathlib.Path,\n chunk: Dataset,\n nfft: int,\n spec: np.ndarray,\n spec_times: np.ndarray,\n spec_freqs: np.ndarray,\n imgname: str,\n chunk_no: int,\n img: Image.Image,\n) -> Union[Tuple[pd.DataFrame, Image.Image], Tuple[None, None]]:\n \"\"\"Generate labels of a simulated dataset.\n\n Parameters\n ----------\n - `output` : `pathlib.Path`\n The output directory.\n - `chunk` : `Dataset`\n The dataset to make bounding boxes for.\n - `nfft` : `int`\n The number of samples in the FFT.\n - `spec` : `np.ndarray`\n The spectrogram.\n - `spec_times` : `np.ndarray`\n The time axis of the spectrogram.\n - `spec_freqs` : `np.ndarray`\n The frequency axis of the spectrogram.\n - `imgname` : `str`\n The name of the image.\n - `chunk_no` : `int`\n The chunk number.\n - `img` : `Image`\n The image.\n\n Returns\n -------\n - `pandas.DataFrame`\n A dataframe with the bounding boxes.\n \"\"\"\n # compute the bounding boxes for this chunk\n bboxes = chirp_bounding_boxes(chunk, nfft)\n\n if len(bboxes) == 0:\n return None, None\n\n # convert bounding box center coordinates to spectrogram coordinates\n # find the indices on the spec_times corresponding to the center times\n x = np.searchsorted(spec_times, bboxes.t_center)\n y = np.searchsorted(spec_freqs, bboxes.f_center)\n widths = np.searchsorted(spec_times - spec_times[0], bboxes.width)\n heights = np.searchsorted(spec_freqs - spec_freqs[0], bboxes.height)\n\n # now we have center coordinates, widths and heights in indices. But PIL\n # expects coordinates in pixels in the format\n # (Upper left x coordinate, upper left y coordinate,\n # lower right x coordinate, lower right y coordinate)\n # In addiotion, an image starts in the top left corner so the bboxes\n # need to be mirrored horizontally.\n\n y = spec.shape[0] - y # flip the y values to fit y=0 at the top\n lxs, lys = x - widths / 2, y - heights / 2\n rxs, rys = x + widths / 2, y + heights / 2\n\n # add them to the bboxes dataframe\n bboxes[\"upperleft_img_x\"] = lxs\n bboxes[\"upperleft_img_y\"] = lys\n bboxes[\"lowerright_img_x\"] = rxs\n bboxes[\"lowerright_img_y\"] = rys\n\n # yolo format is centerx, centery, width, height\n # convert xmin, ymin, xmax, ymax to centerx, centery, width, height\n centerx = (lxs + rxs) / 2\n centery = (lys + rys) / 2\n width = rxs - lxs\n height = rys - lys\n\n # most deep learning frameworks expect bounding box coordinates\n # as relative to the image size. So we normalize the coordinates\n # to the image size\n centerx_norm = centerx / spec.shape[1]\n centery_norm = centery / spec.shape[0]\n width_norm = width / spec.shape[1]\n height_norm = height / spec.shape[0]\n\n # add them to the bboxes dataframe\n bboxes[\"centerx_norm\"] = centerx_norm\n bboxes[\"centery_norm\"] = centery_norm\n bboxes[\"width_norm\"] = width_norm\n bboxes[\"height_norm\"] = height_norm\n\n # add chunk ID to the bboxes dataframe\n bboxes[\"chunk_id\"] = chunk_no\n\n # put them into a dataframe to save for eahc spectrogram\n dataframe = pd.DataFrame(\n {\n \"cx\": centerx_norm,\n \"cy\": centery_norm,\n \"w\": width_norm,\n \"h\": height_norm,\n },\n )\n\n # add as first colum instance id\n dataframe.insert(0, \"instance_id\", np.ones_like(lxs, dtype=int))\n\n # stash the bboxes dataframe for this chunk\n bboxes[\"image\"] = imgname\n\n # save dataframe for every spec without headers as txt\n dataframe.to_csv(\n output / \"labels\" / f\"{chunk.path.name}.txt\",\n header=False,\n index=False,\n sep=\" \",\n )\n return bboxes, img\n
"},{"location":"api/dataset_utils/","title":"dataset_utils","text":"Utility functions for training datasets in the YOLO format.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset","title":"clean_yolo_dataset(path, img_ext)
","text":"Remove images and labels when the label file is empty.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset. img_ext : str
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.clean_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def clean_yolo_dataset(path: pathlib.Path, img_ext: str) -> None:\n \"\"\"Remove images and labels when the label file is empty.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset.\n img_ext : str\n\n Returns\n -------\n None\n \"\"\"\n img_path = path / \"images\"\n lbl_path = path / \"labels\"\n\n images = list(img_path.glob(f\"*{img_ext}\"))\n\n for image in images:\n lbl = lbl_path / f\"{image.stem}.txt\"\n if lbl.stat().st_size == 0:\n image.unlink()\n lbl.unlink()\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img","title":"load_img(path)
","text":"Load an image from a path as a numpy array.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img--parameters","title":"Parameters","text":"path : pathlib.Path The path to the image.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.load_img--returns","title":"Returns","text":"img : np.ndarray The image as a numpy array.
Source code in chirpdetector/dataset_utils.py
def load_img(path: pathlib.Path) -> np.ndarray:\n \"\"\"Load an image from a path as a numpy array.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the image.\n\n Returns\n -------\n img : np.ndarray\n The image as a numpy array.\n \"\"\"\n img = Image.open(path)\n return np.asarray(img)\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets","title":"merge_yolo_datasets(dataset1, dataset2, output)
","text":"Merge two yolo-style datasets into one.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets--parameters","title":"Parameters","text":"dataset1 : str The path to the first dataset. dataset2 : str The path to the second dataset. output : str The path to the output dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.merge_yolo_datasets--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def merge_yolo_datasets(\n dataset1: pathlib.Path,\n dataset2: pathlib.Path,\n output: pathlib.Path,\n) -> None:\n \"\"\"Merge two yolo-style datasets into one.\n\n Parameters\n ----------\n dataset1 : str\n The path to the first dataset.\n dataset2 : str\n The path to the second dataset.\n output : str\n The path to the output dataset.\n\n Returns\n -------\n None\n \"\"\"\n dataset1 = pathlib.Path(dataset1)\n dataset2 = pathlib.Path(dataset2)\n output = pathlib.Path(output)\n\n if not dataset1.exists():\n msg = f\"{dataset1} does not exist.\"\n raise FileNotFoundError(msg)\n if not dataset2.exists():\n msg = f\"{dataset2} does not exist.\"\n raise FileNotFoundError(msg)\n if output.exists():\n msg = f\"{output} already exists.\"\n raise FileExistsError(msg)\n\n output_images = output / \"images\"\n output_images.mkdir(parents=True, exist_ok=False)\n output_labels = output / \"labels\"\n output_labels.mkdir(parents=True, exist_ok=False)\n\n imgs1 = list((dataset1 / \"images\").iterdir())\n labels1 = list((dataset1 / \"labels\").iterdir())\n imgs2 = list((dataset2 / \"images\").iterdir())\n labels2 = list((dataset2 / \"labels\").iterdir())\n\n print(f\"Found {len(imgs1)} images in {dataset1}.\")\n print(f\"Found {len(imgs2)} images in {dataset2}.\")\n\n print(f\"Copying images and labels to {output}...\")\n for idx, _ in enumerate(imgs1):\n shutil.copy(imgs1[idx], output_images / imgs1[idx].name)\n shutil.copy(labels1[idx], output_labels / labels1[idx].name)\n\n for idx, _ in enumerate(imgs2):\n shutil.copy(imgs2[idx], output_images / imgs2[idx].name)\n shutil.copy(labels2[idx], output_labels / labels2[idx].name)\n\n classes = dataset1 / \"classes.txt\"\n shutil.copy(classes, output / classes.name)\n\n print(f\"Done. Merged {len(imgs1) + len(imgs2)} images.\")\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset","title":"plot_yolo_dataset(path, n)
","text":"Plot n random images YOLO-style dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.plot_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def plot_yolo_dataset(path: pathlib.Path, n: int) -> None:\n \"\"\"Plot n random images YOLO-style dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset.\n\n Returns\n -------\n None\n \"\"\"\n mpl.use(\"TkAgg\")\n labelpath = path / \"labels\"\n imgpath = path / \"images\"\n\n label_paths = np.array(list(labelpath.glob(\"*.txt\")))\n label_paths = np.random.choice(label_paths, n)\n\n for lp in label_paths:\n imgp = imgpath / (lp.stem + \".png\")\n img = load_img(imgp)\n labs = np.loadtxt(lp, dtype=np.float32).reshape(-1, 5)\n\n coords = labs[:, 1:]\n\n # make coords absolute and normalize\n coords[:, 0] *= img.shape[1]\n coords[:, 1] *= img.shape[0]\n coords[:, 2] *= img.shape[1]\n coords[:, 3] *= img.shape[0]\n\n # turn centerx, centery, width, height into xmin, ymin, xmax, ymax\n xmin = coords[:, 0] - coords[:, 2] / 2\n ymin = coords[:, 1] - coords[:, 3] / 2\n xmax = coords[:, 0] + coords[:, 2] / 2\n ymax = coords[:, 1] + coords[:, 3] / 2\n\n # plot the image\n _, ax = plt.subplots(figsize=(15, 5), constrained_layout=True)\n ax.imshow(img, cmap=\"magma\")\n for i in range(len(xmin)):\n ax.add_patch(\n Rectangle(\n (xmin[i], ymin[i]),\n xmax[i] - xmin[i],\n ymax[i] - ymin[i],\n fill=False,\n color=\"white\",\n ),\n )\n ax.set_title(imgp.stem)\n plt.axis(\"off\")\n plt.show()\n
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset","title":"subset_yolo_dataset(path, img_ext, n)
","text":"Subset a YOLO dataset.
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset--parameters","title":"Parameters","text":"path : pathlib.Path The path to the dataset root. img_ext : str The image extension, e.g. .png or .jpg n : int The size of the subset
"},{"location":"api/dataset_utils/#chirpdetector.dataset_utils.subset_yolo_dataset--returns","title":"Returns","text":"None
Source code in chirpdetector/dataset_utils.py
def subset_yolo_dataset(path: pathlib.Path, img_ext: str, n: int) -> None:\n \"\"\"Subset a YOLO dataset.\n\n Parameters\n ----------\n path : pathlib.Path\n The path to the dataset root.\n img_ext : str\n The image extension, e.g. .png or .jpg\n n : int\n The size of the subset\n\n Returns\n -------\n None\n \"\"\"\n img_path = path / \"images\"\n lbl_path = path / \"labels\"\n\n images = np.array(img_path.glob(f\"*{img_ext}\"))\n np.random.shuffle(images)\n\n images = images[:n]\n\n subset_dir = path.parent / f\"{path.name}_subset\"\n subset_dir.mkdir(exist_ok=True)\n\n subset_img_path = subset_dir / \"images\"\n subset_img_path.mkdir(exist_ok=True)\n subset_lbl_path = subset_dir / \"labels\"\n subset_lbl_path.mkdir(exist_ok=True)\n\n shutil.copy(path / \"classes.txt\", subset_dir)\n\n for image in images:\n shutil.copy(image, subset_img_path)\n shutil.copy(lbl_path / f\"{image.stem}.txt\", subset_lbl_path)\n
"},{"location":"api/detect_chirps/","title":"detect_chirps","text":"Detect chirps on a spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle","title":"coords_to_mpl_rectangle(boxes)
","text":"Convert normal bounding box to matplotlib.pathes.Rectangle format.
Convert box defined by corner coordinates (x1, y1, x2, y2) to box defined by lower left, width and height (x1, y1, w, h).
The corner coordinates are the model output, but the center coordinates are needed by the matplotlib.patches.Rectangle
object for plotting.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle--parameters","title":"Parameters","text":" boxes
: numpy.ndarray
The boxes to be converted.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.coords_to_mpl_rectangle--returns","title":"Returns","text":" numpy.ndarray
The converted boxes.
Source code in chirpdetector/detect_chirps.py
def coords_to_mpl_rectangle(boxes: np.ndarray) -> np.ndarray:\n \"\"\"Convert normal bounding box to matplotlib.pathes.Rectangle format.\n\n Convert box defined by corner coordinates (x1, y1, x2, y2)\n to box defined by lower left, width and height (x1, y1, w, h).\n\n The corner coordinates are the model output, but the center coordinates\n are needed by the `matplotlib.patches.Rectangle` object for plotting.\n\n Parameters\n ----------\n - `boxes` : `numpy.ndarray`\n The boxes to be converted.\n\n Returns\n -------\n - `numpy.ndarray`\n The converted boxes.\n \"\"\"\n boxes_dims = 2\n if len(boxes.shape) != boxes_dims:\n msg = (\n \"The boxes array must be 2-dimensional.\\n\"\n f\"Shape of boxes: {boxes.shape}\"\n )\n raise ValueError(msg)\n boxes_cols = 4\n if boxes.shape[1] != boxes_cols:\n msg = (\n \"The boxes array must have 4 columns.\\n\"\n f\"Shape of boxes: {boxes.shape}\"\n )\n raise ValueError(msg)\n\n new_boxes = np.zeros_like(boxes)\n new_boxes[:, 0] = boxes[:, 0]\n new_boxes[:, 1] = boxes[:, 1]\n new_boxes[:, 2] = boxes[:, 2] - boxes[:, 0]\n new_boxes[:, 3] = boxes[:, 3] - boxes[:, 1]\n\n return new_boxes\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps","title":"detect_chirps(conf, data)
","text":"Detect chirps on a spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps--parameters","title":"Parameters","text":" conf
: Config
The configuration object. data
: Dataset
The gridtools dataset to detect chirps on.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_chirps--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def detect_chirps(conf: Config, data: Dataset) -> None:\n \"\"\"Detect chirps on a spectrogram.\n\n Parameters\n ----------\n - `conf` : `Config`\n The configuration object.\n - `data` : `Dataset`\n The gridtools dataset to detect chirps on.\n\n Returns\n -------\n - `None`\n \"\"\"\n # get the number of electrodes\n n_electrodes = data.grid.rec.shape[1]\n\n # load the model and the checkpoint, and set it to evaluation mode\n device = get_device()\n model = load_fasterrcnn(num_classes=len(conf.hyper.classes))\n checkpoint = torch.load(\n f\"{conf.hyper.modelpath}/model.pt\",\n map_location=device,\n )\n model.load_state_dict(checkpoint[\"model_state_dict\"])\n model.to(device).eval()\n\n # make spec config\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = conf.spec.time_window * data.grid.samplerate # samples\n nchunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n window_overlap_samples = int(conf.spec.spec_overlap * data.grid.samplerate)\n\n bbox_dfs = []\n\n # iterate over the chunks\n overwritten = False\n for chunk_no in range(nchunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = int(chunk_no * chunksize)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == nchunks - 1:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize)\n else:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n\n # This bit should alleviate the edge effects of the tracks\n # by limiting the start and stop times of the spectrogram\n # to the start and stop times of the track.\n start_t = idx1 / data.grid.samplerate\n stop_t = idx2 / data.grid.samplerate\n if data.track.times[-1] < stop_t:\n stop_t = data.track.times[-1]\n idx2 = int(stop_t * data.grid.samplerate)\n if data.track.times[0] > start_t:\n start_t = data.track.times[0]\n idx1 = int(start_t * data.grid.samplerate)\n if start_t > data.track.times[-1] or stop_t < data.track.times[0]:\n continue\n\n chunk = subset(data, idx1, idx2, mode=\"index\")\n if len(chunk.track.indices) == 0:\n continue\n\n # compute the spectrogram for each electrode of the current chunk\n spec = torch.zeros((len(spec_freqs), len(spec_times)))\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.rec.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n # the spec is a tensor\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n # normalize spectrogram by the number of electrodes\n # the spec is still a tensor\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n # .. still a tensor\n spec = decibel(spec)\n\n # cut off everything outside the upper frequency limit\n # the spec is still a tensor\n # TODO: THIS IS SKETCHY AS HELL! As a result, only time and frequency\n # bounding boxes can be used later! The spectrogram limits change\n # for every window!\n flims = (\n np.min(chunk.track.freqs) - conf.spec.freq_pad,\n np.max(chunk.track.freqs) + conf.spec.freq_pad,\n )\n spec = spec[(spec_freqs >= flims[0]) & (spec_freqs <= flims[1]), :]\n spec_freqs = spec_freqs[\n (spec_freqs >= flims[0]) & (spec_freqs <= flims[1])\n ]\n\n # make a path to save the spectrogram\n path = data.path / \"chirpdetections\"\n if path.exists() and overwritten is False:\n shutil.rmtree(path)\n overwritten = True\n path.mkdir(exist_ok=True)\n path /= f\"chunk{chunk_no:05d}.png\"\n\n # add the 3 channels, normalize to 0-1, etc\n img = spec_to_image(spec)\n\n # perform the detection\n with torch.inference_mode():\n outputs = model([img])\n\n # put the boxes, scores and labels into the dataset\n bboxes = outputs[0][\"boxes\"].detach().cpu().numpy()\n scores = outputs[0][\"scores\"].detach().cpu().numpy()\n labels = outputs[0][\"labels\"].detach().cpu().numpy()\n\n # remove all boxes with a score below the threshold\n bboxes = bboxes[scores > conf.det.threshold]\n labels = labels[scores > conf.det.threshold]\n scores = scores[scores > conf.det.threshold]\n\n # if np.any(scores > conf.det.threshold):\n # plot_detections(img, outputs[0], conf.det.threshold, path, conf)\n\n # save the bboxes to a dataframe\n bbox_df = pd.DataFrame(\n data=bboxes,\n columns=[\"x1\", \"y1\", \"x2\", \"y2\"],\n )\n bbox_df[\"score\"] = scores\n bbox_df[\"label\"] = labels\n\n # convert x values to time on spec_times\n spec_times_index = np.arange(0, len(spec_times))\n bbox_df[\"t1\"] = float_index_interpolation(\n bbox_df[\"x1\"].to_numpy(),\n spec_times_index,\n spec_times,\n )\n bbox_df[\"t2\"] = float_index_interpolation(\n bbox_df[\"x2\"].to_numpy(),\n spec_times_index,\n spec_times,\n )\n\n # convert y values to frequency on spec_freqs\n spec_freqs_index = np.arange(len(spec_freqs))\n bbox_df[\"f1\"] = float_index_interpolation(\n bbox_df[\"y1\"].to_numpy(),\n spec_freqs_index,\n spec_freqs,\n )\n bbox_df[\"f2\"] = float_index_interpolation(\n bbox_df[\"y2\"].to_numpy(),\n spec_freqs_index,\n spec_freqs,\n )\n\n # save df to list\n bbox_dfs.append(bbox_df)\n\n # concatenate all dataframes\n bbox_df = pd.concat(bbox_dfs)\n bbox_reset = bbox_df.reset_index(drop=True)\n\n # sort the dataframe by t1\n bbox_sorted = bbox_reset.sort_values(by=\"t1\")\n\n # sort the columns\n bbox_sorted = bbox_sorted[\n [\"label\", \"score\", \"x1\", \"y1\", \"x2\", \"y2\", \"t1\", \"f1\", \"t2\", \"f2\"]\n ]\n\n # save the dataframe\n bbox_sorted.to_csv(data.path / \"chirpdetector_bboxes.csv\", index=False)\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli","title":"detect_cli(input_path)
","text":"Terminal interface for the detection function.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli--parameters","title":"Parameters","text":""},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.detect_cli--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def detect_cli(input_path: pathlib.Path) -> None:\n \"\"\"Terminal interface for the detection function.\n\n Parameters\n ----------\n - `path` : `str`\n\n Returns\n -------\n - `None`\n \"\"\"\n # make the global logger object\n # global logger # pylint: disable=global-statement\n path = pathlib.Path(input_path)\n logger = make_logger(__name__, path / \"chirpdetector.log\")\n datasets = [folder for folder in path.iterdir() if folder.is_dir()]\n confpath = path / \"chirpdetector.toml\"\n\n # load the config file and print a warning if it does not exist\n if confpath.exists():\n config = load_config(str(confpath))\n else:\n msg = (\n \"The configuration file could not be found in the specified path.\"\n \"Please run `chirpdetector copyconfig` and change the \"\n \"configuration file to your needs.\"\n )\n raise FileNotFoundError(msg)\n\n # detect chirps in all datasets in the specified path\n # and show a progress bar\n prog.console.rule(\"Starting detection\")\n with prog:\n task = prog.add_task(\"Detecting chirps...\", total=len(datasets))\n for dataset in datasets:\n msg = f\"Detecting chirps in {dataset.name}...\"\n prog.console.log(msg)\n logger.info(msg)\n\n data = load(dataset)\n detect_chirps(config, data)\n prog.update(task, advance=1)\n prog.update(task, completed=len(datasets))\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation","title":"float_index_interpolation(values, index_arr, data_arr)
","text":"Convert float indices to values by linear interpolation.
Interpolates a set of float indices within the given index array to obtain corresponding values from the data array using linear interpolation.
Given a set of float indices (values
), this function determines the corresponding values in the data_arr
by linearly interpolating between adjacent indices in the index_arr
. Linear interpolation involves calculating weighted averages based on the fractional parts of the float indices.
This function is useful to transform float coordinates on a spectrogram matrix to the corresponding time and frequency values. The reason for this is, that the model outputs bounding boxes in float coordinates, i.e. it does not care about the exact pixel location of the bounding box.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--parameters","title":"Parameters","text":" values
: np.ndarray
The index value as a float that should be interpolated. index_arr
: numpy.ndarray
The array of indices on the data array. data_arr
: numpy.ndarray
The array of data.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--returns","title":"Returns","text":" numpy.ndarray
The interpolated value.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--raises","title":"Raises","text":" ValueError
If any of the input float indices (values
) are outside the range of the provided index_arr
.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.float_index_interpolation--examples","title":"Examples","text":"values = np.array([2.5, 3.2, 4.8]) index_arr = np.array([2, 3, 4, 5]) data_arr = np.array([10, 15, 20, 25]) result = float_index_interpolation(values, index_arr, data_arr) print(result) array([12.5, 16. , 22.5])
Source code in chirpdetector/detect_chirps.py
def float_index_interpolation(\n values: np.ndarray,\n index_arr: np.ndarray,\n data_arr: np.ndarray,\n) -> np.ndarray:\n \"\"\"Convert float indices to values by linear interpolation.\n\n Interpolates a set of float indices within the given index\n array to obtain corresponding values from the data\n array using linear interpolation.\n\n Given a set of float indices (`values`), this function determines\n the corresponding values in the `data_arr` by linearly interpolating\n between adjacent indices in the `index_arr`. Linear interpolation\n involves calculating weighted averages based on the fractional\n parts of the float indices.\n\n This function is useful to transform float coordinates on a spectrogram\n matrix to the corresponding time and frequency values. The reason for\n this is, that the model outputs bounding boxes in float coordinates,\n i.e. it does not care about the exact pixel location of the bounding\n box.\n\n Parameters\n ----------\n - `values` : `np.ndarray`\n The index value as a float that should be interpolated.\n - `index_arr` : `numpy.ndarray`\n The array of indices on the data array.\n - `data_arr` : `numpy.ndarray`\n The array of data.\n\n Returns\n -------\n - `numpy.ndarray`\n The interpolated value.\n\n Raises\n ------\n - `ValueError`\n If any of the input float indices (`values`) are outside\n the range of the provided `index_arr`.\n\n Examples\n --------\n >>> values = np.array([2.5, 3.2, 4.8])\n >>> index_arr = np.array([2, 3, 4, 5])\n >>> data_arr = np.array([10, 15, 20, 25])\n >>> result = float_index_interpolation(values, index_arr, data_arr)\n >>> print(result)\n array([12.5, 16. , 22.5])\n \"\"\"\n # Check if the values are within the range of the index array\n if np.any(values < (np.min(index_arr) - 1)) or np.any(\n values > (np.max(index_arr) + 1),\n ):\n msg = (\n \"Values outside the range of index array\\n\"\n f\"Target values: {values}\\n\"\n f\"Index array: {index_arr}\\n\"\n f\"Data array: {data_arr}\"\n )\n raise ValueError(msg)\n\n # Find the indices corresponding to the values\n lower_indices = np.floor(values).astype(int)\n upper_indices = np.ceil(values).astype(int)\n\n # Ensure upper indices are within the array bounds\n upper_indices = np.minimum(upper_indices, len(index_arr) - 1)\n lower_indices = np.minimum(lower_indices, len(index_arr) - 1)\n\n # Calculate the interpolation weights\n weights = values - lower_indices\n\n # Linear interpolation\n return (1 - weights) * data_arr[lower_indices] + weights * data_arr[\n upper_indices\n ]\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections","title":"plot_detections(img_tensor, output, threshold, save_path, conf)
","text":"Plot the detections on the spectrogram.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections--parameters","title":"Parameters","text":" img_tensor
: torch.Tensor
The spectrogram. output
: torch.Tensor
The output of the model. threshold
: float
The threshold for the detections. save_path
: pathlib.Path
The path to save the plot to. conf
: Config
The configuration object.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.plot_detections--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def plot_detections(\n img_tensor: torch.Tensor,\n output: torch.Tensor,\n threshold: float,\n save_path: pathlib.Path,\n conf: Config,\n) -> None:\n \"\"\"Plot the detections on the spectrogram.\n\n Parameters\n ----------\n - `img_tensor` : `torch.Tensor`\n The spectrogram.\n - `output` : `torch.Tensor`\n The output of the model.\n - `threshold` : `float`\n The threshold for the detections.\n - `save_path` : `pathlib.Path`\n The path to save the plot to.\n - `conf` : `Config`\n The configuration object.\n\n Returns\n -------\n - `None`\n \"\"\"\n # retrieve all the data from the output and convert\n # spectrogram to numpy array\n img = img_tensor.detach().cpu().numpy().transpose(1, 2, 0)[..., 0]\n boxes = output[\"boxes\"].detach().cpu().numpy()\n boxes = coords_to_mpl_rectangle(boxes)\n scores = output[\"scores\"].detach().cpu().numpy()\n labels = output[\"labels\"].detach().cpu().numpy()\n labels = [conf.hyper.classes[i] for i in labels]\n\n _, ax = plt.subplots(figsize=(20, 10))\n\n ax.pcolormesh(img, cmap=\"magma\")\n\n for i, box in enumerate(boxes):\n if scores[i] > threshold:\n ax.scatter(\n box[0],\n box[1],\n )\n ax.add_patch(\n Rectangle(\n box[:2],\n box[2],\n box[3],\n fill=False,\n color=\"white\",\n linewidth=1,\n ),\n )\n ax.text(\n box[0],\n box[1],\n f\"{scores[i]:.2f}\",\n color=\"black\",\n fontsize=8,\n bbox={\"facecolor\":\"white\", \"alpha\":1},\n )\n plt.axis(\"off\")\n plt.savefig(save_path, dpi=300, bbox_inches=\"tight\", pad_inches=0)\n plt.close()\n
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image","title":"spec_to_image(spec)
","text":"Convert a spectrogram to an image.
Add 3 color channels, normalize to 0-1, etc.
"},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image--parameters","title":"Parameters","text":""},{"location":"api/detect_chirps/#chirpdetector.detect_chirps.spec_to_image--returns","title":"Returns","text":" Source code in chirpdetector/detect_chirps.py
def spec_to_image(spec: torch.Tensor) -> torch.Tensor:\n \"\"\"Convert a spectrogram to an image.\n\n Add 3 color channels, normalize to 0-1, etc.\n\n Parameters\n ----------\n - `spec` : `torch.Tensor`\n\n Returns\n -------\n - `torch.Tensor`\n \"\"\"\n # make sure the spectrogram is a tensor\n if not isinstance(spec, torch.Tensor):\n msg = (\n \"The spectrogram must be a torch.Tensor.\\n\"\n f\"Type of spectrogram: {type(spec)}\"\n )\n raise TypeError(msg)\n\n # make sure the spectrogram is 2-dimensional\n spec_dims = 2\n if len(spec.size()) != spec_dims:\n msg = (\n \"The spectrogram must be a 2-dimensional matrix.\\n\"\n f\"Shape of spectrogram: {spec.size()}\"\n )\n raise ValueError(msg)\n\n # make sure the spectrogram contains some data\n if (\n np.max(spec.detach().cpu().numpy())\n - np.min(spec.detach().cpu().numpy())\n == 0\n ):\n msg = (\n \"The spectrogram must contain some data.\\n\"\n f\"Max value: {np.max(spec.detach().cpu().numpy())}\\n\"\n f\"Min value: {np.min(spec.detach().cpu().numpy())}\"\n )\n raise ValueError(msg)\n\n # Get the dimensions of the original matrix\n original_shape = spec.size()\n\n # Calculate the number of rows and columns in the matrix\n num_rows, num_cols = original_shape\n\n # duplicate the matrix 3 times\n spec = spec.repeat(3, 1, 1)\n\n # Reshape the matrix to the desired shape (3, num_rows, num_cols)\n desired_shape = (3, num_rows, num_cols)\n reshaped_tensor = spec.view(desired_shape)\n\n # normalize the spectrogram to be between 0 and 1\n normalized_tensor = (reshaped_tensor - reshaped_tensor.min()) / (\n reshaped_tensor.max() - reshaped_tensor.min()\n )\n\n # make sure image is float32\n return normalized_tensor.float()\n
"},{"location":"api/plot_detections/","title":"plot_detections","text":"Functions to visualize detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_all_plots_cli","title":"clean_all_plots_cli(path)
","text":"Remove all plots from the chirpdetections folder.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_all_plots_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def clean_all_plots_cli(path: pathlib.Path) -> None:\n \"\"\"Remove all plots from the chirpdetections folder.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n dirs = [dataset for dataset in path.iterdir() if dataset.is_dir()]\n with prog:\n task = prog.add_task(\"Cleaning plots...\", total=len(dirs))\n for dataset in dirs:\n prog.console.log(f\"Cleaning plots for {dataset.name}\")\n clean_plots_cli(dataset)\n prog.advance(task)\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_plots_cli","title":"clean_plots_cli(path)
","text":"Remove all plots from the chirpdetections folder.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.clean_plots_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def clean_plots_cli(path: pathlib.Path) -> None:\n \"\"\"Remove all plots from the chirpdetections folder.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n savepath = path / \"chirpdetections\"\n for f in savepath.iterdir():\n f.unlink()\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_all_detections_cli","title":"plot_all_detections_cli(path)
","text":"Plot detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_all_detections_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def plot_all_detections_cli(path: pathlib.Path) -> None:\n \"\"\"Plot detections on images.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n conf = load_config(path / \"chirpdetector.toml\")\n\n dirs = [dataset for dataset in path.iterdir() if dataset.is_dir()]\n with prog:\n task = prog.add_task(\"Plotting detections...\", total=len(dirs))\n for dataset in dirs:\n prog.console.log(f\"Plotting detections for {dataset.name}\")\n data = load(dataset)\n chirp_df = pd.read_csv(dataset / \"chirpdetector_bboxes.csv\")\n plot_detections(data, chirp_df, conf)\n prog.advance(task)\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections","title":"plot_detections(data, chirp_df, conf)
","text":"Plot detections on spectrograms.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections--parameters","title":"Parameters","text":"data : Dataset The dataset. chirp_df : pd.DataFrame The dataframe containing the chirp detections. conf : Config The config file.
Source code in chirpdetector/plot_detections.py
def plot_detections(\n data: Dataset,\n chirp_df: pd.DataFrame,\n conf: Config,\n) -> None:\n \"\"\"Plot detections on spectrograms.\n\n Parameters\n ----------\n data : Dataset\n The dataset.\n chirp_df : pd.DataFrame\n The dataframe containing the chirp detections.\n conf : Config\n The config file.\n \"\"\"\n time_window = 15\n n_electrodes = data.grid.rec.shape[1]\n\n nfft = freqres_to_nfft(conf.spec.freq_res, data.grid.samplerate) # samples\n hop_len = overlap_to_hoplen(conf.spec.overlap_frac, nfft) # samples\n chunksize = time_window * data.grid.samplerate # samples\n nchunks = np.ceil(data.grid.rec.shape[0] / chunksize).astype(int)\n window_overlap_samples = int(conf.spec.spec_overlap * data.grid.samplerate)\n\n for chunk_no in range(nchunks):\n # get start and stop indices for the current chunk\n # including some overlap to compensate for edge effects\n # this diffrers for the first and last chunk\n\n if chunk_no == 0:\n idx1 = int(chunk_no * chunksize)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n elif chunk_no == nchunks - 1:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize)\n else:\n idx1 = int(chunk_no * chunksize - window_overlap_samples)\n idx2 = int((chunk_no + 1) * chunksize + window_overlap_samples)\n\n # idx1 and idx2 now determine the window I cut out of the raw signal\n # to compute the spectrogram of.\n\n # compute the time and frequency axes of the spectrogram now that we\n # include the start and stop indices of the current chunk and thus the\n # right start and stop time. The `spectrogram` function does not know\n # about this and would start every time axis at 0.\n spec_times = np.arange(idx1, idx2 + 1, hop_len) / data.grid.samplerate\n spec_freqs = np.arange(0, nfft / 2 + 1) * data.grid.samplerate / nfft\n\n # create a subset from the grid dataset\n if idx2 > data.grid.rec.shape[0]:\n idx2 = data.grid.rec.shape[0] - 1\n chunk = subset(data, idx1, idx2, mode=\"index\")\n\n # dont plot chunks without chirps\n if len(chunk.com.chirp.times) == 0:\n continue\n\n # compute the spectrogram for each electrode of the current chunk\n spec = torch.zeros((len(spec_freqs), len(spec_times)))\n for el in range(n_electrodes):\n # get the signal for the current electrode\n sig = chunk.grid.rec[:, el]\n\n # compute the spectrogram for the current electrode\n chunk_spec, _, _ = spectrogram(\n data=sig.copy(),\n samplingrate=data.grid.samplerate,\n nfft=nfft,\n hop_length=hop_len,\n )\n\n # sum spectrogram over all electrodes\n if el == 0:\n spec = chunk_spec\n else:\n spec += chunk_spec\n\n # normalize spectrogram by the number of electrodes\n spec /= n_electrodes\n\n # convert the spectrogram to dB\n spec = decibel(spec)\n spec = spec.detach().cpu().numpy()\n\n # Set y limits\n flims = (\n np.min(data.track.freqs) - 200,\n np.max(data.track.freqs) + 700,\n )\n spec = spec[(spec_freqs >= flims[0]) & (spec_freqs <= flims[1]), :]\n spec_freqs = spec_freqs[\n (spec_freqs >= flims[0]) & (spec_freqs <= flims[1])\n ]\n\n # Extract the bounding boxes for the current chunk\n chunk_t1 = idx1 / data.grid.samplerate\n chunk_t2 = idx2 / data.grid.samplerate\n chunk_df = chirp_df[\n (chirp_df[\"t1\"] >= chunk_t1) & (chirp_df[\"t2\"] <= chunk_t2)\n ]\n\n # get t1, t2, f1, f2 from chunk_df\n bboxes = chunk_df[[\"score\", \"t1\", \"f1\", \"t2\", \"f2\"]].to_numpy()\n\n # get chirp times and chirp ids\n chirp_times = chunk_df[\"envelope_trough_time\"]\n chirp_ids = chunk_df[\"assigned_track\"]\n\n _, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)\n\n # plot bounding boxes\n ax.imshow(\n spec,\n aspect=\"auto\",\n origin=\"lower\",\n interpolation=\"gaussian\",\n extent=[\n spec_times[0],\n spec_times[-1],\n spec_freqs[0],\n spec_freqs[-1],\n ],\n cmap=\"magma\",\n vmin=-80,\n vmax=-45,\n )\n for bbox in bboxes:\n ax.add_patch(\n Rectangle(\n (bbox[1], bbox[2]),\n bbox[3] - bbox[1],\n bbox[4] - bbox[2],\n fill=False,\n color=\"gray\",\n linewidth=1,\n label=\"faster-R-CNN predictions\",\n ),\n )\n ax.text(\n bbox[1],\n bbox[4] + 15,\n f\"{bbox[0]:.2f}\",\n color=\"gray\",\n fontsize=10,\n verticalalignment=\"bottom\",\n horizontalalignment=\"left\",\n rotation=90,\n )\n\n # plot chirp times and frequency traces\n for track_id in np.unique(data.track.idents):\n ctimes = chirp_times[chirp_ids == track_id]\n\n freqs = data.track.freqs[data.track.idents == track_id]\n times = data.track.times[\n data.track.indices[data.track.idents == track_id]\n ]\n freqs = freqs[\n (times >= spec_times[0] - 10) & (times <= spec_times[-1] + 10)\n ]\n times = times[\n (times >= spec_times[0] - 10) & (times <= spec_times[-1] + 10)\n ]\n\n # get freqs where times are closest to ctimes\n cfreqs = np.zeros_like(ctimes)\n for i, ctime in enumerate(ctimes):\n try:\n indx = np.argmin(np.abs(times - ctime))\n cfreqs[i] = freqs[indx]\n except ValueError:\n msg = (\n \"Failed to find track time closest to chirp time \"\n f\"in chunk {chunk_no}, check the plots.\"\n )\n prog.console.log(msg)\n\n if len(times) != 0:\n ax.plot(\n times,\n freqs,\n lw=2,\n color=\"black\",\n label=\"Frequency traces\",\n )\n\n ax.scatter(\n ctimes,\n cfreqs,\n marker=\"o\",\n lw=1,\n facecolor=\"white\",\n edgecolor=\"black\",\n s=25,\n zorder=10,\n label=\"Chirp assignments\",\n )\n\n ax.set_ylim(flims[0] + 5, flims[1] - 5)\n ax.set_xlim([spec_times[0], spec_times[-1]])\n ax.set_xlabel(\"Time [s]\", fontsize=12)\n ax.set_ylabel(\"Frequency [Hz]\", fontsize=12)\n\n handles, labels = plt.gca().get_legend_handles_labels()\n by_label = dict(zip(labels, handles))\n plt.legend(\n by_label.values(),\n by_label.keys(),\n bbox_to_anchor=(0.5, 1.02),\n loc=\"lower center\",\n mode=\"None\",\n borderaxespad=0,\n ncol=3,\n fancybox=False,\n framealpha=0,\n )\n\n savepath = data.path / \"chirpdetections\"\n savepath.mkdir(exist_ok=True)\n plt.savefig(\n savepath / f\"cpd_{chunk_no}.png\",\n dpi=300,\n bbox_inches=\"tight\",\n )\n\n plt.close()\n plt.clf()\n plt.cla()\n plt.close(\"all\")\n
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections_cli","title":"plot_detections_cli(path)
","text":"Plot detections on images.
"},{"location":"api/plot_detections/#chirpdetector.plot_detections.plot_detections_cli--parameters","title":"Parameters","text":"path : pathlib.Path Path to the config file.
Source code in chirpdetector/plot_detections.py
def plot_detections_cli(path: pathlib.Path) -> None:\n \"\"\"Plot detections on images.\n\n Parameters\n ----------\n path : pathlib.Path\n Path to the config file.\n \"\"\"\n conf = load_config(path.parent / \"chirpdetector.toml\")\n data = load(path)\n chirp_df = pd.read_csv(path / \"chirpdetector_bboxes.csv\")\n plot_detections(data, chirp_df, conf)\n
"},{"location":"api/train_model/","title":"train_model","text":""},{"location":"api/train_model/#chirpdetector.train_model--train-the-faster-r-cnn-model","title":"Train the faster-R-CNN model.","text":"Train and test the neural network specified in the config file.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs","title":"plot_epochs(epoch_train_loss, epoch_val_loss, epoch_avg_train_loss, epoch_avg_val_loss, path)
","text":"Plot the loss for each epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs--parameters","title":"Parameters","text":" epoch_train_loss
: list
The training loss for each epoch. epoch_val_loss
: list
The validation loss for each epoch. epoch_avg_train_loss
: list
The average training loss for each epoch. epoch_avg_val_loss
: list
The average validation loss for each epoch. path
: pathlib.Path
The path to save the plot to.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_epochs--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def plot_epochs(\n epoch_train_loss: list,\n epoch_val_loss: list,\n epoch_avg_train_loss: list,\n epoch_avg_val_loss: list,\n path: pathlib.Path,\n) -> None:\n \"\"\"Plot the loss for each epoch.\n\n Parameters\n ----------\n - `epoch_train_loss`: `list`\n The training loss for each epoch.\n - `epoch_val_loss`: `list`\n The validation loss for each epoch.\n - `epoch_avg_train_loss`: `list`\n The average training loss for each epoch.\n - `epoch_avg_val_loss`: `list`\n The average validation loss for each epoch.\n - `path`: `pathlib.Path`\n The path to save the plot to.\n\n Returns\n -------\n - `None`\n \"\"\"\n _, ax = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True)\n\n x_train = np.arange(len(epoch_train_loss[0])) + 1\n x_val = np.arange(len(epoch_val_loss[0])) + len(epoch_train_loss[0]) + 1\n\n for train_loss, val_loss in zip(epoch_train_loss, epoch_val_loss):\n ax[0].plot(x_train, train_loss, c=\"tab:blue\", label=\"_\")\n ax[0].plot(x_val, val_loss, c=\"tab:orange\", label=\"_\")\n x_train = np.arange(len(epoch_train_loss[0])) + x_val[-1]\n x_val = np.arange(len(epoch_val_loss[0])) + x_train[-1]\n\n x_avg = np.arange(len(epoch_avg_train_loss)) + 1\n ax[1].plot(\n x_avg,\n epoch_avg_train_loss,\n label=\"Training Loss\",\n c=\"tab:blue\",\n )\n ax[1].plot(\n x_avg,\n epoch_avg_val_loss,\n label=\"Validation Loss\",\n c=\"tab:orange\",\n )\n\n ax[0].set_ylabel(\"Loss\")\n ax[0].set_xlabel(\"Batch\")\n ax[0].set_ylim(bottom=0)\n ax[0].set_title(\"Loss per batch\")\n\n ax[1].set_ylabel(\"Loss\")\n ax[1].set_xlabel(\"Epoch\")\n ax[1].legend()\n ax[1].set_ylim(bottom=0)\n ax[1].set_title(\"Avg loss per epoch\")\n\n plt.savefig(path)\n plt.close()\n
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds","title":"plot_folds(fold_avg_train_loss, fold_avg_val_loss, path)
","text":"Plot the loss for each fold.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds--parameters","title":"Parameters","text":" fold_avg_train_loss
: list
The average training loss for each fold. fold_avg_val_loss
: list
The average validation loss for each fold. path
: pathlib.Path
The path to save the plot to.
"},{"location":"api/train_model/#chirpdetector.train_model.plot_folds--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def plot_folds(\n fold_avg_train_loss: list,\n fold_avg_val_loss: list,\n path: pathlib.Path,\n) -> None:\n \"\"\"Plot the loss for each fold.\n\n Parameters\n ----------\n - `fold_avg_train_loss`: `list`\n The average training loss for each fold.\n - `fold_avg_val_loss`: `list`\n The average validation loss for each fold.\n - `path`: `pathlib.Path`\n The path to save the plot to.\n\n Returns\n -------\n - `None`\n \"\"\"\n _, ax = plt.subplots(figsize=(10, 5), constrained_layout=True)\n\n for train_loss, val_loss in zip(fold_avg_train_loss, fold_avg_val_loss):\n x = np.arange(len(train_loss)) + 1\n ax.plot(x, train_loss, c=\"tab:blue\", alpha=0.3, label=\"_\")\n ax.plot(x, val_loss, c=\"tab:orange\", alpha=0.3, label=\"_\")\n\n avg_train = np.mean(fold_avg_train_loss, axis=0)\n avg_val = np.mean(fold_avg_val_loss, axis=0)\n x = np.arange(len(avg_train)) + 1\n ax.plot(\n x,\n avg_train,\n label=\"Training Loss\",\n c=\"tab:blue\",\n )\n ax.plot(\n x,\n avg_val,\n label=\"Validation Loss\",\n c=\"tab:orange\",\n )\n\n ax.set_ylabel(\"Loss\")\n ax.set_xlabel(\"Epoch\")\n ax.legend()\n ax.set_ylim(bottom=0)\n\n plt.savefig(path)\n plt.close()\n
"},{"location":"api/train_model/#chirpdetector.train_model.save_model","title":"save_model(epoch, model, optimizer, path)
","text":"Save the model state dict.
"},{"location":"api/train_model/#chirpdetector.train_model.save_model--parameters","title":"Parameters","text":" epoch
: int
The current epoch. model
: torch.nn.Module
The model to save. optimizer
: torch.optim.Optimizer
The optimizer to save. path
: str
The path to save the model to.
"},{"location":"api/train_model/#chirpdetector.train_model.save_model--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def save_model(\n epoch: int,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n path: str,\n) -> None:\n \"\"\"Save the model state dict.\n\n Parameters\n ----------\n - `epoch`: `int`\n The current epoch.\n - `model`: `torch.nn.Module`\n The model to save.\n - `optimizer`: `torch.optim.Optimizer`\n The optimizer to save.\n - `path`: `str`\n The path to save the model to.\n\n Returns\n -------\n - `None`\n \"\"\"\n path = pathlib.Path(path)\n path.mkdir(parents=True, exist_ok=True)\n torch.save(\n {\n \"epoch\": epoch,\n \"model_state_dict\": model.state_dict(),\n \"optimizer_state_dict\": optimizer.state_dict(),\n },\n path / \"model.pt\",\n )\n
"},{"location":"api/train_model/#chirpdetector.train_model.train","title":"train(config, mode='pretrain')
","text":"Train the model.
"},{"location":"api/train_model/#chirpdetector.train_model.train--parameters","title":"Parameters","text":" config
: Config
The config file. mode
: str
The mode to train in. Either pretrain
or finetune
.
"},{"location":"api/train_model/#chirpdetector.train_model.train--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def train(config: Config, mode: str = \"pretrain\") -> None:\n \"\"\"Train the model.\n\n Parameters\n ----------\n - `config`: `Config`\n The config file.\n - `mode`: `str`\n The mode to train in. Either `pretrain` or `finetune`.\n\n Returns\n -------\n - `None`\n \"\"\"\n # Load a pretrained model from pytorch if in pretrain mode,\n # otherwise open an already trained model from the\n # model state dict.\n assert mode in [\"pretrain\", \"finetune\"]\n if mode == \"pretrain\":\n assert config.train.datapath is not None\n datapath = config.train.datapath\n elif mode == \"finetune\":\n assert config.finetune.datapath is not None\n datapath = config.finetune.datapath\n\n # Check if the path to the data actually exists\n if not pathlib.Path(datapath).exists():\n raise FileNotFoundError(f\"Path {datapath} does not exist.\")\n\n # Initialize the logger and progress bar, make the logger global\n global logger\n logger = make_logger(\n __name__,\n pathlib.Path(config.path).parent / \"chirpdetector.log\",\n )\n\n # Get the device (e.g. GPU or CPU)\n device = get_device()\n\n # Print information about starting training\n progress.console.rule(\"Starting training\")\n msg = (\n f\"Device: {device}, Config: {config.path},\"\n f\" Mode: {mode}, Data: {datapath}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n # initialize the dataset\n data = CustomDataset(\n path=datapath,\n classes=config.hyper.classes,\n )\n\n # initialize the k-fold cross-validation\n splits = KFold(n_splits=config.hyper.kfolds, shuffle=True, random_state=42)\n\n # initialize the best validation loss to a large number\n best_val_loss = float(\"inf\")\n\n # iterate over the folds for k-fold cross-validation\n with progress:\n # save loss across all epochs and folds\n fold_train_loss = []\n fold_val_loss = []\n fold_avg_train_loss = []\n fold_avg_val_loss = []\n\n # Add kfolds progress bar that runs alongside the epochs progress bar\n task_folds = progress.add_task(\n f\"[blue]{config.hyper.kfolds}-Fold Crossvalidation\",\n total=config.hyper.kfolds,\n )\n\n # iterate over the folds\n for fold, (train_idx, val_idx) in enumerate(\n splits.split(np.arange(len(data))),\n ):\n # initialize the model and optimizer\n model = load_fasterrcnn(num_classes=len(config.hyper.classes)).to(\n device,\n )\n\n # If the mode is finetune, load the model state dict from\n # previous training\n if mode == \"finetune\":\n modelpath = pathlib.Path(config.hyper.modelpath) / \"model.pt\"\n checkpoint = torch.load(modelpath, map_location=device)\n model.load_state_dict(checkpoint[\"model_state_dict\"])\n\n # Initialize stochastic gradient descent optimizer\n params = [p for p in model.parameters() if p.requires_grad]\n optimizer = torch.optim.SGD(\n params,\n lr=config.hyper.learning_rate,\n momentum=config.hyper.momentum,\n weight_decay=config.hyper.weight_decay,\n )\n\n # make train and validation dataloaders for the current fold\n train_data = torch.utils.data.Subset(data, train_idx)\n val_data = torch.utils.data.Subset(data, val_idx)\n\n # this is for training\n train_loader = DataLoader(\n train_data,\n batch_size=config.hyper.batch_size,\n shuffle=True,\n num_workers=config.hyper.num_workers,\n collate_fn=collate_fn,\n )\n\n # this is only for validation\n val_loader = DataLoader(\n val_data,\n batch_size=config.hyper.batch_size,\n shuffle=True,\n num_workers=config.hyper.num_workers,\n collate_fn=collate_fn,\n )\n\n # save loss across all epochs\n epoch_avg_train_loss = []\n epoch_avg_val_loss = []\n epoch_train_loss = []\n epoch_val_loss = []\n\n # train the model for the specified number of epochs\n task_epochs = progress.add_task(\n f\"{config.hyper.num_epochs} Epochs for fold k={fold + 1}\",\n total=config.hyper.num_epochs,\n )\n\n # iterate across n epochs\n for epoch in range(config.hyper.num_epochs):\n # print information about the current epoch\n msg = (\n f\"Training epoch {epoch + 1} of {config.hyper.num_epochs} \"\n f\"for fold {fold + 1} of {config.hyper.kfolds}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n # train the epoch\n train_loss = train_epoch(\n dataloader=train_loader,\n device=device,\n model=model,\n optimizer=optimizer,\n )\n\n # validate the epoch\n _, val_loss = val_epoch(\n dataloader=val_loader,\n device=device,\n model=model,\n )\n\n # save losses for this epoch\n epoch_train_loss.append(train_loss)\n epoch_val_loss.append(val_loss)\n\n # save the average loss for this epoch\n epoch_avg_train_loss.append(np.median(train_loss))\n epoch_avg_val_loss.append(np.median(val_loss))\n\n # save the model if it is the best so far\n if np.mean(val_loss) < best_val_loss:\n best_val_loss = sum(val_loss) / len(val_loss)\n\n msg = (\n f\"New best validation loss: {best_val_loss:.4f}, \"\n \"saving model...\"\n )\n progress.console.log(msg)\n logger.info(msg)\n\n save_model(\n epoch=epoch,\n model=model,\n optimizer=optimizer,\n path=config.hyper.modelpath,\n )\n\n # plot the losses for this epoch\n plot_epochs(\n epoch_train_loss=epoch_train_loss,\n epoch_val_loss=epoch_val_loss,\n epoch_avg_train_loss=epoch_avg_train_loss,\n epoch_avg_val_loss=epoch_avg_val_loss,\n path=pathlib.Path(config.hyper.modelpath)\n / f\"fold{fold + 1}.png\",\n )\n\n # update the progress bar for the epochs\n progress.update(task_epochs, advance=1)\n\n # update the progress bar for the epochs and hide it if done\n progress.update(task_epochs, visible=False)\n\n # save the losses for this fold\n fold_train_loss.append(epoch_train_loss)\n fold_val_loss.append(epoch_val_loss)\n fold_avg_train_loss.append(epoch_avg_train_loss)\n fold_avg_val_loss.append(epoch_avg_val_loss)\n\n plot_folds(\n fold_avg_train_loss=fold_avg_train_loss,\n fold_avg_val_loss=fold_avg_val_loss,\n path=pathlib.Path(config.hyper.modelpath) / \"losses.png\",\n )\n\n # update the progress bar for the folds\n progress.update(task_folds, advance=1)\n\n # update the progress bar for the folds and hide it if done\n progress.update(task_folds, visible=False)\n\n # print information about the training\n msg = (\n \"Average validation loss of last epoch across folds: \"\n f\"{np.mean(fold_val_loss):.4f}\"\n )\n progress.console.log(msg)\n logger.info(msg)\n progress.console.rule(\"[bold blue]Finished training\")\n
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli","title":"train_cli(config_path, mode)
","text":"Train the model from the command line.
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli--parameters","title":"Parameters","text":" config_path
: pathlib.Path
The path to the config file. mode
: str
The mode to train in. Either pretrain
or finetune
.
"},{"location":"api/train_model/#chirpdetector.train_model.train_cli--returns","title":"Returns","text":" Source code in chirpdetector/train_model.py
def train_cli(config_path: pathlib.Path, mode: str) -> None:\n \"\"\"Train the model from the command line.\n\n Parameters\n ----------\n - `config_path`: `pathlib.Path`\n The path to the config file.\n - `mode`: `str`\n The mode to train in. Either `pretrain` or `finetune`.\n\n Returns\n -------\n - `None`\n \"\"\"\n config = load_config(config_path)\n train(config, mode=mode)\n
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch","title":"train_epoch(dataloader, device, model, optimizer)
","text":"Train the model for one epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch--parameters","title":"Parameters","text":" dataloader
: DataLoader
The dataloader for the training data. device
: torch.device
The device to train on. model
: torch.nn.Module
The model to train. optimizer
: torch.optim.Optimizer
The optimizer to use.
"},{"location":"api/train_model/#chirpdetector.train_model.train_epoch--returns","title":"Returns","text":" train_loss
: List
The training loss for each batch.
Source code in chirpdetector/train_model.py
def train_epoch(\n dataloader: DataLoader,\n device: torch.device,\n model: torch.nn.Module,\n optimizer: torch.optim.Optimizer,\n) -> List:\n \"\"\"Train the model for one epoch.\n\n Parameters\n ----------\n - `dataloader`: `DataLoader`\n The dataloader for the training data.\n - `device`: `torch.device`\n The device to train on.\n - `model`: `torch.nn.Module`\n The model to train.\n - `optimizer`: `torch.optim.Optimizer`\n The optimizer to use.\n\n Returns\n -------\n - `train_loss`: `List`\n The training loss for each batch.\n \"\"\"\n train_loss = []\n\n for samples, targets in dataloader:\n images = list(sample.to(device) for sample in samples)\n targets = [\n {k: v.to(device) for k, v in t.items() if k != \"image_name\"}\n for t in targets\n ]\n\n loss_dict = model(images, targets)\n losses = sum(loss for loss in loss_dict.values())\n train_loss.append(losses.item())\n\n optimizer.zero_grad()\n losses.backward()\n optimizer.step()\n\n return train_loss\n
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch","title":"val_epoch(dataloader, device, model)
","text":"Validate the model for one epoch.
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch--parameters","title":"Parameters","text":" dataloader
: DataLoader
The dataloader for the validation data. device
: torch.device
The device to train on. model
: torch.nn.Module
The model to train.
"},{"location":"api/train_model/#chirpdetector.train_model.val_epoch--returns","title":"Returns","text":" loss_dict
: dict
The loss dictionary.
Source code in chirpdetector/train_model.py
def val_epoch(\n dataloader: DataLoader,\n device: torch.device,\n model: torch.nn.Module,\n) -> List:\n \"\"\"Validate the model for one epoch.\n\n Parameters\n ----------\n - `dataloader`: `DataLoader`\n The dataloader for the validation data.\n - `device`: `torch.device`\n The device to train on.\n - `model`: `torch.nn.Module`\n The model to train.\n\n Returns\n -------\n - `loss_dict`: `dict`\n The loss dictionary.\n \"\"\"\n val_loss = []\n for samples, targets in dataloader:\n images = list(sample.to(device) for sample in samples)\n targets = [\n {k: v.to(device) for k, v in t.items() if k != \"image_name\"}\n for t in targets\n ]\n\n with torch.inference_mode():\n loss_dict = model(images, targets)\n\n losses = sum(loss for loss in loss_dict.values())\n val_loss.append(losses.item())\n\n return loss_dict, val_loss\n
"}]}
\ No newline at end of file
diff --git a/setup/index.html b/setup/index.html
index 796c190..f95e35f 100644
--- a/setup/index.html
+++ b/setup/index.html
@@ -236,23 +236,6 @@
-
-
-
-
-
-
-
- Demo
-
-
-
-
-
-
-
-
-
@@ -448,15 +431,8 @@
-
-