Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/ContinualAI/avalanche int…
Browse files Browse the repository at this point in the history
…o remove_conda_files
  • Loading branch information
AntonioCarta committed Jan 25, 2024
2 parents 0f267e9 + f208a6c commit 0fd6873
Show file tree
Hide file tree
Showing 56 changed files with 1,656 additions and 301 deletions.
19 changes: 14 additions & 5 deletions avalanche/benchmarks/classic/cimagenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def SplitImageNet(
class_ids_from_zero_in_each_exp: bool = False,
class_ids_from_zero_from_first_exp: bool = False,
train_transform: Optional[Any] = _default_train_transform,
eval_transform: Optional[Any] = _default_eval_transform
eval_transform: Optional[Any] = _default_eval_transform,
meta_root: Optional[Union[str, Path]] = None,
):
"""
Creates a CL benchmark using the ImageNet dataset.
Expand Down Expand Up @@ -130,11 +131,19 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
comprehensive list of possible transformations).
If no transformation is passed, the default test transformation
will be used.
:param meta_root: Directory where the `ILSVRC2012_devkit_t12.tar.gz`
file can be found. The first time you use this dataset, the meta file will be
extracted from the archive and a `meta.bin` file will be created in the `meta_root`
directory. Defaults to None, which means that the meta file is expected to be
in the path provied in the `root` argument.
This is an additional argument not found in the original ImageNet class
from the torchvision package. For more info, see the `meta_root` argument
in the :class:`AvalancheImageNet` class.
:returns: A properly initialized :class:`NCScenario` instance.
"""

train_set, test_set = _get_imagenet_dataset(dataset_root)
train_set, test_set = _get_imagenet_dataset(dataset_root, meta_root=meta_root)

return nc_benchmark(
train_dataset=train_set,
Expand All @@ -152,10 +161,10 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
)


def _get_imagenet_dataset(root):
train_set = ImageNet(root, split="train")
def _get_imagenet_dataset(root, meta_root=None):
train_set = ImageNet(root, split="train", meta_root=meta_root)

test_set = ImageNet(root, split="val")
test_set = ImageNet(root, split="val", meta_root=meta_root)

return train_set, test_set

Expand Down
3 changes: 1 addition & 2 deletions avalanche/benchmarks/classic/openloris.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
a number of configuration parameters."""

from pathlib import Path
from typing import Union, Any, Optional
from typing_extensions import Literal
from typing import Union, Any, Optional, Literal

from avalanche.benchmarks.classic.classic_benchmarks_utils import (
check_vision_benchmark,
Expand Down
4 changes: 1 addition & 3 deletions avalanche/benchmarks/classic/stream51.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@
# Website: www.continualai.org #
################################################################################
from pathlib import Path
from typing import List, Optional, Union

from typing_extensions import Literal
from typing import List, Optional, Union, Literal

from avalanche.benchmarks.datasets import Stream51
from avalanche.benchmarks.scenarios.deprecated.generic_benchmark_creation import (
Expand Down
21 changes: 21 additions & 0 deletions avalanche/benchmarks/datasets/core50/core50.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import glob
import os
import pickle as pkl
import dill
from pathlib import Path
from typing import List, Optional, Tuple, Union
from warnings import warn
Expand All @@ -26,6 +27,7 @@
from avalanche.benchmarks.datasets.downloadable_dataset import (
DownloadableDataset,
)
from avalanche.checkpointing import constructor_based_serialization


class CORe50Dataset(DownloadableDataset):
Expand Down Expand Up @@ -247,6 +249,25 @@ def CORe50(*args, **kwargs):
return CORe50Dataset(*args, **kwargs)


@dill.register(CORe50Dataset)
def checkpoint_CORe50Dataset(pickler, obj: CORe50Dataset):
constructor_based_serialization(
pickler,
obj,
CORe50Dataset,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
loader=obj.loader,
mini=obj.mini,
object_level=obj.object_level,
),
)


if __name__ == "__main__":
# this litte example script can be used to visualize the first image
# leaded from the dataset.
Expand Down
19 changes: 19 additions & 0 deletions avalanche/benchmarks/datasets/cub200/cub200.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

import gdown
import os
import dill
from collections import OrderedDict
from torchvision.datasets.folder import default_loader

Expand All @@ -31,6 +32,7 @@
DownloadableDataset,
)
from avalanche.benchmarks.utils import PathsDataset
from avalanche.checkpointing import constructor_based_serialization


class CUB200(PathsDataset, DownloadableDataset):
Expand Down Expand Up @@ -178,6 +180,23 @@ def _load_metadata(self):
return True


@dill.register(CUB200)
def checkpoint_CUB200(pickler, obj: CUB200):
constructor_based_serialization(
pickler,
obj,
CUB200,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
loader=obj.loader,
),
)


if __name__ == "__main__":
"""Simple test that will start if you run this script directly"""

Expand Down
41 changes: 25 additions & 16 deletions avalanche/benchmarks/datasets/external_datasets/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision.datasets import CIFAR100, CIFAR10

from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


def get_cifar10_dataset(dataset_root):
Expand Down Expand Up @@ -31,26 +32,34 @@ def load_CIFAR100(root, train, transform, target_transform):


@dill.register(CIFAR100)
def save_CIFAR100(pickler, obj: CIFAR100):
pickler.save_reduce(
load_CIFAR100,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
)


def load_CIFAR10(root, train, transform, target_transform):
return CIFAR10(
root=root, train=train, transform=transform, target_transform=target_transform
def checkpoint_CIFAR100(pickler, obj: CIFAR100):
constructor_based_serialization(
pickler,
obj,
CIFAR100,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


@dill.register(CIFAR10)
def save_CIFAR10(pickler, obj: CIFAR10):
pickler.save_reduce(
load_CIFAR10,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
def checkpoint_CIFAR10(pickler, obj: CIFAR10):
constructor_based_serialization(
pickler,
obj,
CIFAR10,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
24 changes: 13 additions & 11 deletions avalanche/benchmarks/datasets/external_datasets/fmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torchvision.datasets import FashionMNIST

from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


def get_fmnist_dataset(dataset_root):
Expand All @@ -13,18 +14,19 @@ def get_fmnist_dataset(dataset_root):
return train_set, test_set


def load_FashionMNIST(root, train, transform, target_transform):
return FashionMNIST(
root=root, train=train, transform=transform, target_transform=target_transform
)


@dill.register(FashionMNIST)
def save_FashionMNIST(pickler, obj: FashionMNIST):
pickler.save_reduce(
load_FashionMNIST,
(obj.root, obj.train, obj.transform, obj.target_transform),
obj=obj,
def checkpoint_FashionMNIST(pickler, obj: FashionMNIST):
constructor_based_serialization(
pickler,
obj,
FashionMNIST,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
22 changes: 13 additions & 9 deletions avalanche/benchmarks/datasets/external_datasets/mnist.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import dill
from torchvision.datasets import MNIST
from avalanche.benchmarks.datasets import default_dataset_location
from avalanche.checkpointing import constructor_based_serialization


class TensorMNIST(MNIST):
Expand Down Expand Up @@ -35,16 +36,19 @@ def get_mnist_dataset(dataset_root):
return train_set, test_set


def load_MNIST(root, train, transform, target_transform):
return TensorMNIST(
root=root, train=train, transform=transform, target_transform=target_transform
)


@dill.register(TensorMNIST)
def save_MNIST(pickler, obj: TensorMNIST):
pickler.save_reduce(
load_MNIST, (obj.root, obj.train, obj.transform, obj.target_transform), obj=obj
def checkpoint_TensorMNIST(pickler, obj: TensorMNIST):
constructor_based_serialization(
pickler,
obj,
TensorMNIST,
deduplicate=True,
kwargs=dict(
root=obj.root,
train=obj.train,
transform=obj.transform,
target_transform=obj.target_transform,
),
)


Expand Down
1 change: 1 addition & 0 deletions avalanche/benchmarks/datasets/imagenet/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .imagenet import *
Loading

0 comments on commit 0fd6873

Please sign in to comment.