Skip to content

Commit

Permalink
adding biosr dataloader
Browse files Browse the repository at this point in the history
  • Loading branch information
della Maggiora Valdes, Gabriel Eugenio (FWU) - 154694 committed Dec 22, 2024
1 parent a7320de commit 7b80ed5
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 4 deletions.
6 changes: 3 additions & 3 deletions configs/biosr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ task: "biosr_sr"
model:
noise_model_type: "sr3"
alpha: 0.001
load_weights: '/home/gabriel/Documents/p_code/cvdm/outputs/biosr/weights/model_0_4e9fa08f-beb8-4335-907f-32bbe3f70a72.h5'
load_weights: null
load_mu_weights: null
snr_expansion_n: 1
zmd: False
Expand All @@ -23,9 +23,9 @@ eval:
val_len: 100

data:
dataset_path: "/media/gabriel/data_hdd/biosr_dataset/train/biosr_ds.npz"
dataset_path: "D:/biosr_dataset/valid/biosr_ds_test.npz"
n_samples: 100
batch_size: 4
batch_size: 1
im_size: 256

neptune:
Expand Down
31 changes: 31 additions & 0 deletions cvdm/data/biosr_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from typing import Iterator, Tuple

import numpy as np


class BioSRDataloader:
def __init__(
self,
path: str,
n_samples: int,
im_size: int,
) -> None:
self._x = np.load(f"{path}")['x']
self._y = np.load(f"{path}")['y']
self._im_size = im_size
self._n_samples: int = self._x.shape[0]

def __len__(self) -> int:
return self._n_samples

def get_channels(self) -> Tuple[int, int]:
return self._x.shape[-1], self._y.shape[-1]

def __getitem__(self, idx: int) -> Tuple[np.ndarray, np.ndarray]:
x, y = self._x[idx], self._y[idx]

return x, y

def __call__(self) -> Iterator[Tuple[np.ndarray, np.ndarray]]:
for i in range(self.__len__()):
yield self.__getitem__(i)
2 changes: 1 addition & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main() -> None:
print("Getting data...")
batch_size = data_config.batch_size
dataset, x_shape, y_shape = prepare_dataset(task, data_config, training=True)
dataset = dataset.shuffle(1000000, reshuffle_each_iteration=False)
dataset = dataset.shuffle(10000, reshuffle_each_iteration=False)
val_len = eval_config.val_len
dataset = dataset.skip(val_len)

Expand Down

0 comments on commit 7b80ed5

Please sign in to comment.