Skip to content

Commit

Permalink
Add batch dataloading
Browse files Browse the repository at this point in the history
  • Loading branch information
pierlj committed Dec 13, 2023
1 parent 0f76a9f commit 3a03110
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 42 deletions.
107 changes: 73 additions & 34 deletions examples/test.ipynb

Large diffs are not rendered by default.

35 changes: 32 additions & 3 deletions loreal_poc/dataloaders/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand All @@ -6,6 +7,9 @@


class DataIteratorBase(ABC):
batch_size: int
index_sampler: List[int]

def __init__(self) -> None:
super().__init__()
self.index = 0
Expand All @@ -31,6 +35,7 @@ def get_meta(self, idx: int) -> Optional[Dict]:
def __getitem__(
self, idx: int
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]: # (image, marks, meta)
idx = self.index_sampler[idx]
return self.get_image(idx), self.get_marks(idx), self.get_meta(idx)

@property
Expand All @@ -49,9 +54,24 @@ def all_meta(self) -> List: # (meta)
def __next__(self) -> Tuple[np.ndarray, np.ndarray]:
if self.index >= len(self):
raise StopIteration
elt = self[self.index]
self.index += 1
return elt
elt = [self[idx] for idx in range(self.index, min(len(self), self.index + self.batch_size))]
self.index += self.batch_size
if self.batch_size == 1:
return elt[0]
else:
return self._collate(elt)

def _collate(
self, elements: List[Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]]
) -> Tuple[np.ndarray, Optional[np.ndarray], Optional[Dict[Any, Any]]]:
batched_elements = list(zip(*elements))
# TO DO: create image stack but require same size images and therefore automatic padding or resize.
# batched_elements[0] = batched_elements[0]
if elements[0][1] is not None:
batched_elements[1] = np.stack(batched_elements[1], axis=0)
if elements[0][2] is not None:
batched_elements[2] = {key: [meta[key] for meta in batched_elements[2]] for key in batched_elements[2][0]}
return batched_elements


class DataLoaderBase(DataIteratorBase):
Expand All @@ -66,6 +86,8 @@ def __init__(
images_dir_path: Union[str, Path],
landmarks_dir_path: Union[str, Path],
meta: Optional[Dict[str, Any]] = None,
batch_size: Optional[int] = None,
shuffle: Optional[bool] = False,
) -> None:
super().__init__()
images_dir_path = self._get_absolute_local_path(images_dir_path)
Expand All @@ -79,6 +101,13 @@ def __init__(
f"for {len(self.marks_paths)} of the images."
)

self.batch_size = batch_size if batch_size else 1

self.shuffle = shuffle
self.index_sampler = [idx for idx in range(len(self))]
if shuffle:
random.shuffle(self.index_sampler)

self.meta = {
**(meta if meta is not None else {}),
"num_samples": len(self),
Expand Down
6 changes: 2 additions & 4 deletions loreal_poc/dataloaders/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ class DataLoader300W(DataLoaderBase):
n_landmarks: int = 68
n_dimensions: int = 2

def __init__(
self,
dir_path: Union[str, Path],
) -> None:
def __init__(self, dir_path: Union[str, Path], **kwargs) -> None:
super().__init__(
dir_path,
dir_path,
Expand All @@ -28,6 +25,7 @@ def __init__(
"preprocessed": False,
"preprocessing_time": 0.0,
},
**kwargs,
)

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion loreal_poc/dataloaders/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,4 @@ def __getitem__(self, idx: int) -> Tuple[np.ndarray, Optional[np.ndarray], Optio
self._cache_idxs.insert(0, idx)
if len(self._cache_idxs) > self._max_size:
self._cache.pop(self._cache_idxs.pop(-1))
return self._cache_idxs[idx]
return self._cache[idx]

0 comments on commit 3a03110

Please sign in to comment.