Skip to content

Commit

Permalink
updating cellpose args documentation and fixing cellpose import
Browse files Browse the repository at this point in the history
  • Loading branch information
JoOkuma committed Jun 20, 2024
1 parent 7f60ad6 commit a516ad0
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 7 deletions.
7 changes: 7 additions & 0 deletions ultrack/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@
logger = logging.getLogger()
logger.setLevel(logging.INFO)

# Cellpose and ultrack had conflicts due to torch/cuda leading to Segmentation Fault
# importing Cellpose first avoids the issue, https://github.com/royerlab/ultrack/issues/108
try:
from cellpose.models import Cellpose # noqa: F401
except (ImportError, ModuleNotFoundError):
pass

# ignoring small float32/64 zero flush warning
warnings.filterwarnings("ignore", message="The value of the smallest subnormal for")

Expand Down
28 changes: 21 additions & 7 deletions ultrack/imgproc/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import logging
from typing import Optional
from typing import Callable, Optional

import edt
import numpy as np
Expand Down Expand Up @@ -211,20 +212,33 @@ def inverted_edt(
return dist


def _maybe_wrap(wrapper_name: str) -> Callable:
"""Wraps function with cellpose model method if cellpose is available."""
try:
from cellpose.models import CellposeModel as _Cellpose
except ImportError:
return lambda x: x

return functools.wraps(getattr(_Cellpose, wrapper_name))


class Cellpose:
@_maybe_wrap("__init__")
def __init__(self, **kwargs) -> None:
"""See cellpose.models.Cellpose documentation for details."""
from cellpose.models import CellposeModel as _Cellpose
try:
from cellpose.models import CellposeModel as _Cellpose
except ImportError as e:
raise ImportError(
"Cellpose not found, please install it."
"See for instructions https://github.com/MouseLand/cellpose"
) from e

if "pretrained_model" not in kwargs and "model_type" not in kwargs:
kwargs["model_type"] = "cyto"

self.model = _Cellpose(**kwargs)

@_maybe_wrap("eval")
def __call__(self, image: ArrayLike, **kwargs) -> np.ndarray:
"""
Predicts image labels.
See cellpose.models.Cellpose.eval documentation for details.
"""
labels, _, _ = self.model.eval(image, **kwargs)
return labels

0 comments on commit a516ad0

Please sign in to comment.