Skip to content

Commit

Permalink
Merge pull request #196 from saalfeldlab/feature/streak-finder
Browse files Browse the repository at this point in the history
Various tools for de-streaking
minnerbe authored Jan 7, 2025

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
2 parents 9f9e41c + f3b6e17 commit c1c8906
Showing 12 changed files with 1,030 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
package org.janelia.alignment.destreak;

import ij.IJ;
import ij.ImagePlus;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;

import java.io.Serializable;

/**
* This class detects streaks in an image and returns a corresponding mask.
* <p>
* The finder first applies a derivative filter in the x-direction to detect vertical edges. Then, it applies a mean
* filter in the y-direction to smooth out the edges in the y-direction. The resulting image is then thresholded
* (from above and below) to create a mask of the streaks. Finally, an optional Gaussian blur is applied to the mask to
* smooth it. The mask is 0 where there are no streaks and 255 where there are streaks.
* <p>
* There are three parameters that can be set:
* <ul>
* <li>meanFilterSize: the number of pixels to average in the y-direction (e.g., 0 means no averaging, 50 means averaging +/-50 pixels in y)</li>
* <li>threshold: the threshold used to convert the streak mask to a binary mask</li>
* <li>blurRadius: the radius of the Gaussian blur applied to the streak mask (0 means no smoothing)</li>
* </ul>
*/
public class StreakFinder implements Serializable {

private final int meanFilterSize;
private final double threshold;
private final int blurRadius;

public StreakFinder(final int meanFilterSize, final double threshold, final int blurRadius) {
if (meanFilterSize < 0) {
throw new IllegalArgumentException("meanFilterSize must be non-negative");
}
if (threshold < 0) {
throw new IllegalArgumentException("threshold must be non-negative");
}
if (blurRadius < 0) {
throw new IllegalArgumentException("blurRadius must be 0 (no blur) or positive");
}

this.meanFilterSize = meanFilterSize;
this.threshold = threshold;
this.blurRadius = blurRadius;
}

public ImagePlus createStreakMask(final ImagePlus input) {
ImageProcessor filtered = differenceFilterX(input.getProcessor());
filtered = meanFilterY(filtered, meanFilterSize);
filtered = bidirectionalThreshold(filtered, threshold);

final ImagePlus mask = new ImagePlus("Mask", filtered);
if (blurRadius > 0) {
IJ.run(mask, "Gaussian Blur...", String.format("sigma=%d", blurRadius));
}
return mask;
}

private static ImageProcessor differenceFilterX(final ImageProcessor in) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final float left = in.getf(projectPeriodically(x - 1, width), y);
final float right = in.getf(projectPeriodically(x + 1, width), y);
out.setf(x, y, (right - left) / 2);
}
}
return out;
}

private static ImageProcessor meanFilterY(final ImageProcessor in, final int size) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();
final int n = 2 * size + 1;

for (int x = 0; x < width; x++) {
// initialize running sum
float sum = in.getf(x, 0);
for (int y = 1; y <= size; y++) {
sum += 2 * in.getf(x, y);
}
out.setf(x, 0, sum / n);

// update running sum by adding the next value and subtracting the oldest value
for (int y = 1; y < height; y++) {
final float oldest = in.getf(x, projectPeriodically(y - size - 1, height));
final float newest = in.getf(x, projectPeriodically(y + size, height));
sum += newest - oldest;
out.setf(x, y, sum / n);
}
}
return out;
}

private static ImageProcessor bidirectionalThreshold(final ImageProcessor in, final double threshold) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final float value = Math.abs(in.getf(x, y));
out.setf(x, y, (value > threshold) ? 255 : 0);
}
}
return out;
}

private static int projectPeriodically(final int index, final int max) {
if (index < 0) {
return -index;
} else if (index >= max) {
return 2 * max - index - 2;
} else {
return index;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a small perturbation of a given 2D direction for each sample.
*/
public class AnisotropicDirection2D implements DirectionalStatistic {

private final Random random;
private final double[] primalAxis;
private final double[] secondaryAxis;
private final double perturbation;

/**
* Creates a new statistic with a random seed.
*/
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation) {
this(primalAxis, perturbation, new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public AnisotropicDirection2D(final double[] primalAxis, final double perturbation, final Random random) {
final double norm = Math.sqrt(primalAxis[0] * primalAxis[0] + primalAxis[1] * primalAxis[1]);
this.primalAxis = new double[] { primalAxis[0] / norm, primalAxis[1] / norm };
this.secondaryAxis = new double[] { -primalAxis[1] / norm, primalAxis[0] / norm };
this.perturbation = perturbation;
this.random = random;
}

@Override
public void sample(final double[] direction) {
// TODO: this should be a von Mises distribution instead of this homegrown implementation
final int sign = random.nextBoolean() ? 1 : -1;
final double eps = perturbation * random.nextGaussian();
final double norm = 1 + eps * eps; // because axes are orthonormal
direction[0] = (sign * primalAxis[0] + eps * secondaryAxis[0]) / norm;
direction[1] = (sign * primalAxis[1] + eps * secondaryAxis[1]) / norm;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package org.janelia.alignment.inpainting;

/**
* Interface for distributions that model the direction of a ray used in {@link RayCastingInpainter}.
*/
public interface DirectionalStatistic {

/**
* Initializes the direction of the next ray. The array that is passed in is filled with the direction.
*
* @param direction the array in which to initialize the direction
*/
void sample(double[] direction);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a completely random 2D direction for each sample.
*/
public class RandomDirection2D implements DirectionalStatistic {

private final Random random;

/**
* Creates a new statistic with a random seed.
*/
public RandomDirection2D() {
this(new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public RandomDirection2D(final Random random) {
this.random = random;
}

@Override
public void sample(final double[] direction) {
final double angle = random.nextDouble() * 2 * Math.PI;
direction[0] = Math.cos(angle);
direction[1] = Math.sin(angle);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
package org.janelia.alignment.inpainting;

import java.util.Random;

/**
* A statistic that yields a completely random 3D direction for each sample.
*/
public class RandomDirection3D implements DirectionalStatistic {

private final Random random;

/**
* Creates a new statistic with a random seed.
*/
public RandomDirection3D() {
this(new Random());
}

/**
* Creates a new statistic with the given random number generator.
*
* @param random the random number generator to use
*/
public RandomDirection3D(final Random random) {
this.random = random;
}

@Override
public void sample(final double[] direction) {
final double x = random.nextGaussian();
final double y = random.nextGaussian();
final double z = random.nextGaussian();
final double norm = Math.sqrt(x * x + y * y + z * z);
direction[0] = x / norm;
direction[1] = y / norm;
direction[2] = z / norm;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package org.janelia.alignment.inpainting;

import net.imglib2.RealInterval;

import net.imglib2.Cursor;
import net.imglib2.Interval;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.RealLocalizable;
import net.imglib2.RealRandomAccess;
import net.imglib2.interpolation.randomaccess.NLinearInterpolatorFactory;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.view.Views;


/**
* Infer missing values in an image (up to 3D) by ray casting (which is equivalent to diffusion of image values).
* <p>
* This is adapted from the hotknife repository for testing purposes.
*/
public class RayCastingInpainter {

private final int nRays;
private final long maxRayLength;
private final DirectionalStatistic directionStatistic;

private final double[] direction = new double[3];
private final Result result = new Result();

public RayCastingInpainter(final int nRays, final int maxInpaintingDiameter, final DirectionalStatistic directionStatistic) {
this.nRays = nRays;
this.maxRayLength = maxInpaintingDiameter;
this.directionStatistic = directionStatistic;
}

private static boolean isInside(final RealLocalizable p, final RealInterval r) {
for (int d = 0; d < p.numDimensions(); ++d) {
final double l = p.getDoublePosition(d);
if (l < r.realMin(d) || l > r.realMax(d)) {
return false;
}
}
return true;
}

/**
* Inpaints missing values in an image (up to 3D) by casting rays in random directions and averaging the values of
* the first non-masked pixel.
*
* @param img the image to inpaint
* @param mask the mask
*/
public void inpaint(final RandomAccessibleInterval<FloatType> img, final RandomAccessibleInterval<FloatType> mask) {
final Cursor<FloatType> imgCursor = Views.iterable(img).localizingCursor();

final RealRandomAccess<FloatType> imageAccess = Views.interpolate(Views.extendBorder(img), new NLinearInterpolatorFactory<>()).realRandomAccess();
final RealRandomAccess<FloatType> maskAccess = Views.interpolate(Views.extendBorder(mask), new NLinearInterpolatorFactory<>()).realRandomAccess();

while (imgCursor.hasNext()) {
final FloatType o = imgCursor.next();
final float m = maskAccess.setPositionAndGet(imgCursor).get();
if (m == 0.0) {
// pixel not masked, no inpainting necessary
continue;
}

double weightSum = 0;
double valueSum = 0;

// interpolate value by casting rays in random directions and averaging (weighted by distances) the
// values of the first non-masked pixel
for (int i = 0; i < nRays; ++i) {
final Result result = castRay(maskAccess, mask, imgCursor);
if (result != null) {
final double weight = 1.0 / result.distance;
weightSum += weight;
final double value = imageAccess.setPositionAndGet(result.position).getRealDouble();
valueSum += value * weight;
}
}

final float v = (float) (valueSum / weightSum);
final float w = m / 255.0f;
final float oldValue = o.get();
final float newValue = v * w + oldValue * (1 - w);
o.set(newValue);
}
}

/**
* Casts a ray from the given position in a random direction until it hits a non-masked (i.e., non-NaN) pixel
* or exits the image boundary.
*
* @param mask the mask indicating which pixels are masked (> 0) and which are not (0)
* @param interval the interval of the image
* @param position the position from which to cast the ray
* @return the result of the ray casting or null if the ray exited the image boundary without hitting a
* non-masked pixel
*/
private Result castRay(final RealRandomAccess<FloatType> mask, final Interval interval, final RealLocalizable position) {
mask.setPosition(position);
directionStatistic.sample(direction);
long steps = 0;

while(true) {
mask.move(direction);
++steps;

if (!isInside(mask, interval) || steps > maxRayLength) {
// the ray exited the image boundaries without hitting a non-masked pixel
return null;
}

final float value = mask.get().get();
if (value < 1.0) {
// the ray reached a non-masked pixel
mask.localize(result.position);
result.distance = steps;
return result;
}
}
}


private static class Result {
public double[] position = new double[3];
public double distance = 0;
}
}
Original file line number Diff line number Diff line change
@@ -148,7 +148,7 @@ public static void main(final String[] args) {
HUM_AIRWAY_FILE_NAMES[0]; // change file names index to test different images

final Map<String, Double> parameters = new HashMap<>();
parameters.put("numThreads", 8.0);
parameters.put("numThreads", 12.0);
parameters.put("fftWidth", 5545.0);
parameters.put("fftHeight", 10920.0);
parameters.put("innerCutoff", 18.0);
@@ -158,10 +158,14 @@ public static void main(final String[] args) {
parameters.put("initialThreshold", 7.0);
parameters.put("finalThreshold", 0.05);

// this shows the effect of varying the parameters on the correction process (symmetrically around the value set above)
// a good strategy for finding a good parameter set is to start with the actual correction parameters "innerCutoff" and "bandWidth"
// once good parameters for correction are found, the locality of correction can be adjusted by varying "gaussianBlurRadius",
// "initialThreshold" and "finalThreshold"; the angle only needs to be adjusted if the streaks are not vertical
displayParameterRange(srcPath, parameters, "innerCutoff", 3.0, 3, false);
displayParameterRange(srcPath, parameters, "bandWidth", 2.0, 3, false);
// displayParameterRange(srcPath, parameters, "angle", 0.0, 1.0, 3, false);
displayParameterRange(srcPath, parameters, "gaussianBlurRadius", 20.0, 3, true);
displayParameterRange(srcPath, parameters, "gaussianBlurRadius", 3.0, 3, true);
displayParameterRange(srcPath, parameters, "initialThreshold", 1.0, 3, true);
displayParameterRange(srcPath, parameters, "finalThreshold", 0.01, 3, true);

Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.janelia.alignment.destreak;

import ij.ImageJ;
import ij.ImagePlus;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
import net.imglib2.img.Img;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.type.numeric.real.FloatType;
import org.janelia.alignment.inpainting.RandomDirection2D;
import org.janelia.alignment.inpainting.RayCastingInpainter;

public class StreakFinderTest {
public static void main(final String[] args) {
final String srcPath = "/home/innerbergerm@hhmi.org/big-data/streak-correction/jrc_mus-liver-zon-3/z00032-0-0-1.png";
final StreakFinder finder = new StreakFinder(100, 5.0, 3);
// final StreakCorrector corrector = new SmoothMaskStreakCorrector(12, 6161, 8190, 10, 10, 0);
final RayCastingInpainter inpainter = new RayCastingInpainter(128, 100, new RandomDirection2D());

final long start = System.currentTimeMillis();
final ImagePlus original = new ImagePlus(srcPath);
final ImagePlus mask = finder.createStreakMask(original);
// final ImagePlus corrected = streakCorrectFourier(corrector, original, mask);
final ImagePlus corrected = streakCorrectInpainting(inpainter, original, mask);
System.out.println("Processing time: " + (System.currentTimeMillis() - start) + "ms");

new ImageJ();
mask.show();
original.show();
corrected.show();
}

private static ImagePlus streakCorrectFourier(
final StreakCorrector corrector,
final ImagePlus original,
final ImagePlus mask) {

final ImagePlus corrected = original.duplicate();
corrected.setTitle("Corrected");
corrector.process(corrected.getProcessor(), 1.0);

final ImageProcessor proc = corrected.getProcessor();
final ImageProcessor maskProc = mask.getProcessor();
final ImageProcessor originalProc = original.getProcessor();
for (int i = 0; i < corrected.getWidth() * corrected.getHeight(); i++) {
final float lambda = maskProc.getf(i) / 255.0f;
final float mergedValue = originalProc.getf(i) * (1 - lambda) + proc.getf(i) * lambda;
proc.setf(i, mergedValue);
}

return corrected;
}

private static ImagePlus streakCorrectInpainting(
final RayCastingInpainter inpainter,
final ImagePlus original,
final ImagePlus mask) {

final FloatProcessor correctedFp = original.getProcessor().convertToFloatProcessor();
final ImagePlus corrected = new ImagePlus("Corrected", correctedFp);
final Img<FloatType> correctedImg = ImageJFunctions.wrapFloat(corrected);

final FloatProcessor maskFp = mask.getProcessor().convertToFloatProcessor();
final ImagePlus maskIp = new ImagePlus("Mask", maskFp);
final Img<FloatType> maskImg = ImageJFunctions.wrapFloat(maskIp);

inpainter.inpaint(correctedImg, maskImg);

return corrected;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package org.janelia.render.client.parameter;

import com.beust.jcommander.Parameter;
import org.janelia.alignment.destreak.StreakFinder;

import java.io.Serializable;

/**
* Parameters for streak finding with a {@link StreakFinder}.
*/
public class StreakFinderParameters implements Serializable {
@Parameter(
names = "--meanFilterSize",
description = "Number of pixels to average in the positive and negative y-direction",
required = true)
public int meanFilterSize;

@Parameter(
names = "--threshold",
description = "Threshold used to convert the streak mask to a binary mask",
required = true)
public double threshold;

@Parameter(
names = "--blurRadius",
description = "Radius of the Gaussian blur applied to the streak mask",
required = true)
public int blurRadius;

public StreakFinder createStreakFinder() {
return new StreakFinder(meanFilterSize, threshold, blurRadius);
}

public void validate() {
if (meanFilterSize < 0) {
throw new IllegalArgumentException("meanFilterSize must be non-negative");
}
if (threshold < 0) {
throw new IllegalArgumentException("threshold must be non-negative");
}
if (blurRadius < 0) {
throw new IllegalArgumentException("blurRadius must be 0 (no blur) or positive");
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package org.janelia.render.client;

import ij.ImageJ;
import ij.ImagePlus;
import ij.process.FloatProcessor;
import ij.process.ImageProcessor;
import net.imglib2.img.Img;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.type.numeric.real.FloatType;
import org.janelia.alignment.inpainting.AnisotropicDirection2D;
import org.janelia.alignment.inpainting.DirectionalStatistic;
import org.janelia.alignment.inpainting.RayCastingInpainter;

public class InpaintingTest {

private static final int N_RAYS = 256;
private static final int MAX_INPAINTING_DIAMETER = 100;
private static final double THRESHOLD = 20.0;

private static final int Y_MIN = 3826;
private static final int Y_MAX = 3856;

public static void main(final String[] args) {
final String srcPath = "/home/innerbergerm@hhmi.org/big-data/streak-correction/jrc_P3-E2-D1-Lip4-19/z14765-0-0-0.png";
final ImagePlus original = new ImagePlus(srcPath);
final DirectionalStatistic directionStatistic = new AnisotropicDirection2D(new double[] {0, 1}, 0.5);
final RayCastingInpainter inpainter = new RayCastingInpainter(N_RAYS, MAX_INPAINTING_DIAMETER, directionStatistic);

// final ImagePlus mask = threshold(original.getProcessor(), THRESHOLD);
final ImagePlus mask = bandMask(original.getProcessor(), Y_MIN, Y_MAX);

final long start = System.currentTimeMillis();
final ImagePlus corrected = streakCorrectInpainting(inpainter, original, mask);
System.out.println("Processing time: " + (System.currentTimeMillis() - start) + "ms");

new ImageJ();
original.show();
mask.show();
corrected.show();
}

private static ImagePlus threshold(final ImageProcessor in, final double threshold) {
final ImageProcessor out = new FloatProcessor(in.getWidth(), in.getHeight());
final int width = in.getWidth();
final int height = in.getHeight();

for (int y = 0; y < height; y++) {
for (int x = 0; x < width; x++) {
final float value = in.getf(x, y);
out.setf(x, y, (value < threshold) ? 255 : 0);
}
}
return new ImagePlus("Mask", out);
}

private static ImagePlus bandMask(final ImageProcessor in, final int yMin, final int yMax) {
final int width = in.getWidth();
final int height = in.getHeight();
final ImageProcessor out = new FloatProcessor(width, height);

for (int y = yMin; y <= yMax; y++) {
for (int x = 0; x < width; x++) {
out.setf(x, y, 255);
}
}
return new ImagePlus("Mask", out);
}

private static ImagePlus streakCorrectInpainting(
final RayCastingInpainter inpainter,
final ImagePlus original,
final ImagePlus mask) {

final FloatProcessor correctedFp = original.getProcessor().convertToFloatProcessor();
final ImagePlus corrected = new ImagePlus("Corrected", correctedFp);
final Img<FloatType> correctedImg = ImageJFunctions.wrapFloat(corrected);

final FloatProcessor maskFp = mask.getProcessor().convertToFloatProcessor();
final ImagePlus maskIp = new ImagePlus("Mask", maskFp);
final Img<FloatType> maskImg = ImageJFunctions.wrapFloat(maskIp);

inpainter.inpaint(correctedImg, maskImg);

return corrected;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.janelia.render.client;

import ij.ImageJ;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.converter.Converters;
import net.imglib2.img.Img;
import net.imglib2.img.display.imagej.ImageJFunctions;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.type.numeric.integer.UnsignedByteType;
import net.imglib2.type.numeric.real.FloatType;
import net.imglib2.util.Util;
import org.janelia.alignment.inpainting.DirectionalStatistic;
import org.janelia.alignment.inpainting.RandomDirection3D;
import org.janelia.alignment.inpainting.RayCastingInpainter;
import org.janelia.render.client.spark.n5.N5Client;
import org.janelia.saalfeldlab.n5.N5FSReader;
import org.janelia.saalfeldlab.n5.N5Reader;
import org.janelia.saalfeldlab.n5.imglib2.N5Utils;

public class JrcP3E2D1Lip419Inpainter {

// path to the N5 container containing the data and a mask
public static final String N5_BASE_PATH = System.getenv("HOME") + "/big-data/streak-correction/jrc_P3-E2-D1-Lip4-19/broken_layers.n5";

public static void main(final String[] args) {
// saveSubStack();
doInpainting();
}

@SuppressWarnings("unused")
private static void saveSubStack() {
final String[] n5ClientArgs = {
"--baseDataUrl", "http://renderer-dev.int.janelia.org:8080/render-ws/v1",
"--owner", "fibsem",
"--project", "jrc_P3_E2_D1_Lip4_19",
"--stack", "v1_acquire_trimmed_align_v2",
"--n5Path", N5_BASE_PATH,
"--n5Dataset", "tissue",
"--tileWidth", "2048",
"--tileHeight", "2048",
"--blockSize", "512,512,128",
"--minZ", "14763",
"--maxZ", "14805",
};

N5Client.main(n5ClientArgs);
}

@SuppressWarnings("unused")
private static void doInpainting() {
final DirectionalStatistic directionalStatistic = new RandomDirection3D();
final RayCastingInpainter inpainter = new RayCastingInpainter(64, 50, directionalStatistic);

try (final N5Reader n5 = new N5FSReader(N5_BASE_PATH)) {
final RandomAccessibleInterval<UnsignedByteType> tissue = N5Utils.open(n5, "tissue_crop");
final RandomAccessibleInterval<UnsignedByteType> mask = N5Utils.open(n5, "mask_crop_v2");

// convert to float (and copy tissue)
final Img<FloatType> tissueFloat = Util.getSuitableImgFactory(tissue, new FloatType()).create(tissue);
LoopBuilder.setImages(tissue, tissueFloat).forEachPixel((i, o) -> o.setReal(i.getRealDouble()));
final RandomAccessibleInterval<FloatType> maskFloat = Converters.convert(mask, (i, o) -> o.setReal(i.getRealDouble()), new FloatType());

inpainter.inpaint(tissueFloat, maskFloat);

new ImageJ();
ImageJFunctions.show(tissueFloat, "Tissue");
ImageJFunctions.show(maskFloat, "Mask");
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,371 @@
package org.janelia.render.client.spark.destreak;

import com.beust.jcommander.Parameter;
import com.beust.jcommander.ParametersDelegate;
import ij.ImagePlus;
import ij.process.ImageProcessor;
import mpicbg.models.CoordinateTransform;
import mpicbg.models.CoordinateTransformList;
import net.imglib2.img.Img;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.view.Views;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.broadcast.Broadcast;
import org.janelia.alignment.destreak.StreakFinder;
import org.janelia.alignment.loader.ImageLoader;
import org.janelia.alignment.spec.Bounds;
import org.janelia.alignment.spec.ResolvedTileSpecCollection;
import org.janelia.alignment.spec.TileSpec;
import org.janelia.alignment.spec.stack.StackMetaData;
import org.janelia.alignment.util.ImageProcessorCache;
import org.janelia.render.client.ClientRunner;
import org.janelia.render.client.RenderDataClient;
import org.janelia.render.client.parameter.CommandLineParameters;
import org.janelia.render.client.parameter.RenderWebServiceParameters;
import org.janelia.render.client.parameter.StreakFinderParameters;
import org.janelia.render.client.parameter.ZRangeParameters;
import org.janelia.render.client.spark.LogUtilities;
import org.janelia.saalfeldlab.n5.GzipCompression;
import org.janelia.saalfeldlab.n5.N5Writer;
import org.janelia.saalfeldlab.n5.imglib2.N5Utils;
import org.janelia.saalfeldlab.n5.zarr.N5ZarrWriter;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Tuple2;

import java.io.IOException;
import java.io.Serializable;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;


/**
* This client computes statistics for streaks in a stack of images by accumulating the values of a streak mask over
* subregions of each layer.
*/
public class StreakStatisticsClient implements Serializable {

public static class Parameters extends CommandLineParameters {

@ParametersDelegate
public RenderWebServiceParameters renderWeb = new RenderWebServiceParameters();

@ParametersDelegate
public ZRangeParameters zRange = new ZRangeParameters();

@ParametersDelegate
public StreakFinderParameters streakFinder = new StreakFinderParameters();

@Parameter(names = "--stack", description = "Stack to pull image and transformation data from", required = true)
public String stack;

@Parameter(names = "--output", description = "Output file path", required = true)
public String outputPath;

@Parameter(names = "--nCells", description = "Number of cells to use in x and y directions, e.g., 5x3", required = true)
public String cells;

private int nX = 0;
private int nY = 0;

public void validate() {
streakFinder.validate();
try {
final String[] nCells = cells.split("x");
nX = Integer.parseInt(nCells[0]);
nY = Integer.parseInt(nCells[1]);
if (nX < 1 || nY < 1) {
throw new IllegalArgumentException("nCells must be positive");
}
} catch (final NumberFormatException e) {
throw new IllegalArgumentException("nCells must be in the format 'NxM'");
}
}

public int nCellsX() {
if (nX == 0) {
validate();
}
return nX;
}

public int nCellsY() {
if (nY == 0) {
validate();
}
return nY;
}
}

public static void main(final String[] args) {
final ClientRunner clientRunner = new ClientRunner(args) {
@Override
public void runClient(final String[] args) throws Exception {

final Parameters parameters = new Parameters();
parameters.parse(args);
LOG.info("runClient: entry, parameters={}", parameters);
parameters.validate();

final StreakStatisticsClient client = new StreakStatisticsClient(parameters);
client.compileStreakStatistics();
}
};
clientRunner.run();
}


private static final Logger LOG = LoggerFactory.getLogger(StreakStatisticsClient.class);
private final Parameters parameters;

public StreakStatisticsClient(final Parameters parameters) {
this.parameters = parameters;
}

public void compileStreakStatistics() throws IOException {
final SparkConf conf = new SparkConf().setAppName("StreakStatisticsClient");

try (final JavaSparkContext sparkContext = new JavaSparkContext(conf)) {
final String sparkAppId = sparkContext.getConf().getAppId();
final String executorsJson = LogUtilities.getExecutorsApiJson(sparkAppId);

LOG.info("run: appId is {}, executors data is {}", sparkAppId, executorsJson);

compileStreakStatistics(sparkContext);
}
}

private void compileStreakStatistics(final JavaSparkContext sparkContext) throws IOException {
// get some metadata and broadcast variables accessed by all workers
final Broadcast<Parameters> bcParameters = sparkContext.broadcast(parameters);
final StackMetaData stackMetaData = getMetaData();
final Broadcast<Bounds> bounds = sparkContext.broadcast(parameters.zRange.overrideBounds(stackMetaData.getStackBounds()));
final List<Double> zValues = IntStream.range(bounds.value().getMinZ().intValue(), bounds.value().getMaxZ().intValue() + 1)
.boxed().map(Double::valueOf).collect(Collectors.toList());

LOG.info("run: fetched {} z-values for stack {}", zValues.size(), parameters.stack);

// do the computation in a distributed way
final List<Tuple2<Double, LayerResults>> result = sparkContext.parallelize(zValues)
.mapToPair(z -> new Tuple2<>(z, pullTileSpecs(bcParameters.value(), z)))
.mapValues(tileSpecs -> computeStreakStatisticsForLayer(bcParameters.value(), bounds.value(), tileSpecs))
.collect();

// convert to image and store on disk - list needs to be copied since the list returned by spark is not sortable
final List<Tuple2<Double, LayerResults>> sortedResults = new ArrayList<>(result);
sortedResults.sort((a, b) -> a._1.compareTo(b._1));
final Img<DoubleType> data = statisticsToImage(sortedResults);
final Img<DoubleType> min = minMaxToImage(sortedResults, LayerResults::getMin);
final Img<DoubleType> max = minMaxToImage(sortedResults, LayerResults::getMax);
storeData(data, min, max, stackMetaData);
}

private StackMetaData getMetaData() throws IOException {
final RenderDataClient dataClient = parameters.renderWeb.getDataClient();
return dataClient.getStackMetaData(parameters.stack);
}

private static ResolvedTileSpecCollection pullTileSpecs(final Parameters parameters, final double z) throws IOException {
final RenderDataClient dataClient = parameters.renderWeb.getDataClient();
return dataClient.getResolvedTiles(parameters.stack, z);
}

private Img<DoubleType> statisticsToImage(final List<Tuple2<Double, LayerResults>> zLayerToResults) {
final double[] fullData = new double[parameters.nCellsX() * parameters.nCellsY() * zLayerToResults.size()];
final Img<DoubleType> data = ArrayImgs.doubles(fullData, parameters.nCellsX(), parameters.nCellsY(), zLayerToResults.size());
final long[] position = new long[3];
int z = 0;

for (final Tuple2<Double, LayerResults> zLayerToResult : zLayerToResults) {
final double[][] layerStatistics = zLayerToResult._2.getStatistics();

position[2] = z++;
for (int x = 0; x < parameters.nCellsX(); x++) {
position[0] = x;
for (int y = 0; y < parameters.nCellsY(); y++) {
position[1] = y;
data.getAt(position).set(layerStatistics[x][y]);
}
}
}
return data;
}

private Img<DoubleType> minMaxToImage(
final List<Tuple2<Double, LayerResults>> zLayerToResults,
final Function<LayerResults, double[]> minMaxExtractor
) {
final double[] flattenedMinMax = zLayerToResults.stream()
.map(Tuple2::_2)
.map(minMaxExtractor)
.flatMapToDouble(Arrays::stream)
.toArray();
return ArrayImgs.doubles(flattenedMinMax, 2, zLayerToResults.size());
}

private void storeData(
final Img<DoubleType> data,
final Img<DoubleType> layerMin,
final Img<DoubleType> layerMax,
final StackMetaData stackMetaData
) {
// transpose data because images are F-order and python expects C-order
final String group = Paths.get(parameters.renderWeb.project, parameters.stack).toString();
final int zChunkSize = Math.min(1000, (int) data.dimension(2));
final int[] dataChunkSize = new int[] {zChunkSize, parameters.nCellsY(), parameters.nCellsX()};
final int[] minMaxChunkSize = new int[] {2, zChunkSize};

final Bounds stackBounds = stackMetaData.getStackBounds();
final double[] min = new double[3];
min[0] = stackBounds.getMinX();
min[1] = stackBounds.getMinY();
min[2] = stackBounds.getMinZ();

final double[] max = new double[3];
max[0] = stackBounds.getMaxX();
max[1] = stackBounds.getMaxY();
max[2] = stackBounds.getMaxZ();

try (final N5Writer n5Writer = new N5ZarrWriter(parameters.outputPath)) {
n5Writer.createGroup(group);
N5Utils.save(Views.permute(data, 0, 2), n5Writer, Paths.get(group, "statistics").toString(), dataChunkSize, new GzipCompression());
N5Utils.save(layerMin, n5Writer, Paths.get(group, "layer_min").toString(), minMaxChunkSize, new GzipCompression());
N5Utils.save(layerMax, n5Writer, Paths.get(group, "layer_max").toString(), minMaxChunkSize, new GzipCompression());

n5Writer.setAttribute(group, "StackBounds", Map.of("min", min, "max", max));
n5Writer.setAttribute(group, "Resolution_nm", stackMetaData.getCurrentResolutionValues());
final Map<String, Double> runParameters = Map.of(
"threshold", parameters.streakFinder.threshold,
"meanFilterSize", (double) parameters.streakFinder.meanFilterSize,
"blurRadius", (double) parameters.streakFinder.blurRadius);
n5Writer.setAttribute(group, "RunParameters", runParameters);
}
}

private static LayerResults computeStreakStatisticsForLayer(
final Parameters parameters,
final Bounds stackBounds,
final ResolvedTileSpecCollection layerTiles) {

LOG.info("computeStreakStatisticsForLayer: processing {} tiles", layerTiles.getTileSpecs().size());
final StreakAccumulator accumulator = new StreakAccumulator(stackBounds, parameters.nCellsX(), parameters.nCellsY());
final ImageProcessorCache cache = ImageProcessorCache.DISABLED_CACHE;

for (final TileSpec tileSpec : layerTiles.getTileSpecs()) {
LOG.debug("computeStreakStatisticsForLayer: processing tile {}", tileSpec.getTileId());

final ImageProcessor imp = cache.get(tileSpec.getImagePath(), 0, false, false, ImageLoader.LoaderType.H5_SLICE, null);
final ImagePlus image = new ImagePlus(tileSpec.getTileId(), imp);
if (image.getProcessor() == null) {
LOG.warn("computeStreakStatisticsForLayer: could not load image for tile {}", tileSpec.getTileId());
continue;
}

final StreakFinder streakFinder = parameters.streakFinder.createStreakFinder();
final ImagePlus mask = streakFinder.createStreakMask(image);
addStreakStatisticsForSingleMask(accumulator, mask, tileSpec);
}

layerTiles.recalculateBoundingBoxes();
final Bounds layerBounds = layerTiles.toBounds();
final double[] min = new double[] {layerBounds.getMinX(), layerBounds.getMinY()};
final double[] max = new double[] {layerBounds.getMaxX(), layerBounds.getMaxY()};

return new LayerResults(accumulator.getResults(), min, max);
}

private static void addStreakStatisticsForSingleMask(final StreakAccumulator accumulator, final ImagePlus mask, final TileSpec tileSpec) {
final double[] position = new double[2];
final CoordinateTransformList<CoordinateTransform> transformList = tileSpec.getTransformList();
for (int x = 0; x < mask.getWidth(); x++) {
for (int y = 0; y < mask.getHeight(); y++) {
position[0] = x;
position[1] = y;
transformList.applyInPlace(position);
accumulator.addValue(mask.getProcessor().getf(x, y), position[0], position[1]);
}
}
}


private static class StreakAccumulator {
private final int nX;
private final int nY;

private final int minX;
private final int minY;
private final int width;
private final int height;

private final double[][] sum;
private final long[][] counts;

public StreakAccumulator(final Bounds layerBounds, final int nX, final int nY) {
this.nX = nX;
this.nY = nY;

this.minX = layerBounds.getMinX().intValue();
this.minY = layerBounds.getMinY().intValue();
this.width = layerBounds.getWidth();
this.height = layerBounds.getHeight();

sum = new double[nX][nY];
counts = new long[nX][nY];
}

public void addValue(final double value, final double x, final double y) {
final int i = (int) (nX * (x - minX) / width);
final int j = (int) (nY * (y - minY) / height);
sum[i][j] += value;
counts[i][j]++;
}

public double[][] getResults() {
final double[][] results = new double[nX][nY];
for (int i = 0; i < nX; i++) {
for (int j = 0; j < nY; j++) {
// account for the fact that the mask values are in [0, 255], with 255 indicating a streak
if (counts[i][j] == 0) {
results[i][j] = 0;
} else {
results[i][j] = sum[i][j] / (255 * counts[i][j]);
}
}
}
return results;
}
}


private static class LayerResults implements Serializable{
final private double[][] statistics;
final private double[] min;
final private double[] max;

public LayerResults(final double[][] statistics, final double[] min, final double[] max) {
this.statistics = statistics;
this.min = min;
this.max = max;
}

public double[][] getStatistics() {
return statistics;
}

public double[] getMin() {
return min;
}

public double[] getMax() {
return max;
}
}
}

0 comments on commit c1c8906

Please sign in to comment.