Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v2.8.0 #735

Merged
merged 22 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 180 additions & 0 deletions docs/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
# Datasets

Datasets classes give you a way to automatically download a dataset and transform it into a PyTorch dataset.

All implemented datasets have disjoint train-test splits, ideal for benchmarking on image retrieval and one-shot/few-shot classification tasks.

## BaseDataset

All dataset classes extend this class and therefore inherit its ```__init__``` parameters.

```python
datasets.base_dataset.BaseDataset(
root,
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

**Parameters**:

* **root**: The path where the dataset files are saved.
* **split**: A string that determines which split of the dataset is loaded.
* **transform**: A `torchvision.transforms` object which will be used on the input images.
* **target_transform**: A `torchvision.transforms` object which will be used on the labels.
* **download**: Whether to download the dataset or not. Setting this as False, but not having the dataset on the disk will raise a ValueError.

**Required Implementations**:
```python
@abstractmethod
def download_and_remove():
raise NotImplementedError

@abstractmethod
def generate_split():
raise NotImplementedError
```

## CUB-200-2011

```python
datasets.cub.CUB(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 5864 examples, taken from classes 1 to 100.
- `test` - Consists of 5924 examples, taken from classes 101 to 200.
- `train+test` - Consists 11788 of examples, taken from all classes.

**Loading different dataset splits**
```python
train_dataset = CUB(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = CUB(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = CUB(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## Cars196

```python
datasets.cars196.Cars196(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 8054 examples, taken from classes 1 to 99.
- `test` - Consists of 8131 examples, taken from classes 99 to 197.
- `train+test` - Consists of 16185 examples, taken from all classes.

**Loading different dataset splits**
```python
train_dataset = Cars196(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = Cars196(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = Cars196(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## INaturalist2018

```python
datasets.inaturalist2018.INaturalist2018(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 325 846 examples.
- `test` - Consists of 136 093 examples.
- `train+test` - Consists of 461 939 examples.

**Loading different dataset splits**
```python
# The download takes a while - the dataset is very large
train_dataset = INaturalist2018(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = INaturalist2018(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = INaturalist2018(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```

## StanfordOnlineProducts

```python
datasets.sop.StanfordOnlineProducts(*args, **kwargs)
```

**Defined splits**:

- `train` - Consists of 59551 examples.
- `test` - Consists of 60502 examples.
- `train+test` - Consists of 120 053 examples.

**Loading different dataset splits**
```python
# The download takes a while - the dataset is very large
train_dataset = StanfordOnlineProducts(root="data",
split="train",
transform=None,
target_transform=None,
download=True
)
# No need to download the dataset after it is already downladed
test_dataset = StanfordOnlineProducts(root="data",
split="test",
transform=None,
target_transform=None,
download=False
)
train_and_test_dataset = StanfordOnlineProducts(root="data",
split="train+test",
transform=None,
target_transform=None,
download=False
)
```
34 changes: 34 additions & 0 deletions docs/extend/datasets.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# How to write custom datasets

1. Subclass the ```datasets.base_dataset.BaseDataset``` class
2. Add implementations for abstract methods from the base class:
- ```download_and_remove()```
- ```generate_split()```


```python
from pytorch_metric_learning.datasets.base_dataset import BaseDataset

class MyDataset(BaseDataset):

def __init__(self, my_parameter, *args, **kwargs):
super().__init__(*args, **kwargs)
self.my_parameter = self.my_parameter

def download_and_remove(self):
# Downloads the dataset files needed
#
# If you're using a dataset that you've already downloaded elsewhere,
# just use an empty implementation
pass

def generate_split(self):
# Creates a list of image paths, and saves them into self.paths
# Creates a list of labels for the images, and saves them into self.labels
#
# The default training splits that need to be covered are `train`, `test`, and `train+test`
# If you need a different split setup, override `get_available_splits(self)` to return
# the split names you want
pass

```
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
site_name: PyTorch Metric Learning
nav:
- Home: index.md
- Datasets: datasets.md
- Distances: distances.md
- Losses: losses.md
- Miners: miners.md
Expand All @@ -16,6 +17,7 @@ nav:
- Common Functions: common_functions.md
- Distributed: distributed.md
- How to extend this library:
- Custom datasets: extend/datasets.md
- Custom losses: extend/losses.md
- Custom miners: extend/miners.md
- Frequently Asked Questions: faq.md
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "2.7.0"
__version__ = "2.8.0"
71 changes: 71 additions & 0 deletions src/pytorch_metric_learning/datasets/base_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import os
from abc import ABC, abstractmethod

from PIL import Image
from torch.utils.data import Dataset


class BaseDataset(ABC, Dataset):

def __init__(
self,
root,
split="train+test",
transform=None,
target_transform=None,
download=False,
):
self.root = root

if download:
if not os.path.isdir(self.root):
os.makedirs(self.root, exist_ok=False)
self.download_and_remove()
elif os.listdir(self.root) == []:
self.download_and_remove()
else:
# The given directory does not exist so the user should be aware of downloading it
# Otherwise proceed as usual
if not os.path.isdir(self.root):
raise ValueError(
"The given path does not exist. "
"You should probably initialize the dataset with download=True."
)

self.transform = transform
self.target_transform = target_transform

if split not in self.get_available_splits():
raise ValueError(
f"Supported splits are: {', '.join(self.get_available_splits())}"
)

self.split = split

self.generate_split()

@abstractmethod
def generate_split():
raise NotImplementedError

@abstractmethod
def download_and_remove():
raise NotImplementedError

def get_available_splits(self):
return ["train", "test", "train+test"]

def __len__(self):
return len(self.labels)

def __getitem__(self, idx):
img = Image.open(self.paths[idx])
label = self.labels[idx]

if self.transform is not None:
img = self.transform(img)

if self.target_transform is not None:
label = self.target_transform(label)

return (img, label)
67 changes: 67 additions & 0 deletions src/pytorch_metric_learning/datasets/cars196.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import os
import zipfile

from ..datasets.base_dataset import BaseDataset
from ..utils.common_functions import _urlretrieve


class Cars196(BaseDataset):

DOWNLOAD_URL = "https://www.kaggle.com/api/v1/datasets/download/jutrera/stanford-car-dataset-by-classes-folder"

def generate_split(self):
# Training set is first 99 classes, test is other classes
if self.split == "train":
classes = set(range(1, 99))
elif self.split == "test":
classes = set(range(99, 197))
else:
classes = set(range(1, 197))

with open(os.path.join(self.root, "names.csv"), "r") as f:
names = [x.strip() for x in f.readlines()]

paths_train, labels_train = self._load_csv(
os.path.join(self.root, "anno_train.csv"), names, split="train"
)
paths_test, labels_test = self._load_csv(
os.path.join(self.root, "anno_test.csv"), names, split="test"
)
paths = paths_train + paths_test
labels = labels_train + labels_test

self.paths, self.labels = [], []
for p, l in zip(paths, labels):
if l in classes:
self.paths.append(p)
self.labels.append(l)

def _load_csv(self, path, names, split):
all_paths, all_labels = [], []
with open(path, "r") as f:
for l in f:
path_annos = l.split(",")
curr_path = path_annos[0]
curr_label = path_annos[-1]
all_paths.append(
os.path.join(
self.root,
"car_data",
"car_data",
split,
names[int(curr_label) - 1].replace("/", "-"),
curr_path,
)
)
all_labels.append(int(curr_label))
return all_paths, all_labels

def download_and_remove(self):
os.makedirs(self.root, exist_ok=True)
download_folder_path = os.path.join(
self.root, Cars196.DOWNLOAD_URL.split("/")[-1]
)
_urlretrieve(url=Cars196.DOWNLOAD_URL, filename=download_folder_path)
with zipfile.ZipFile(download_folder_path, "r") as zip_ref:
zip_ref.extractall(self.root)
os.remove(download_folder_path)
Loading
Loading