-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Loading status checks…
Merge pull request #196 from saalfeldlab/feature/streak-finder
Various tools for de-streaking
Showing
12 changed files
with
1,030 additions
and
2 deletions.
There are no files selected for viewing
122 changes: 122 additions & 0 deletions
122
render-app/src/main/java/org/janelia/alignment/destreak/StreakFinder.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
package org.janelia.alignment.destreak; | ||
|
||
import ij.IJ; | ||
import ij.ImagePlus; | ||
import ij.process.FloatProcessor; | ||
import ij.process.ImageProcessor; | ||
|
||
import java.io.Serializable; | ||
|
||
/** | ||
* This class detects streaks in an image and returns a corresponding mask. | ||
* <p> | ||
* The finder first applies a derivative filter in the x-direction to detect vertical edges. Then, it applies a mean | ||
* filter in the y-direction to smooth out the edges in the y-direction. The resulting image is then thresholded | ||
* (from above and below) to create a mask of the streaks. Finally, an optional Gaussian blur is applied to the mask to | ||
* smooth it. The mask is 0 where there are no streaks and 255 where there are streaks. | ||
* <p> | ||
* There are three parameters that can be set: | ||
* <ul> | ||
* <li>meanFilterSize: the number of pixels to average in the y-direction (e.g., 0 means no averaging, 50 means averaging +/-50 pixels in y)</li> | ||
* <li>threshold: the threshold used to convert the streak mask to a binary mask</li> | ||
* <li>blurRadius: the radius of the Gaussian blur applied to the streak mask (0 means no smoothing)</li> | ||
* </ul> | ||
*/ | ||
public class StreakFinder implements Serializable { | ||
|
||
private final int meanFilterSize; | ||
private final double threshold; | ||
private final int blurRadius; | ||
|
||
public StreakFinder(final int meanFilterSize, final double threshold, final int blurRadius) { | ||
if (meanFilterSize < 0) { | ||
throw new IllegalArgumentException("meanFilterSize must be non-negative"); | ||
} | ||
if (threshold < 0) { | ||
throw new IllegalArgumentException("threshold must be non-negative"); | ||
} | ||
if (blurRadius < 0) { | ||
throw new IllegalArgumentException("blurRadius must be 0 (no blur) or positive"); | ||
} | ||
|
||
this.meanFilterSize = meanFilterSize; | ||
this.threshold = threshold; | ||
this.blurRadius = blurRadius; | ||
} | ||
|
||
public ImagePlus createStreakMask(final ImagePlus input) { | ||
ImageProcessor filtered = differenceFilterX(input.getProcessor()); | ||
filtered = meanFilterY(filtered, meanFilterSize); | ||
filtered = bidirectionalThreshold(filtered, threshold); | ||
|
||
final ImagePlus mask = new ImagePlus("Mask", filtered); | ||
if (blurRadius > 0) { | ||
IJ.run(mask, "Gaussian Blur...", String.format("sigma=%d", blurRadius)); | ||
} | ||
return mask; | ||
} | ||
|
||
private static ImageProcessor differenceFilterX(final ImageProcessor in) { | ||
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight()); | ||
final int width = in.getWidth(); | ||
final int height = in.getHeight(); | ||
|
||
for (int y = 0; y < height; y++) { | ||
for (int x = 0; x < width; x++) { | ||
final float left = in.getf(projectPeriodically(x - 1, width), y); | ||
final float right = in.getf(projectPeriodically(x + 1, width), y); | ||
out.setf(x, y, (right - left) / 2); | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
private static ImageProcessor meanFilterY(final ImageProcessor in, final int size) { | ||
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight()); | ||
final int width = in.getWidth(); | ||
final int height = in.getHeight(); | ||
final int n = 2 * size + 1; | ||
|
||
for (int x = 0; x < width; x++) { | ||
// initialize running sum | ||
float sum = in.getf(x, 0); | ||
for (int y = 1; y <= size; y++) { | ||
sum += 2 * in.getf(x, y); | ||
} | ||
out.setf(x, 0, sum / n); | ||
|
||
// update running sum by adding the next value and subtracting the oldest value | ||
for (int y = 1; y < height; y++) { | ||
final float oldest = in.getf(x, projectPeriodically(y - size - 1, height)); | ||
final float newest = in.getf(x, projectPeriodically(y + size, height)); | ||
sum += newest - oldest; | ||
out.setf(x, y, sum / n); | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
private static ImageProcessor bidirectionalThreshold(final ImageProcessor in, final double threshold) { | ||
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight()); | ||
final int width = in.getWidth(); | ||
final int height = in.getHeight(); | ||
|
||
for (int y = 0; y < height; y++) { | ||
for (int x = 0; x < width; x++) { | ||
final float value = Math.abs(in.getf(x, y)); | ||
out.setf(x, y, (value > threshold) ? 255 : 0); | ||
} | ||
} | ||
return out; | ||
} | ||
|
||
private static int projectPeriodically(final int index, final int max) { | ||
if (index < 0) { | ||
return -index; | ||
} else if (index >= max) { | ||
return 2 * max - index - 2; | ||
} else { | ||
return index; | ||
} | ||
} | ||
} |
44 changes: 44 additions & 0 deletions
44
render-app/src/main/java/org/janelia/alignment/inpainting/AnisotropicDirection2D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
package org.janelia.alignment.inpainting; | ||
|
||
import java.util.Random; | ||
|
||
/** | ||
* A statistic that yields a small perturbation of a given 2D direction for each sample. | ||
*/ | ||
public class AnisotropicDirection2D implements DirectionalStatistic { | ||
|
||
private final Random random; | ||
private final double[] primalAxis; | ||
private final double[] secondaryAxis; | ||
private final double perturbation; | ||
|
||
/** | ||
* Creates a new statistic with a random seed. | ||
*/ | ||
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation) { | ||
this(primalAxis, perturbation, new Random()); | ||
} | ||
|
||
/** | ||
* Creates a new statistic with the given random number generator. | ||
* | ||
* @param random the random number generator to use | ||
*/ | ||
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation, final Random random) { | ||
final double norm = Math.sqrt(primalAxis[0] * primalAxis[0] + primalAxis[1] * primalAxis[1]); | ||
this.primalAxis = new double[] { primalAxis[0] / norm, primalAxis[1] / norm }; | ||
this.secondaryAxis = new double[] { -primalAxis[1] / norm, primalAxis[0] / norm }; | ||
this.perturbation = perturbation; | ||
this.random = random; | ||
} | ||
|
||
@Override | ||
public void sample(final double[] direction) { | ||
// TODO: this should be a von Mises distribution instead of this homegrown implementation | ||
final int sign = random.nextBoolean() ? 1 : -1; | ||
final double eps = perturbation * random.nextGaussian(); | ||
final double norm = 1 + eps * eps; // because axes are orthonormal | ||
direction[0] = (sign * primalAxis[0] + eps * secondaryAxis[0]) / norm; | ||
direction[1] = (sign * primalAxis[1] + eps * secondaryAxis[1]) / norm; | ||
} | ||
} |
14 changes: 14 additions & 0 deletions
14
render-app/src/main/java/org/janelia/alignment/inpainting/DirectionalStatistic.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package org.janelia.alignment.inpainting; | ||
|
||
/** | ||
* Interface for distributions that model the direction of a ray used in {@link RayCastingInpainter}. | ||
*/ | ||
public interface DirectionalStatistic { | ||
|
||
/** | ||
* Initializes the direction of the next ray. The array that is passed in is filled with the direction. | ||
* | ||
* @param direction the array in which to initialize the direction | ||
*/ | ||
void sample(double[] direction); | ||
} |
34 changes: 34 additions & 0 deletions
34
render-app/src/main/java/org/janelia/alignment/inpainting/RandomDirection2D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
package org.janelia.alignment.inpainting; | ||
|
||
import java.util.Random; | ||
|
||
/** | ||
* A statistic that yields a completely random 2D direction for each sample. | ||
*/ | ||
public class RandomDirection2D implements DirectionalStatistic { | ||
|
||
private final Random random; | ||
|
||
/** | ||
* Creates a new statistic with a random seed. | ||
*/ | ||
public RandomDirection2D() { | ||
this(new Random()); | ||
} | ||
|
||
/** | ||
* Creates a new statistic with the given random number generator. | ||
* | ||
* @param random the random number generator to use | ||
*/ | ||
public RandomDirection2D(final Random random) { | ||
this.random = random; | ||
} | ||
|
||
@Override | ||
public void sample(final double[] direction) { | ||
final double angle = random.nextDouble() * 2 * Math.PI; | ||
direction[0] = Math.cos(angle); | ||
direction[1] = Math.sin(angle); | ||
} | ||
} |
38 changes: 38 additions & 0 deletions
38
render-app/src/main/java/org/janelia/alignment/inpainting/RandomDirection3D.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
package org.janelia.alignment.inpainting; | ||
|
||
import java.util.Random; | ||
|
||
/** | ||
* A statistic that yields a completely random 3D direction for each sample. | ||
*/ | ||
public class RandomDirection3D implements DirectionalStatistic { | ||
|
||
private final Random random; | ||
|
||
/** | ||
* Creates a new statistic with a random seed. | ||
*/ | ||
public RandomDirection3D() { | ||
this(new Random()); | ||
} | ||
|
||
/** | ||
* Creates a new statistic with the given random number generator. | ||
* | ||
* @param random the random number generator to use | ||
*/ | ||
public RandomDirection3D(final Random random) { | ||
this.random = random; | ||
} | ||
|
||
@Override | ||
public void sample(final double[] direction) { | ||
final double x = random.nextGaussian(); | ||
final double y = random.nextGaussian(); | ||
final double z = random.nextGaussian(); | ||
final double norm = Math.sqrt(x * x + y * y + z * z); | ||
direction[0] = x / norm; | ||
direction[1] = y / norm; | ||
direction[2] = z / norm; | ||
} | ||
} |
128 changes: 128 additions & 0 deletions
128
render-app/src/main/java/org/janelia/alignment/inpainting/RayCastingInpainter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
package org.janelia.alignment.inpainting; | ||
|
||
import net.imglib2.RealInterval; | ||
|
||
import net.imglib2.Cursor; | ||
import net.imglib2.Interval; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.RealLocalizable; | ||
import net.imglib2.RealRandomAccess; | ||
import net.imglib2.interpolation.randomaccess.NLinearInterpolatorFactory; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import net.imglib2.view.Views; | ||
|
||
|
||
/** | ||
* Infer missing values in an image (up to 3D) by ray casting (which is equivalent to diffusion of image values). | ||
* <p> | ||
* This is adapted from the hotknife repository for testing purposes. | ||
*/ | ||
public class RayCastingInpainter { | ||
|
||
private final int nRays; | ||
private final long maxRayLength; | ||
private final DirectionalStatistic directionStatistic; | ||
|
||
private final double[] direction = new double[3]; | ||
private final Result result = new Result(); | ||
|
||
public RayCastingInpainter(final int nRays, final int maxInpaintingDiameter, final DirectionalStatistic directionStatistic) { | ||
this.nRays = nRays; | ||
this.maxRayLength = maxInpaintingDiameter; | ||
this.directionStatistic = directionStatistic; | ||
} | ||
|
||
private static boolean isInside(final RealLocalizable p, final RealInterval r) { | ||
for (int d = 0; d < p.numDimensions(); ++d) { | ||
final double l = p.getDoublePosition(d); | ||
if (l < r.realMin(d) || l > r.realMax(d)) { | ||
return false; | ||
} | ||
} | ||
return true; | ||
} | ||
|
||
/** | ||
* Inpaints missing values in an image (up to 3D) by casting rays in random directions and averaging the values of | ||
* the first non-masked pixel. | ||
* | ||
* @param img the image to inpaint | ||
* @param mask the mask | ||
*/ | ||
public void inpaint(final RandomAccessibleInterval<FloatType> img, final RandomAccessibleInterval<FloatType> mask) { | ||
final Cursor<FloatType> imgCursor = Views.iterable(img).localizingCursor(); | ||
|
||
final RealRandomAccess<FloatType> imageAccess = Views.interpolate(Views.extendBorder(img), new NLinearInterpolatorFactory<>()).realRandomAccess(); | ||
final RealRandomAccess<FloatType> maskAccess = Views.interpolate(Views.extendBorder(mask), new NLinearInterpolatorFactory<>()).realRandomAccess(); | ||
|
||
while (imgCursor.hasNext()) { | ||
final FloatType o = imgCursor.next(); | ||
final float m = maskAccess.setPositionAndGet(imgCursor).get(); | ||
if (m == 0.0) { | ||
// pixel not masked, no inpainting necessary | ||
continue; | ||
} | ||
|
||
double weightSum = 0; | ||
double valueSum = 0; | ||
|
||
// interpolate value by casting rays in random directions and averaging (weighted by distances) the | ||
// values of the first non-masked pixel | ||
for (int i = 0; i < nRays; ++i) { | ||
final Result result = castRay(maskAccess, mask, imgCursor); | ||
if (result != null) { | ||
final double weight = 1.0 / result.distance; | ||
weightSum += weight; | ||
final double value = imageAccess.setPositionAndGet(result.position).getRealDouble(); | ||
valueSum += value * weight; | ||
} | ||
} | ||
|
||
final float v = (float) (valueSum / weightSum); | ||
final float w = m / 255.0f; | ||
final float oldValue = o.get(); | ||
final float newValue = v * w + oldValue * (1 - w); | ||
o.set(newValue); | ||
} | ||
} | ||
|
||
/** | ||
* Casts a ray from the given position in a random direction until it hits a non-masked (i.e., non-NaN) pixel | ||
* or exits the image boundary. | ||
* | ||
* @param mask the mask indicating which pixels are masked (> 0) and which are not (0) | ||
* @param interval the interval of the image | ||
* @param position the position from which to cast the ray | ||
* @return the result of the ray casting or null if the ray exited the image boundary without hitting a | ||
* non-masked pixel | ||
*/ | ||
private Result castRay(final RealRandomAccess<FloatType> mask, final Interval interval, final RealLocalizable position) { | ||
mask.setPosition(position); | ||
directionStatistic.sample(direction); | ||
long steps = 0; | ||
|
||
while(true) { | ||
mask.move(direction); | ||
++steps; | ||
|
||
if (!isInside(mask, interval) || steps > maxRayLength) { | ||
// the ray exited the image boundaries without hitting a non-masked pixel | ||
return null; | ||
} | ||
|
||
final float value = mask.get().get(); | ||
if (value < 1.0) { | ||
// the ray reached a non-masked pixel | ||
mask.localize(result.position); | ||
result.distance = steps; | ||
return result; | ||
} | ||
} | ||
} | ||
|
||
|
||
private static class Result { | ||
public double[] position = new double[3]; | ||
public double distance = 0; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
71 changes: 71 additions & 0 deletions
71
render-app/src/test/java/org/janelia/alignment/destreak/StreakFinderTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package org.janelia.alignment.destreak; | ||
|
||
import ij.ImageJ; | ||
import ij.ImagePlus; | ||
import ij.process.FloatProcessor; | ||
import ij.process.ImageProcessor; | ||
import net.imglib2.img.Img; | ||
import net.imglib2.img.display.imagej.ImageJFunctions; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import org.janelia.alignment.inpainting.RandomDirection2D; | ||
import org.janelia.alignment.inpainting.RayCastingInpainter; | ||
|
||
public class StreakFinderTest { | ||
public static void main(final String[] args) { | ||
final String srcPath = "/home/innerbergerm@hhmi.org/big-data/streak-correction/jrc_mus-liver-zon-3/z00032-0-0-1.png"; | ||
final StreakFinder finder = new StreakFinder(100, 5.0, 3); | ||
// final StreakCorrector corrector = new SmoothMaskStreakCorrector(12, 6161, 8190, 10, 10, 0); | ||
final RayCastingInpainter inpainter = new RayCastingInpainter(128, 100, new RandomDirection2D()); | ||
|
||
final long start = System.currentTimeMillis(); | ||
final ImagePlus original = new ImagePlus(srcPath); | ||
final ImagePlus mask = finder.createStreakMask(original); | ||
// final ImagePlus corrected = streakCorrectFourier(corrector, original, mask); | ||
final ImagePlus corrected = streakCorrectInpainting(inpainter, original, mask); | ||
System.out.println("Processing time: " + (System.currentTimeMillis() - start) + "ms"); | ||
|
||
new ImageJ(); | ||
mask.show(); | ||
original.show(); | ||
corrected.show(); | ||
} | ||
|
||
private static ImagePlus streakCorrectFourier( | ||
final StreakCorrector corrector, | ||
final ImagePlus original, | ||
final ImagePlus mask) { | ||
|
||
final ImagePlus corrected = original.duplicate(); | ||
corrected.setTitle("Corrected"); | ||
corrector.process(corrected.getProcessor(), 1.0); | ||
|
||
final ImageProcessor proc = corrected.getProcessor(); | ||
final ImageProcessor maskProc = mask.getProcessor(); | ||
final ImageProcessor originalProc = original.getProcessor(); | ||
for (int i = 0; i < corrected.getWidth() * corrected.getHeight(); i++) { | ||
final float lambda = maskProc.getf(i) / 255.0f; | ||
final float mergedValue = originalProc.getf(i) * (1 - lambda) + proc.getf(i) * lambda; | ||
proc.setf(i, mergedValue); | ||
} | ||
|
||
return corrected; | ||
} | ||
|
||
private static ImagePlus streakCorrectInpainting( | ||
final RayCastingInpainter inpainter, | ||
final ImagePlus original, | ||
final ImagePlus mask) { | ||
|
||
final FloatProcessor correctedFp = original.getProcessor().convertToFloatProcessor(); | ||
final ImagePlus corrected = new ImagePlus("Corrected", correctedFp); | ||
final Img<FloatType> correctedImg = ImageJFunctions.wrapFloat(corrected); | ||
|
||
final FloatProcessor maskFp = mask.getProcessor().convertToFloatProcessor(); | ||
final ImagePlus maskIp = new ImagePlus("Mask", maskFp); | ||
final Img<FloatType> maskImg = ImageJFunctions.wrapFloat(maskIp); | ||
|
||
inpainter.inpaint(correctedImg, maskImg); | ||
|
||
return corrected; | ||
} | ||
} |
45 changes: 45 additions & 0 deletions
45
...java-client/src/main/java/org/janelia/render/client/parameter/StreakFinderParameters.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package org.janelia.render.client.parameter; | ||
|
||
import com.beust.jcommander.Parameter; | ||
import org.janelia.alignment.destreak.StreakFinder; | ||
|
||
import java.io.Serializable; | ||
|
||
/** | ||
* Parameters for streak finding with a {@link StreakFinder}. | ||
*/ | ||
public class StreakFinderParameters implements Serializable { | ||
@Parameter( | ||
names = "--meanFilterSize", | ||
description = "Number of pixels to average in the positive and negative y-direction", | ||
required = true) | ||
public int meanFilterSize; | ||
|
||
@Parameter( | ||
names = "--threshold", | ||
description = "Threshold used to convert the streak mask to a binary mask", | ||
required = true) | ||
public double threshold; | ||
|
||
@Parameter( | ||
names = "--blurRadius", | ||
description = "Radius of the Gaussian blur applied to the streak mask", | ||
required = true) | ||
public int blurRadius; | ||
|
||
public StreakFinder createStreakFinder() { | ||
return new StreakFinder(meanFilterSize, threshold, blurRadius); | ||
} | ||
|
||
public void validate() { | ||
if (meanFilterSize < 0) { | ||
throw new IllegalArgumentException("meanFilterSize must be non-negative"); | ||
} | ||
if (threshold < 0) { | ||
throw new IllegalArgumentException("threshold must be non-negative"); | ||
} | ||
if (blurRadius < 0) { | ||
throw new IllegalArgumentException("blurRadius must be 0 (no blur) or positive"); | ||
} | ||
} | ||
} |
86 changes: 86 additions & 0 deletions
86
render-ws-java-client/src/test/java/org/janelia/render/client/InpaintingTest.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package org.janelia.render.client; | ||
|
||
import ij.ImageJ; | ||
import ij.ImagePlus; | ||
import ij.process.FloatProcessor; | ||
import ij.process.ImageProcessor; | ||
import net.imglib2.img.Img; | ||
import net.imglib2.img.display.imagej.ImageJFunctions; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import org.janelia.alignment.inpainting.AnisotropicDirection2D; | ||
import org.janelia.alignment.inpainting.DirectionalStatistic; | ||
import org.janelia.alignment.inpainting.RayCastingInpainter; | ||
|
||
public class InpaintingTest { | ||
|
||
private static final int N_RAYS = 256; | ||
private static final int MAX_INPAINTING_DIAMETER = 100; | ||
private static final double THRESHOLD = 20.0; | ||
|
||
private static final int Y_MIN = 3826; | ||
private static final int Y_MAX = 3856; | ||
|
||
public static void main(final String[] args) { | ||
final String srcPath = "/home/innerbergerm@hhmi.org/big-data/streak-correction/jrc_P3-E2-D1-Lip4-19/z14765-0-0-0.png"; | ||
final ImagePlus original = new ImagePlus(srcPath); | ||
final DirectionalStatistic directionStatistic = new AnisotropicDirection2D(new double[] {0, 1}, 0.5); | ||
final RayCastingInpainter inpainter = new RayCastingInpainter(N_RAYS, MAX_INPAINTING_DIAMETER, directionStatistic); | ||
|
||
// final ImagePlus mask = threshold(original.getProcessor(), THRESHOLD); | ||
final ImagePlus mask = bandMask(original.getProcessor(), Y_MIN, Y_MAX); | ||
|
||
final long start = System.currentTimeMillis(); | ||
final ImagePlus corrected = streakCorrectInpainting(inpainter, original, mask); | ||
System.out.println("Processing time: " + (System.currentTimeMillis() - start) + "ms"); | ||
|
||
new ImageJ(); | ||
original.show(); | ||
mask.show(); | ||
corrected.show(); | ||
} | ||
|
||
private static ImagePlus threshold(final ImageProcessor in, final double threshold) { | ||
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight()); | ||
final int width = in.getWidth(); | ||
final int height = in.getHeight(); | ||
|
||
for (int y = 0; y < height; y++) { | ||
for (int x = 0; x < width; x++) { | ||
final float value = in.getf(x, y); | ||
out.setf(x, y, (value < threshold) ? 255 : 0); | ||
} | ||
} | ||
return new ImagePlus("Mask", out); | ||
} | ||
|
||
private static ImagePlus bandMask(final ImageProcessor in, final int yMin, final int yMax) { | ||
final int width = in.getWidth(); | ||
final int height = in.getHeight(); | ||
final ImageProcessor out = new FloatProcessor(width, height); | ||
|
||
for (int y = yMin; y <= yMax; y++) { | ||
for (int x = 0; x < width; x++) { | ||
out.setf(x, y, 255); | ||
} | ||
} | ||
return new ImagePlus("Mask", out); | ||
} | ||
|
||
private static ImagePlus streakCorrectInpainting( | ||
final RayCastingInpainter inpainter, | ||
final ImagePlus original, | ||
final ImagePlus mask) { | ||
|
||
final FloatProcessor correctedFp = original.getProcessor().convertToFloatProcessor(); | ||
final ImagePlus corrected = new ImagePlus("Corrected", correctedFp); | ||
final Img<FloatType> correctedImg = ImageJFunctions.wrapFloat(corrected); | ||
|
||
final FloatProcessor maskFp = mask.getProcessor().convertToFloatProcessor(); | ||
final ImagePlus maskIp = new ImagePlus("Mask", maskFp); | ||
final Img<FloatType> maskImg = ImageJFunctions.wrapFloat(maskIp); | ||
|
||
inpainter.inpaint(correctedImg, maskImg); | ||
|
||
return corrected; | ||
} | ||
} |
71 changes: 71 additions & 0 deletions
71
render-ws-spark-client/src/main/java/org/janelia/render/client/JrcP3E2D1Lip419Inpainter.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
package org.janelia.render.client; | ||
|
||
import ij.ImageJ; | ||
import net.imglib2.RandomAccessibleInterval; | ||
import net.imglib2.converter.Converters; | ||
import net.imglib2.img.Img; | ||
import net.imglib2.img.display.imagej.ImageJFunctions; | ||
import net.imglib2.loops.LoopBuilder; | ||
import net.imglib2.type.numeric.integer.UnsignedByteType; | ||
import net.imglib2.type.numeric.real.FloatType; | ||
import net.imglib2.util.Util; | ||
import org.janelia.alignment.inpainting.DirectionalStatistic; | ||
import org.janelia.alignment.inpainting.RandomDirection3D; | ||
import org.janelia.alignment.inpainting.RayCastingInpainter; | ||
import org.janelia.render.client.spark.n5.N5Client; | ||
import org.janelia.saalfeldlab.n5.N5FSReader; | ||
import org.janelia.saalfeldlab.n5.N5Reader; | ||
import org.janelia.saalfeldlab.n5.imglib2.N5Utils; | ||
|
||
public class JrcP3E2D1Lip419Inpainter { | ||
|
||
// path to the N5 container containing the data and a mask | ||
public static final String N5_BASE_PATH = System.getenv("HOME") + "/big-data/streak-correction/jrc_P3-E2-D1-Lip4-19/broken_layers.n5"; | ||
|
||
public static void main(final String[] args) { | ||
// saveSubStack(); | ||
doInpainting(); | ||
} | ||
|
||
@SuppressWarnings("unused") | ||
private static void saveSubStack() { | ||
final String[] n5ClientArgs = { | ||
"--baseDataUrl", "http://renderer-dev.int.janelia.org:8080/render-ws/v1", | ||
"--owner", "fibsem", | ||
"--project", "jrc_P3_E2_D1_Lip4_19", | ||
"--stack", "v1_acquire_trimmed_align_v2", | ||
"--n5Path", N5_BASE_PATH, | ||
"--n5Dataset", "tissue", | ||
"--tileWidth", "2048", | ||
"--tileHeight", "2048", | ||
"--blockSize", "512,512,128", | ||
"--minZ", "14763", | ||
"--maxZ", "14805", | ||
}; | ||
|
||
N5Client.main(n5ClientArgs); | ||
} | ||
|
||
@SuppressWarnings("unused") | ||
private static void doInpainting() { | ||
final DirectionalStatistic directionalStatistic = new RandomDirection3D(); | ||
final RayCastingInpainter inpainter = new RayCastingInpainter(64, 50, directionalStatistic); | ||
|
||
try (final N5Reader n5 = new N5FSReader(N5_BASE_PATH)) { | ||
final RandomAccessibleInterval<UnsignedByteType> tissue = N5Utils.open(n5, "tissue_crop"); | ||
final RandomAccessibleInterval<UnsignedByteType> mask = N5Utils.open(n5, "mask_crop_v2"); | ||
|
||
// convert to float (and copy tissue) | ||
final Img<FloatType> tissueFloat = Util.getSuitableImgFactory(tissue, new FloatType()).create(tissue); | ||
LoopBuilder.setImages(tissue, tissueFloat).forEachPixel((i, o) -> o.setReal(i.getRealDouble())); | ||
final RandomAccessibleInterval<FloatType> maskFloat = Converters.convert(mask, (i, o) -> o.setReal(i.getRealDouble()), new FloatType()); | ||
|
||
inpainter.inpaint(tissueFloat, maskFloat); | ||
|
||
new ImageJ(); | ||
ImageJFunctions.show(tissueFloat, "Tissue"); | ||
ImageJFunctions.show(maskFloat, "Mask"); | ||
} | ||
} | ||
|
||
} |
371 changes: 371 additions & 0 deletions
371
...client/src/main/java/org/janelia/render/client/spark/destreak/StreakStatisticsClient.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,371 @@ | ||
package org.janelia.render.client.spark.destreak; | ||
|
||
import com.beust.jcommander.Parameter; | ||
import com.beust.jcommander.ParametersDelegate; | ||
import ij.ImagePlus; | ||
import ij.process.ImageProcessor; | ||
import mpicbg.models.CoordinateTransform; | ||
import mpicbg.models.CoordinateTransformList; | ||
import net.imglib2.img.Img; | ||
import net.imglib2.img.array.ArrayImgs; | ||
import net.imglib2.type.numeric.real.DoubleType; | ||
import net.imglib2.view.Views; | ||
import org.apache.spark.SparkConf; | ||
import org.apache.spark.api.java.JavaSparkContext; | ||
import org.apache.spark.broadcast.Broadcast; | ||
import org.janelia.alignment.destreak.StreakFinder; | ||
import org.janelia.alignment.loader.ImageLoader; | ||
import org.janelia.alignment.spec.Bounds; | ||
import org.janelia.alignment.spec.ResolvedTileSpecCollection; | ||
import org.janelia.alignment.spec.TileSpec; | ||
import org.janelia.alignment.spec.stack.StackMetaData; | ||
import org.janelia.alignment.util.ImageProcessorCache; | ||
import org.janelia.render.client.ClientRunner; | ||
import org.janelia.render.client.RenderDataClient; | ||
import org.janelia.render.client.parameter.CommandLineParameters; | ||
import org.janelia.render.client.parameter.RenderWebServiceParameters; | ||
import org.janelia.render.client.parameter.StreakFinderParameters; | ||
import org.janelia.render.client.parameter.ZRangeParameters; | ||
import org.janelia.render.client.spark.LogUtilities; | ||
import org.janelia.saalfeldlab.n5.GzipCompression; | ||
import org.janelia.saalfeldlab.n5.N5Writer; | ||
import org.janelia.saalfeldlab.n5.imglib2.N5Utils; | ||
import org.janelia.saalfeldlab.n5.zarr.N5ZarrWriter; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
import scala.Tuple2; | ||
|
||
import java.io.IOException; | ||
import java.io.Serializable; | ||
import java.nio.file.Paths; | ||
import java.util.ArrayList; | ||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.function.Function; | ||
import java.util.stream.Collectors; | ||
import java.util.stream.IntStream; | ||
|
||
|
||
/** | ||
* This client computes statistics for streaks in a stack of images by accumulating the values of a streak mask over | ||
* subregions of each layer. | ||
*/ | ||
public class StreakStatisticsClient implements Serializable { | ||
|
||
public static class Parameters extends CommandLineParameters { | ||
|
||
@ParametersDelegate | ||
public RenderWebServiceParameters renderWeb = new RenderWebServiceParameters(); | ||
|
||
@ParametersDelegate | ||
public ZRangeParameters zRange = new ZRangeParameters(); | ||
|
||
@ParametersDelegate | ||
public StreakFinderParameters streakFinder = new StreakFinderParameters(); | ||
|
||
@Parameter(names = "--stack", description = "Stack to pull image and transformation data from", required = true) | ||
public String stack; | ||
|
||
@Parameter(names = "--output", description = "Output file path", required = true) | ||
public String outputPath; | ||
|
||
@Parameter(names = "--nCells", description = "Number of cells to use in x and y directions, e.g., 5x3", required = true) | ||
public String cells; | ||
|
||
private int nX = 0; | ||
private int nY = 0; | ||
|
||
public void validate() { | ||
streakFinder.validate(); | ||
try { | ||
final String[] nCells = cells.split("x"); | ||
nX = Integer.parseInt(nCells[0]); | ||
nY = Integer.parseInt(nCells[1]); | ||
if (nX < 1 || nY < 1) { | ||
throw new IllegalArgumentException("nCells must be positive"); | ||
} | ||
} catch (final NumberFormatException e) { | ||
throw new IllegalArgumentException("nCells must be in the format 'NxM'"); | ||
} | ||
} | ||
|
||
public int nCellsX() { | ||
if (nX == 0) { | ||
validate(); | ||
} | ||
return nX; | ||
} | ||
|
||
public int nCellsY() { | ||
if (nY == 0) { | ||
validate(); | ||
} | ||
return nY; | ||
} | ||
} | ||
|
||
public static void main(final String[] args) { | ||
final ClientRunner clientRunner = new ClientRunner(args) { | ||
@Override | ||
public void runClient(final String[] args) throws Exception { | ||
|
||
final Parameters parameters = new Parameters(); | ||
parameters.parse(args); | ||
LOG.info("runClient: entry, parameters={}", parameters); | ||
parameters.validate(); | ||
|
||
final StreakStatisticsClient client = new StreakStatisticsClient(parameters); | ||
client.compileStreakStatistics(); | ||
} | ||
}; | ||
clientRunner.run(); | ||
} | ||
|
||
|
||
private static final Logger LOG = LoggerFactory.getLogger(StreakStatisticsClient.class); | ||
private final Parameters parameters; | ||
|
||
public StreakStatisticsClient(final Parameters parameters) { | ||
this.parameters = parameters; | ||
} | ||
|
||
public void compileStreakStatistics() throws IOException { | ||
final SparkConf conf = new SparkConf().setAppName("StreakStatisticsClient"); | ||
|
||
try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) { | ||
final String sparkAppId = sparkContext.getConf().getAppId(); | ||
final String executorsJson = LogUtilities.getExecutorsApiJson(sparkAppId); | ||
|
||
LOG.info("run: appId is {}, executors data is {}", sparkAppId, executorsJson); | ||
|
||
compileStreakStatistics(sparkContext); | ||
} | ||
} | ||
|
||
private void compileStreakStatistics(final JavaSparkContext sparkContext) throws IOException { | ||
// get some metadata and broadcast variables accessed by all workers | ||
final Broadcast<Parameters> bcParameters = sparkContext.broadcast(parameters); | ||
final StackMetaData stackMetaData = getMetaData(); | ||
final Broadcast<Bounds> bounds = sparkContext.broadcast(parameters.zRange.overrideBounds(stackMetaData.getStackBounds())); | ||
final List<Double> zValues = IntStream.range(bounds.value().getMinZ().intValue(), bounds.value().getMaxZ().intValue() + 1) | ||
.boxed().map(Double::valueOf).collect(Collectors.toList()); | ||
|
||
LOG.info("run: fetched {} z-values for stack {}", zValues.size(), parameters.stack); | ||
|
||
// do the computation in a distributed way | ||
final List<Tuple2<Double, LayerResults>> result = sparkContext.parallelize(zValues) | ||
.mapToPair(z -> new Tuple2<>(z, pullTileSpecs(bcParameters.value(), z))) | ||
.mapValues(tileSpecs -> computeStreakStatisticsForLayer(bcParameters.value(), bounds.value(), tileSpecs)) | ||
.collect(); | ||
|
||
// convert to image and store on disk - list needs to be copied since the list returned by spark is not sortable | ||
final List<Tuple2<Double, LayerResults>> sortedResults = new ArrayList<>(result); | ||
sortedResults.sort((a, b) -> a._1.compareTo(b._1)); | ||
final Img<DoubleType> data = statisticsToImage(sortedResults); | ||
final Img<DoubleType> min = minMaxToImage(sortedResults, LayerResults::getMin); | ||
final Img<DoubleType> max = minMaxToImage(sortedResults, LayerResults::getMax); | ||
storeData(data, min, max, stackMetaData); | ||
} | ||
|
||
private StackMetaData getMetaData() throws IOException { | ||
final RenderDataClient dataClient = parameters.renderWeb.getDataClient(); | ||
return dataClient.getStackMetaData(parameters.stack); | ||
} | ||
|
||
private static ResolvedTileSpecCollection pullTileSpecs(final Parameters parameters, final double z) throws IOException { | ||
final RenderDataClient dataClient = parameters.renderWeb.getDataClient(); | ||
return dataClient.getResolvedTiles(parameters.stack, z); | ||
} | ||
|
||
private Img<DoubleType> statisticsToImage(final List<Tuple2<Double, LayerResults>> zLayerToResults) { | ||
final double[] fullData = new double[parameters.nCellsX() * parameters.nCellsY() * zLayerToResults.size()]; | ||
final Img<DoubleType> data = ArrayImgs.doubles(fullData, parameters.nCellsX(), parameters.nCellsY(), zLayerToResults.size()); | ||
final long[] position = new long[3]; | ||
int z = 0; | ||
|
||
for (final Tuple2<Double, LayerResults> zLayerToResult : zLayerToResults) { | ||
final double[][] layerStatistics = zLayerToResult._2.getStatistics(); | ||
|
||
position[2] = z++; | ||
for (int x = 0; x < parameters.nCellsX(); x++) { | ||
position[0] = x; | ||
for (int y = 0; y < parameters.nCellsY(); y++) { | ||
position[1] = y; | ||
data.getAt(position).set(layerStatistics[x][y]); | ||
} | ||
} | ||
} | ||
return data; | ||
} | ||
|
||
private Img<DoubleType> minMaxToImage( | ||
final List<Tuple2<Double, LayerResults>> zLayerToResults, | ||
final Function<LayerResults, double[]> minMaxExtractor | ||
) { | ||
final double[] flattenedMinMax = zLayerToResults.stream() | ||
.map(Tuple2::_2) | ||
.map(minMaxExtractor) | ||
.flatMapToDouble(Arrays::stream) | ||
.toArray(); | ||
return ArrayImgs.doubles(flattenedMinMax, 2, zLayerToResults.size()); | ||
} | ||
|
||
private void storeData( | ||
final Img<DoubleType> data, | ||
final Img<DoubleType> layerMin, | ||
final Img<DoubleType> layerMax, | ||
final StackMetaData stackMetaData | ||
) { | ||
// transpose data because images are F-order and python expects C-order | ||
final String group = Paths.get(parameters.renderWeb.project, parameters.stack).toString(); | ||
final int zChunkSize = Math.min(1000, (int) data.dimension(2)); | ||
final int[] dataChunkSize = new int[] {zChunkSize, parameters.nCellsY(), parameters.nCellsX()}; | ||
final int[] minMaxChunkSize = new int[] {2, zChunkSize}; | ||
|
||
final Bounds stackBounds = stackMetaData.getStackBounds(); | ||
final double[] min = new double[3]; | ||
min[0] = stackBounds.getMinX(); | ||
min[1] = stackBounds.getMinY(); | ||
min[2] = stackBounds.getMinZ(); | ||
|
||
final double[] max = new double[3]; | ||
max[0] = stackBounds.getMaxX(); | ||
max[1] = stackBounds.getMaxY(); | ||
max[2] = stackBounds.getMaxZ(); | ||
|
||
try (final N5Writer n5Writer = new N5ZarrWriter(parameters.outputPath)) { | ||
n5Writer.createGroup(group); | ||
N5Utils.save(Views.permute(data, 0, 2), n5Writer, Paths.get(group, "statistics").toString(), dataChunkSize, new GzipCompression()); | ||
N5Utils.save(layerMin, n5Writer, Paths.get(group, "layer_min").toString(), minMaxChunkSize, new GzipCompression()); | ||
N5Utils.save(layerMax, n5Writer, Paths.get(group, "layer_max").toString(), minMaxChunkSize, new GzipCompression()); | ||
|
||
n5Writer.setAttribute(group, "StackBounds", Map.of("min", min, "max", max)); | ||
n5Writer.setAttribute(group, "Resolution_nm", stackMetaData.getCurrentResolutionValues()); | ||
final Map<String, Double> runParameters = Map.of( | ||
"threshold", parameters.streakFinder.threshold, | ||
"meanFilterSize", (double) parameters.streakFinder.meanFilterSize, | ||
"blurRadius", (double) parameters.streakFinder.blurRadius); | ||
n5Writer.setAttribute(group, "RunParameters", runParameters); | ||
} | ||
} | ||
|
||
private static LayerResults computeStreakStatisticsForLayer( | ||
final Parameters parameters, | ||
final Bounds stackBounds, | ||
final ResolvedTileSpecCollection layerTiles) { | ||
|
||
LOG.info("computeStreakStatisticsForLayer: processing {} tiles", layerTiles.getTileSpecs().size()); | ||
final StreakAccumulator accumulator = new StreakAccumulator(stackBounds, parameters.nCellsX(), parameters.nCellsY()); | ||
final ImageProcessorCache cache = ImageProcessorCache.DISABLED_CACHE; | ||
|
||
for (final TileSpec tileSpec : layerTiles.getTileSpecs()) { | ||
LOG.debug("computeStreakStatisticsForLayer: processing tile {}", tileSpec.getTileId()); | ||
|
||
final ImageProcessor imp = cache.get(tileSpec.getImagePath(), 0, false, false, ImageLoader.LoaderType.H5_SLICE, null); | ||
final ImagePlus image = new ImagePlus(tileSpec.getTileId(), imp); | ||
if (image.getProcessor() == null) { | ||
LOG.warn("computeStreakStatisticsForLayer: could not load image for tile {}", tileSpec.getTileId()); | ||
continue; | ||
} | ||
|
||
final StreakFinder streakFinder = parameters.streakFinder.createStreakFinder(); | ||
final ImagePlus mask = streakFinder.createStreakMask(image); | ||
addStreakStatisticsForSingleMask(accumulator, mask, tileSpec); | ||
} | ||
|
||
layerTiles.recalculateBoundingBoxes(); | ||
final Bounds layerBounds = layerTiles.toBounds(); | ||
final double[] min = new double[] {layerBounds.getMinX(), layerBounds.getMinY()}; | ||
final double[] max = new double[] {layerBounds.getMaxX(), layerBounds.getMaxY()}; | ||
|
||
return new LayerResults(accumulator.getResults(), min, max); | ||
} | ||
|
||
private static void addStreakStatisticsForSingleMask(final StreakAccumulator accumulator, final ImagePlus mask, final TileSpec tileSpec) { | ||
final double[] position = new double[2]; | ||
final CoordinateTransformList<CoordinateTransform> transformList = tileSpec.getTransformList(); | ||
for (int x = 0; x < mask.getWidth(); x++) { | ||
for (int y = 0; y < mask.getHeight(); y++) { | ||
position[0] = x; | ||
position[1] = y; | ||
transformList.applyInPlace(position); | ||
accumulator.addValue(mask.getProcessor().getf(x, y), position[0], position[1]); | ||
} | ||
} | ||
} | ||
|
||
|
||
private static class StreakAccumulator { | ||
private final int nX; | ||
private final int nY; | ||
|
||
private final int minX; | ||
private final int minY; | ||
private final int width; | ||
private final int height; | ||
|
||
private final double[][] sum; | ||
private final long[][] counts; | ||
|
||
public StreakAccumulator(final Bounds layerBounds, final int nX, final int nY) { | ||
this.nX = nX; | ||
this.nY = nY; | ||
|
||
this.minX = layerBounds.getMinX().intValue(); | ||
this.minY = layerBounds.getMinY().intValue(); | ||
this.width = layerBounds.getWidth(); | ||
this.height = layerBounds.getHeight(); | ||
|
||
sum = new double[nX][nY]; | ||
counts = new long[nX][nY]; | ||
} | ||
|
||
public void addValue(final double value, final double x, final double y) { | ||
final int i = (int) (nX * (x - minX) / width); | ||
final int j = (int) (nY * (y - minY) / height); | ||
sum[i][j] += value; | ||
counts[i][j]++; | ||
} | ||
|
||
public double[][] getResults() { | ||
final double[][] results = new double[nX][nY]; | ||
for (int i = 0; i < nX; i++) { | ||
for (int j = 0; j < nY; j++) { | ||
// account for the fact that the mask values are in [0, 255], with 255 indicating a streak | ||
if (counts[i][j] == 0) { | ||
results[i][j] = 0; | ||
} else { | ||
results[i][j] = sum[i][j] / (255 * counts[i][j]); | ||
} | ||
} | ||
} | ||
return results; | ||
} | ||
} | ||
|
||
|
||
private static class LayerResults implements Serializable{ | ||
final private double[][] statistics; | ||
final private double[] min; | ||
final private double[] max; | ||
|
||
public LayerResults(final double[][] statistics, final double[] min, final double[] max) { | ||
this.statistics = statistics; | ||
this.min = min; | ||
this.max = max; | ||
} | ||
|
||
public double[][] getStatistics() { | ||
return statistics; | ||
} | ||
|
||
public double[] getMin() { | ||
return min; | ||
} | ||
|
||
public double[] getMax() { | ||
return max; | ||
} | ||
} | ||
} |