Skip to content

Commit

Permalink
Merge pull request #39 from neuro-ml/dev
Browse files Browse the repository at this point in the history
bumped reqs
  • Loading branch information
maxme1 authored Jun 25, 2023
2 parents 4b3cbf8 + b014e03 commit 5a3b899
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 113 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ jobs:
remote:
storage:
- http: $CACHE_HOST
optional: true
remote:
nginx: $CACHE_HOST
meta:
fallback: local
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ jobs:
remote:
storage:
- http: $CACHE_HOST
optional: true
remote:
nginx: $CACHE_HOST
meta:
fallback: local
Expand Down
2 changes: 1 addition & 1 deletion amid/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.11.1'
__version__ = '0.12.0'
70 changes: 11 additions & 59 deletions amid/internals/cache.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from gzip import GzipFile
from pathlib import Path
from typing import Dict, Sequence, Union
from typing import Any, Callable, Sequence, Union

import numpy as np
from bev import Repository
from connectome import CacheColumns as Columns, CacheToDisk as Disk
from connectome.utils import StringsLike
from tarn import HashKeyStorage, ReadError
from tarn import ReadError
from tarn.serializers import (
ChainSerializer,
ContentsIn,
ContentsOut,
DictSerializer,
JsonSerializer as BaseJsonSerializer,
NumpySerializer,
PickleSerializer,
Serializer,
SerializerError,
Expand All @@ -27,7 +28,7 @@ def __init__(
repo = Repository.from_here('../data')
cache = repo.cache
super().__init__(
[x.root for x in cache.local[0].locations],
cache.local,
cache.storage,
serializer=default_serializer(serializer),
names=names,
Expand All @@ -46,7 +47,7 @@ def __init__(
repo = Repository.from_here('../data')
cache = repo.cache
super().__init__(
[x.root for x in cache.local[0].locations],
cache.local,
cache.storage,
serializer=default_serializer(serializer),
names=names,
Expand All @@ -68,67 +69,18 @@ def default_serializer(serializers):
return serializers


class NumpySerializer(Serializer):
def __init__(self, compression: Union[int, Dict[type, int]] = None):
self.compression = compression

def _choose_compression(self, value):
if isinstance(self.compression, int) or self.compression is None:
return self.compression

if isinstance(self.compression, dict):
for dtype in self.compression:
if np.issubdtype(value.dtype, dtype):
return self.compression[dtype]

def save(self, value, folder: Path):
if not isinstance(value, (np.ndarray, np.generic)):
raise SerializerError

compression = self._choose_compression(value)
if compression is not None:
assert isinstance(compression, int)
with GzipFile(folder / 'value.npy.gz', 'wb', compresslevel=compression, mtime=0) as file:
np.save(file, value, allow_pickle=False)

else:
np.save(folder / 'value.npy', value, allow_pickle=False)

def load(self, folder: Path, storage: HashKeyStorage):
paths = list(folder.iterdir())
if len(paths) != 1:
raise SerializerError

(path,) = paths
if path.name == 'value.npy':
loader = np.load
elif path.name == 'value.npy.gz':

def loader(x):
with GzipFile(x, 'rb') as file:
return np.load(file)

else:
raise SerializerError

try:
return self._load_file(storage, loader, path)
except (ValueError, EOFError) as e:
raise ReadError from e


class JsonSerializer(BaseJsonSerializer):
def save(self, value, folder: Path):
def save(self, value: Any, write: Callable) -> ContentsOut:
# if namedtuple
if isinstance(value, tuple) and hasattr(value, '_asdict') and hasattr(value, '_fields'):
raise SerializerError

return super().save(value, folder)
return super().save(value, write)


class CleanInvalid(Serializer):
def save(self, value, folder: Path):
def save(self, value: Any, write: Callable) -> ContentsOut:
raise SerializerError

def load(self, folder: Path, storage: HashKeyStorage):
def load(self, contents: ContentsIn, read: Callable) -> Any:
raise ReadError
74 changes: 27 additions & 47 deletions amid/internals/checksum.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import functools
import tempfile
from collections import defaultdict
from contextlib import suppress
from pathlib import Path
Expand Down Expand Up @@ -59,11 +58,9 @@ def _raw(root: str = None, version: str = Local):
args.append(
CacheAndCheck(
set(dir(ds)) - {'id', 'ids', *ignore},
repository,
repository.copy(version=version, fetch=True),
path,
fetch=True,
serializer=serializer,
version=version,
)
)

Expand All @@ -86,11 +83,9 @@ def _populate(

ds = ds >> CacheAndCheck(
fields,
repository,
repository.copy(fetch=fetch, version=Local),
path,
fetch=fetch,
serializer=serializer,
version=self._version,
return_tree=True,
)
_loader = ds._compile(fields)
Expand Down Expand Up @@ -137,28 +132,26 @@ def __init__(
*,
serializer=None,
return_tree: bool = False,
fetch: bool = True,
version,
):
super().__init__(names, False)
serializer = default_serializer(serializer)
# name -> identifier -> tree
checksums = defaultdict(lambda: defaultdict(dict))

with suppress(HashNotFound, ReadError):
for key, value in repository.load_tree(to_hash(Path(path)), version=version, fetch=fetch).items():
for key, value in repository.load_tree(to_hash(Path(path))).items():
name, identifier, relative = key.split('/', 2)
if name in names:
checksums[name][identifier][relative] = value

self.checksums = dict(checksums)
self.return_tree = return_tree
self.storage = repository.storage
self.repository = repository
self.serializer = serializer
self.keys = 'ids'

def _get_storage(self) -> Cache:
return self.storage
return self.repository

def _prepare_container(self, previous: EdgesBag) -> EdgesBag:
if len(previous.inputs) != 1:
Expand Down Expand Up @@ -208,10 +201,10 @@ def _prepare_container(self, previous: EdgesBag) -> EdgesBag:


class CheckSumEdge(StaticGraph, StaticHash):
def __init__(self, tree, serializer, storage, return_tree: bool, check: bool):
def __init__(self, tree, serializer, repository, return_tree: bool, check: bool):
super().__init__(arity=2)
self.return_tree = return_tree
self._serializer, self._storage = serializer, storage
self._serializer, self._repository = serializer, repository
self.tree = tree
self.check = check

Expand Down Expand Up @@ -248,41 +241,28 @@ def evaluate(self):
return value

def _deserialize(self, tree):
with tempfile.TemporaryDirectory() as base:
base = Path(base)
for k, v in tree.items():
k = base / k
k.parent.mkdir(parents=True, exist_ok=True)
with open(k, 'w') as file:
file.write(v)

try:
return self._serializer.load(base, self._storage), True
except ReadError as e:
if isinstance(e, DeserializationError):
locations = {}
for k, v in tree.items():
try:
locations[k] = self._storage.read(lambda x: x, v)
except ReadError:
pass

raise DeserializationError(f'{tree}: {locations}')
return None, False
def read(fn, x):
return self._repository.storage.read(fn, x, fetch=self._repository.fetch)

try:
return self._serializer.load(list(tree.items()), read), True

except ReadError as e:
if isinstance(e, DeserializationError):
locations = {}
for k, v in tree.items():
try:
locations[k] = read(lambda x: x, v)
except ReadError:
pass

raise DeserializationError(f'{tree}: {locations}')
return None, False

def _serialize(self, value):
with tempfile.TemporaryDirectory() as base:
base = Path(base)
self._serializer.save(value, base)
tree = {}
# TODO: this is basically `mirror to storage`
for file in base.glob('**/*'):
if file.is_dir():
continue

tree[str(file.relative_to(base))] = self._storage.write(file, labels=['amid.checksum']).hex()

return tree
return dict(
self._serializer.save(value, lambda v: self._repository.storage.write(v, labels=['amid.checksum']).hex())
)


# source: https://stackoverflow.com/a/61027781
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
connectome>=0.6.1,<1.0.0
bev>=0.8.0,<1.0.0
tarn>=0.5.0,<1.0.0
bev>=0.9.0,<1.0.0
tarn>=0.8.0,<1.0.0
numpy
nibabel
more-itertools
Expand Down

0 comments on commit 5a3b899

Please sign in to comment.