Skip to content

Commit

Permalink
removed old config entries
Browse files Browse the repository at this point in the history
  • Loading branch information
weygoldt committed Jan 30, 2025
1 parent a39a608 commit d16fdb5
Show file tree
Hide file tree
Showing 7 changed files with 80 additions and 52 deletions.
25 changes: 25 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,28 @@ It consists of two main steps:
of a single chirp, the detected chirps are assigned to individual fish.

![Flowchart](assets/chirpdetector_pipeline.png)

## Ho to use

**Step 1:** Track frequencies of fish using the [`wavetracker`](https://github.com/weygoldt/wavetracker).
**Step 2:** Install this package using pip:

```bash
pip install git+https://github.com/weygoldt/chirpdetector.git
```

**Step 3:** Copy the default config to the dataset root:

```bash
chirpdetector copyconfig -p /path/to/dataset
```

**Step 4:** Run the pipeline:
```bash
chirpdetector detect -p /path/to/dataset
```

## Open issues

There are many things that still need to done here! A list can be found in the [docs](docs/contributing.md).

37 changes: 19 additions & 18 deletions chirpdetector/config/config.toml
Original file line number Diff line number Diff line change
@@ -1,31 +1,32 @@
[hyperparameters]
classes = ["__background__", "chirp"] # classes for the detection
modelpath = "/home/weygoldt/wrk/analyses/chirp_competition/models" # path to save and load models from

### The following params are deprecated and do not have any effect
num_epochs = 10 # number of epochs to train for
batch_size = 2 # batch size for training
kfolds = 5 # number of folds for cross validation
learning_rate = 0.001 # learning rate for SGD
momentum = 0.9 # momentum for SGD
weight_decay = 0.0005 # regularization parameter
num_workers = 4 # number of workers for the data loader
###
modelpath = "/home/weygoldt/Projects/mscthesis/models" # path to save and load models from
# classes = ["__background__", "chirp"] # classes for the detection

[training]
datapath = "/home/weygoldt/Projects/mscthesis/data/interrim/labeled200" # path to training data
# ### The following params are deprecated and do not have any effect
# num_epochs = 10 # number of epochs to train for
# batch_size = 2 # batch size for training
# kfolds = 5 # number of folds for cross validation
# learning_rate = 0.001 # learning rate for SGD
# momentum = 0.9 # momentum for SGD
# weight_decay = 0.0005 # regularization parameter
# num_workers = 4 # number of workers for the data loader
# ###

[finetuning]
datapath = "/home/weygoldt/Projects/mscthesis/data/interrim/finetune" # path to finetuning data
# [training]
# datapath = "/home/weygoldt/wrk/analyses/chirp_competition/data/interrim/labeled200" # path to training data

# [finetuning]
# datapath = "/home/weygoldt/Projects/mscthesis/data/interrim/finetune" # path to finetuning data

[detection]
threshold = 0.5 # threshold for the detection

[spectrogram]
time_window = 15 # time window in seconds, determines size of the images
freq_ranges = [500, 1500] # frequency ranges to split the spectrograms into
batch_size = 4 # how many spectrograms to detect on at once,
freq_pad = 600 # padding in Hz so that limits are (lowest fish - padding, highest fish + padding)
freq_ranges = [0, 1000, 500, 1500, 1000, 2000] # frequency ranges to split the spectrograms into
batch_size = 4 # how many spectrograms (time windows) to detect on at once,
# freq_pad = 600 # padding in Hz so that limits are (lowest fish - padding, highest fish + padding)
freq_res = 6 # target frequency resolution of the spectrograms in Hz
overlap_frac = 0.9 # overlap of the fft windows in %
spec_overlap = 1 # overlap of the whole spectrogram in seconds
30 changes: 15 additions & 15 deletions chirpdetector/config/configfiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
class Hyperparams(BaseModel):
"""Class to store hyperparameters for training and finetuning."""

classes: List
num_epochs: int
batch_size: int
kfolds: int
learning_rate: float
momentum: float
weight_decay: float
num_workers: int
# classes: List
# num_epochs: int
# batch_size: int
# kfolds: int
# learning_rate: float
# momentum: float
# weight_decay: float
# num_workers: int
modelpath: str


Expand Down Expand Up @@ -46,7 +46,7 @@ class Spectrogram(BaseModel):
time_window: float
freq_ranges: List
freq_res: float
freq_pad: float
# freq_pad: float
overlap_frac: float
spec_overlap: float
batch_size: int
Expand All @@ -58,8 +58,8 @@ class Config(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
path: str
hyper: Hyperparams
train: Training
finetune: Finetune
# train: Training
# finetune: Finetune
det: Detection
spec: Spectrogram

Expand Down Expand Up @@ -118,15 +118,15 @@ def load_config(path: Union[str, pathlib.Path]) -> Config:
"""
file = toml.load(path)
hy = Hyperparams(**file["hyperparameters"])
tr = Training(**file["training"])
fi = Finetune(**file["finetuning"])
# tr = Training(**file["training"])
# fi = Finetune(**file["finetuning"])
det = Detection(**file["detection"])
spec = Spectrogram(**file["spectrogram"])
return Config(
path=str(path),
hyper=hy,
train=tr,
finetune=fi,
# train=tr,
# finetune=fi,
det=det,
spec=spec,
)
1 change: 1 addition & 0 deletions chirpdetector/datahandling/dataset_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,7 @@ def make_batch_specs(
for idxs in indices
]
batch_specs = [(spec, *ax) for spec, ax in zip(batch_sum_specs, axes)]

# Add the metadata to each spec tuple
batch_specs = [(meta, *spec) for meta, spec in zip(metadata, batch_specs)]

Expand Down
24 changes: 13 additions & 11 deletions chirpdetector/detection/detect_chirps.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,16 +301,21 @@ def detect_cli(input_path: pathlib.Path) -> None:
# print("chirpdetector_bboxes.csv exists, skipping")
# continue

dataset = dataset.parent
good_datasets.append(dataset)

print(
f"Out of {len(datasets)} a total of {len(good_datasets)} still need detecting"
)
good_datasets = sorted(good_datasets)
good_datasets = good_datasets[:4]
print(good_datasets)

with prog:
task = prog.add_task("Detecting chirps...", total=len(good_datasets))
for dataset in good_datasets:
prog.console.log(f"Detecting chirps in {dataset.name}")
data = load(dataset)
data = load(dataset, search_intermediate=True)
data = interpolate_tracks(data, samplerate=120)

# TODO: This is a mess,standardize this
Expand Down Expand Up @@ -494,24 +499,21 @@ def detect(self: Self) -> None:
prog.console.log(f"No detections in batch {i}.")
continue

# STEP 7: Associate the fundamental frequency of the emitter
# TODO: This should be all done at once after the full file is processed

# to the closest wavetracker track
# with Timer(prog.console, "Associate emitter frequency to tracks"):
# assigned_batch_df = assign_ffreqs_to_tracks(
# assigned_batch_df, self.data
# )
# UNDER CONSTRUCTION 🚧 -----------------------------------------
# TODO: Move shape of spec snippets for each chirp to config file
# TODO: Build a saving function to write all data into a HDF5
# TODO: Use the gridtools.datasets.models.ChirpDataV2 for this

# Extract spec snippet of every detected chirp for saving
# spec_snippets, time_snippets, freq_snippets = extract_spec_snippets(
# specs, times, freqs, assigned_batch_df
# )

# TODO: Move shape to config

# Resize each spec snippet to same dimensions for saving
# spec_snippets, time_snippets, freq_snippets, orig_shapes = resize_spec_snippets(
# spec_snippets, time_snippets, freq_snippets, 256
# )
# ---------------------------------------------------------------

# STEP 8: plot the detections
with Timer(prog.console, "Saving plot for current batch"):
Expand Down
6 changes: 1 addition & 5 deletions chirpdetector/detection/visualization_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from matplotlib.patches import Rectangle
from PIL import Image
from rich.console import Console
import seaborn as sns

console = Console()
backend = "Agg"
Expand Down Expand Up @@ -145,13 +146,8 @@ def plot_batch_detections(
t2 = assigned_batch_df["t2"].iloc[j]
f2 = assigned_batch_df["f2"].iloc[j]
score = assigned_batch_df["score"].iloc[j]
# track_id = assigned_batch_df["track_id"].iloc[j]
predicted_eodf = assigned_batch_df["emitter_eodf"].iloc[j]

# if np.isnan(track_id):
# continue

# color = track_colors[data.track.ids == track_id][0]
color = "white"

patches.append(
Expand Down
9 changes: 6 additions & 3 deletions docs/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ After the first release, this section will be removed an tasks will be
organized as github issues. Until them, if you fixed something, please check it
off on this list before opening a pull request.

- [ ] Rebuild the logging, verbosity and progress bar system like in the wavetracker
- [ ] Idea: Instead of simple non maximum supression for the final output
we could try max averaging: Run non-max supression with a specific
threshold and then group overlapping bboxes and average their
Expand Down Expand Up @@ -64,7 +65,7 @@ off on this list before opening a pull request.
and most importantly, give us an idea of which data in the validation
dataset is not well detected so that we can improve the training data
using quantitative measures.
- [ ] Move all dataframe operations to cudf, a pandas-like dataframe library
- [ ] Move all dataframe operations to cudf/polars, a pandas-like dataframe library
that runs on the GPU.
- [ ] Rethink the output: Needs to be a HDF5 file that not only includes
chirp time and ID but also the full chirp spectrograms so that
Expand All @@ -79,13 +80,15 @@ off on this list before opening a pull request.
- [ ] Find out why current assignment algo is failing at raw = raw1 - raw2
- [ ] Try a random forest classifier on PCAed envelope extractions to assign
chirps
- [ ] Finish a script to analyze the Json dumps from the training loop
- [ ] Update all the docstrings after refactoring.
- [ ] Move hardcoded params from assignment algo into config.toml
- [ ] Split the messy training loop into functions or remove it all together
and rely on external libraries for training
and rely on external libraries for training. Regarding this:
Only thing this package should do is genreate good training data, then
train wiht external libraries (e.g. ultralytics, ...).
- [ ] Remove all pyright warnings.
- [ ] Build github actions CI/CD pipeline for codecov etc.
- [x] Finish a script to analyze the Json dumps from the training loop
- [x] Implement detector class that works with the trained yolov8
- [x] Write an assingment benchmarking
- [x] Try a small NN, might work better than the random forest in this case
Expand Down

0 comments on commit d16fdb5

Please sign in to comment.