Skip to content

Commit

Permalink
Merge pull request #71 from wolny/modelzoo-2dunet
Browse files Browse the repository at this point in the history
Add a single 2D U-Net trained on ovules to bioimage-io
  • Loading branch information
Adrian Wolny authored Jun 22, 2020
2 parents 4dac5a9 + f7562e1 commit f7cb39f
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
name: 2D UNet Arabidopsis Ovules
description: A 2D U-Net trained to predict the cell boundaries in confocal stacks of Arabidopsis ovules. Trained on z-slices of 3D confocal images.
cite:
- text: "Wolny, Adrian et al. Accurate and Versatile 3D Segmentation of Plant Tissues at Cellular Resolution. BioRxiv 2020."
doi: https://doi.org/10.1101/2020.01.17.910562
authors:
- Adrian Wolny;@bioimage-io
documentation: ../../README.md
tags: [unet3d, pytorch, arabidopsis, ovuls, cell membrane, segmentation, plant tissue]
license: MIT

format_version: 0.1.0
language: python
framework: pytorch

source: pytorch3dunet.unet3d.model.UNet2D
optional_kwargs:
in_channels: 1
out_channels: 1
layer_order: gcr # determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
f_maps: [32, 64, 128, 256] # initial number of feature maps
num_groups: 8 # number of groups in the groupnorm
final_sigmoid: true # apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
is_segmentation: true # don't touch, use postprocessing instead
testing: false # don't touch, use postprocessing instead

#test_input: test_input.npz
#test_output: test_output.npz
covers: [raw.png, pred.png]

inputs:
- name: raw
axes: bcyx
data_type: float32
data_range: [-inf, inf]
shape: [1, 1, 256, 256]

outputs:
- name: cell_boundaries
axes: bcyx
data_type: float32
data_range: [0, 1]
halo: [0, 0, 16, 16]
shape:
reference_input: raw
scale: [1, 1, 1, 1]
offset: [0, 0, 0, 0]

prediction:
preprocess:
- spec: https://github.com/bioimage-io/pytorch-bioimage-io/blob/f71b8ac598267de88cd39e5495abd93dcda1d0a4/specs/transformations/EnsureTorch.transformation.yaml
- spec: https://github.com/bioimage-io/pytorch-bioimage-io/blob/f71b8ac598267de88cd39e5495abd93dcda1d0a4/specs/transformations/Cast.transformation.yaml
kwargs: {dtype: float32}
- spec: https://github.com/bioimage-io/pytorch-bioimage-io/blob/f71b8ac598267de88cd39e5495abd93dcda1d0a4/specs/transformations/NormalizeZeroMeanUnitVariance.transformation.yaml
kwargs: {apply_to: [0]}
weights:
source: https://oc.embl.de/index.php/s/61s67Mg5VQy7dh9/download?path=%2FArabidopsis-Ovules%2F2dunet_bce_dice_ds2x&files=best_checkpoint.pytorch
hash: {md5: 47ee0d24991e758eab65f87b2fc22de1}
postprocess:
- spec: https://github.com/bioimage-io/pytorch-bioimage-io/blob/f71b8ac598267de88cd39e5495abd93dcda1d0a4/specs/transformations/Sigmoid.transformation.yaml
- spec: https://github.com/bioimage-io/pytorch-bioimage-io/blob/f71b8ac598267de88cd39e5495abd93dcda1d0a4/specs/transformations/EnsureNumpy.transformation.yaml

dependencies: conda:environment.yaml
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
8 changes: 4 additions & 4 deletions tests/test_bioimage-io/test_UNet3DArabidopsisOvules.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from io import BytesIO
from pathlib import Path

import h5py
import imageio
import numpy
import numpy as np
import pytest
import torch

from pybio.core.transformations import apply_transformations
from pybio.spec import load_model
from pybio.spec.utils import get_instance

from pytorch3dunet.unet3d.model import UNet3D


Expand Down Expand Up @@ -37,8 +37,8 @@ def test_Net3DArabidopsisOvules_forward(cache_path):
assert spec_path.exists(), spec_path
pybio_model = load_model(str(spec_path), cache_path=cache_path)
assert pybio_model.spec.outputs[0].shape.reference_input == "raw"
assert pybio_model.spec.outputs[0].shape.scale == (1, 1, 1, 1, 1)
assert pybio_model.spec.outputs[0].shape.offset == (0, 0, 0, 0, 0)
assert np.allclose(pybio_model.spec.outputs[0].shape.scale, (1, 1, 1, 1, 1))
assert np.allclose(pybio_model.spec.outputs[0].shape.offset, (0, 0, 0, 0, 0))

assert isinstance(pybio_model.spec.prediction.weights.source, BytesIO)
assert pybio_model.spec.test_input is not None
Expand Down

0 comments on commit f7cb39f

Please sign in to comment.