Skip to content

Commit

Permalink
Merge pull request #84 from neuro-ml/develop
Browse files Browse the repository at this point in the history
Faster patches grid
  • Loading branch information
vovaf709 authored Dec 26, 2023
2 parents 0448e2d + 6259a7f commit 913e671
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 43 deletions.
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.0'
__version__ = '0.3.1'
52 changes: 46 additions & 6 deletions dpipe/im/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Iterable, Type, Tuple, Callable

import numpy as np
import torch
from imops.numeric import pointwise_add

from .shape_ops import crop_to_box
Expand All @@ -14,7 +15,7 @@
from .shape_utils import shape_after_convolution
from .utils import build_slices

__all__ = 'get_boxes', 'divide', 'combine', 'PatchCombiner', 'Average'
__all__ = 'get_boxes', 'make_batch', 'break_batch', 'divide', 'combine', 'PatchCombiner', 'Average'


def get_boxes(shape: AxesLike, box_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
Expand Down Expand Up @@ -49,6 +50,27 @@ def get_boxes(shape: AxesLike, box_size: AxesLike, stride: AxesLike, axis: AxesL
yield make_box_([start, np.minimum(start + box_size, shape)])


def make_batch(divide_iterator, batch_size: int = 1):
patches_to_batch = []
n = 0
for patch in divide_iterator:
patches_to_batch.append(torch.from_numpy(patch))
n += 1

if n == batch_size:
n = 0
yield torch.cat(patches_to_batch).numpy()
patches_to_batch = []
if len(patches_to_batch) != 0:
yield torch.cat(patches_to_batch).numpy()


def break_batch(prediction_iterator: Iterable):
for prediction in prediction_iterator:
for single_prediction in prediction:
yield single_prediction[None, ]


def divide(x: np.ndarray, patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
valid: bool = False, get_boxes: Callable = get_boxes) -> Iterable[np.ndarray]:
"""
Expand Down Expand Up @@ -88,23 +110,41 @@ def build(self) -> np.ndarray:


class Average(PatchCombiner):
def __init__(self, shape: Tuple[int, ...], dtype: np.dtype, **imops_kwargs: dict):
def __init__(self, shape: Tuple[int, ...], dtype: np.dtype, use_torch: bool = True, **imops_kwargs: dict):
super().__init__(shape, dtype)
self._result = np.zeros(shape, dtype)
self._counts = np.zeros(shape, int)
self._counts = np.zeros(shape, np.uint8 if use_torch else int)
self._imops_kwargs = imops_kwargs

self._use_torch = use_torch

def update(self, box: Box, patch: np.ndarray):
slc = build_slices(*box)

result_slc = self._result[slc]
pointwise_add(result_slc, patch.astype(result_slc.dtype, copy=False), result_slc, **self._imops_kwargs)
if self._use_torch:
result_slc_torch = torch.from_numpy(result_slc)
patch_torch = torch.from_numpy(patch.astype(result_slc.dtype, copy=False))
result_slc_torch += patch_torch
else:
pointwise_add(result_slc, patch.astype(result_slc.dtype, copy=False), result_slc, **self._imops_kwargs)

counts_slc = self._counts[slc]
pointwise_add(counts_slc, 1, counts_slc, **self._imops_kwargs)
if self._use_torch:
counts_slc_torch = torch.from_numpy(counts_slc)
counts_slc_torch += 1
else:
pointwise_add(counts_slc, 1, counts_slc, **self._imops_kwargs)

def build(self):
np.true_divide(self._result, self._counts, out=self._result, where=self._counts > 0)
if self._use_torch:
result_torch = torch.from_numpy(self._result)
counts_torch = torch.from_numpy(self._counts)

counts_torch[counts_torch == 0] = 1
torch.div(result_torch, counts_torch, out=result_torch)
else:
np.true_divide(self._result, self._counts, out=self._result, where=self._counts > 0)
return self._result


Expand Down
39 changes: 39 additions & 0 deletions dpipe/itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from functools import wraps
from itertools import chain
from operator import itemgetter
from threading import Thread
from typing import Iterable, Sized, Union, Callable, Sequence, Any, Tuple
from queue import Queue

import numpy as np

Expand Down Expand Up @@ -111,6 +113,43 @@ def pmap(func: Callable, iterable: Iterable, *args, **kwargs) -> Iterable:
yield func(value, *args, **kwargs)


class FinishToken:
pass


class AsyncPmap:
def __init__(self, func: Callable, iterable: Iterable, *args, **kwargs) -> None:
self.__func = func
self.__iterable = iterable
self.__args = args
self.__kwargs = kwargs

self.__result_queue = Queue(1)

self.__working_thread = Thread(
target = self._prediction_func
)

def start(self) -> None:
self.__working_thread.start()

def _prediction_func(self) -> None:
for value in self.__iterable:
self.__result_queue.put(self.__func(value, *self.__args, **self.__kwargs))
self.__result_queue.put(FinishToken)

def __iter__(self):
return self

def __next__(self) -> Any:
obj = self.__result_queue.get()
if obj is FinishToken:
self.__working_thread.join()
assert not self.__working_thread.is_alive()
raise StopIteration
return obj


def dmap(func: Callable, dictionary: dict, *args, **kwargs):
"""
Transform the ``dictionary`` by mapping ``func`` over its values.
Expand Down
31 changes: 17 additions & 14 deletions dpipe/predict/shape.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from functools import wraps
from typing import Union, Callable, Type
from typing import Union, Callable, Type, Iterable
from warnings import warn

import numpy as np

from ..im.axes import broadcast_to_axis, AxesLike, AxesParams, axis_from_dim, resolve_deprecation
from ..im.grid import divide, combine, get_boxes, PatchCombiner, Average
from ..itertools import extract, pmap
from ..im.grid import divide, combine, get_boxes, PatchCombiner, Average, make_batch, break_batch
from ..itertools import extract, pmap, AsyncPmap
from ..im.shape_ops import pad_to_shape, crop_to_shape, pad_to_divisible
from ..im.shape_utils import prepend_dims, extract_dims

Expand Down Expand Up @@ -81,8 +82,8 @@ def wrapper(x, *args, **kwargs):

def patches_grid(patch_size: AxesLike, stride: AxesLike, axis: AxesLike = None,
padding_values: Union[AxesParams, Callable] = 0, ratio: AxesParams = 0.5,
combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes, stream: bool = False,
**imops_kwargs):
combiner: Type[PatchCombiner] = Average, get_boxes: Callable = get_boxes,
use_torch: bool = True, async_predict: bool = True, batch_size: int = None, **imops_kwargs):
"""
Divide an incoming array into patches of corresponding ``patch_size`` and ``stride`` and then combine
the predicted patches by aggregating the overlapping regions using the ``combiner`` - Average by default.
Expand Down Expand Up @@ -111,18 +112,20 @@ def wrapper(x, *args, **kwargs):
elif ((shape - local_size) < 0).any() or ((local_stride - shape + local_size) % local_stride).any():
raise ValueError('Input cannot be patched without remainder.')

if stream:
patches = predict(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), *args, **kwargs)
divide_wrapper = make_batch if batch_size is not None else lambda x, batch_size: x
patches_wrapper = break_batch if batch_size is not None else lambda x: x

input_patches = divide_wrapper(divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes), batch_size=batch_size)

if async_predict:
patches = AsyncPmap(predict, input_patches, *args, **kwargs)
patches.start()
else:
patches = pmap(
predict,
divide(x, local_size, local_stride, input_axis, get_boxes=get_boxes),
*args, **kwargs
)
patches = pmap(predict, input_patches, *args, **kwargs)

prediction = combine(
patches, extract(x.shape, input_axis), local_stride, axis,
combiner=combiner, get_boxes=get_boxes,
patches_wrapper(patches), extract(x.shape, input_axis), local_stride, axis,
combiner=combiner, get_boxes=get_boxes, use_torch=use_torch
)

if valid:
Expand Down
33 changes: 12 additions & 21 deletions tests/predict/test_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,20 @@ def batch_size(request):


@pytest.fixture(params=[False, True])
def stream(request):
def use_torch(request):
return request.param


def test_patches_grid(stream):
@pytest.fixture(params=[False, True])
def async_predict(request):
return request.param


def test_patches_grid(use_torch, async_predict, batch_size):
def check_equal(**kwargs):
assert_eq(x, patches_grid(**kwargs, stream=stream, axis=-1)(identity)(x))
predict = patches_grid(**kwargs, use_torch=use_torch, async_predict=async_predict, axis=-1, batch_size=batch_size)(lambda x: x + 1)
predict = add_extract_dims(1)(predict)
assert_eq(x + 1, predict(x))

x = np.random.randn(3, 23, 20, 27) * 10
check_equal(patch_size=10, stride=1, padding_values=0)
Expand All @@ -40,29 +47,13 @@ def check_equal(**kwargs):
check_equal(patch_size=15, stride=12, padding_values=None)


def test_divisible_patches(stream):
def test_divisible_patches():
def check_equal(**kwargs):
assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs, stream=stream)(identity))(x))
assert_eq(x, divisible_shape(divisible)(patches_grid(**kwargs)(identity))(x))

size = [80] * 3
stride = [20] * 3
divisible = [8] * 3
for shape in [(373, 302, 55), (330, 252, 67)]:
x = np.random.randn(*shape)
check_equal(patch_size=size, stride=stride)


@pytest.mark.skipif(sys.version_info < (3, 7), reason='Requires python3.7 or higher.')
def test_batched_patches_grid(batch_size):
from more_itertools import batched
from itertools import chain

def patch_predict(patch):
return patch + 1

def stream_predict(patches_generator):
return chain.from_iterable(pmap(patch_predict, map(np.array, batched(patches_generator, batch_size))))

x = np.random.randn(3, 23, 20, 27) * 10

assert_eq(x + 1, patches_grid(patch_size=(6, 8, 9), stride=(4, 3, 2), stream=True, axis=(-1, -2, -3))(stream_predict)(x))
12 changes: 11 additions & 1 deletion tests/test_itertools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np

from dpipe.im.utils import filter_mask
from dpipe.itertools import zip_equal, flatten, extract, negate_indices, head_tail, peek
from dpipe.itertools import zip_equal, flatten, extract, negate_indices, head_tail, peek, AsyncPmap


class TestItertools(unittest.TestCase):
Expand Down Expand Up @@ -65,3 +65,13 @@ def test_peek(self):
head, new_it = peek(self.make_iterable(it))
self.assertEqual(head, it[0])
self.assertListEqual(list(new_it), it)

def test_async_pmap(self):
foo = lambda x: x**2
iterable = range(10)
async_results = AsyncPmap(foo, iterable)
async_results.start()
for i in iterable:
assert foo(i) == next(async_results)
with self.assertRaises(StopIteration):
next(async_results)

0 comments on commit 913e671

Please sign in to comment.