Skip to content

Commit

Permalink
Use IntensityTile in matching and block worker
Browse files Browse the repository at this point in the history
  • Loading branch information
minnerbe committed Jan 26, 2025
1 parent 8115508 commit 4209a5f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 74 deletions.
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
package org.janelia.render.client.newsolver.solvers.intensity;

import mpicbg.models.Affine1D;
import mpicbg.models.AffineModel1D;
import mpicbg.models.ErrorStatistic;
import mpicbg.models.IdentityModel;
import mpicbg.models.InterpolatedAffineModel1D;
import mpicbg.models.NoninvertibleModelException;
import mpicbg.models.PointMatch;
import mpicbg.models.Tile;
import mpicbg.models.TileConfiguration;
import mpicbg.models.TileUtil;
import mpicbg.models.TranslationModel1D;
import net.imglib2.util.ValuePair;

Expand Down Expand Up @@ -76,14 +72,14 @@ public List<BlockData<ArrayList<AffineModel1D>, FIBSEMIntensityCorrectionParamet

final List<TileSpec> wrappedTiles = AdjustBlock.sortTileSpecs(blockData.rtsc());

final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles = computeCoefficients(wrappedTiles);
final Map<String, IntensityTile> coefficientTiles = computeCoefficients(wrappedTiles);

coefficientTiles.forEach((tileId, tiles) -> {
final ArrayList<AffineModel1D> models = new ArrayList<>();
tiles.forEach(tile -> {
final AffineModel1D model = ((InterpolatedAffineModel1D<?, ?>) tile.getModel()).createAffineModel1D();
models.add(model);
});
for (int i = 0; i < tiles.nSubTiles(); i++) {
final InterpolatedAffineModel1D<?, ?> interpolatedModel = (InterpolatedAffineModel1D<?, ?>) tiles.getSubTile(i).getModel();
models.add(interpolatedModel.createAffineModel1D());
}
blockData.getResults().recordModel(tileId, models);
});

Expand All @@ -105,7 +101,7 @@ private void fetchResolvedTiles()
blockData.getResults().init(rtsc);
}

private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> computeCoefficients(final List<TileSpec> tiles)
private Map<String, IntensityTile> computeCoefficients(final List<TileSpec> tiles)
throws ExecutionException, InterruptedException {

LOG.info("computeCoefficients: entry");
Expand All @@ -115,7 +111,7 @@ private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> computeCoefficie
? ImageProcessorCache.DISABLED_CACHE
: new ImageProcessorCache(parameters.maxNumberOfCachedPixels(), true, false);

final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles = splitIntoCoefficientTiles(tiles, imageProcessorCache);
final Map<String, IntensityTile> coefficientTiles = splitIntoCoefficientTiles(tiles, imageProcessorCache);

if (tiles.size() > 1) {
solveForGlobalCoefficients(coefficientTiles, ITERATIONS);
Expand All @@ -126,7 +122,7 @@ private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> computeCoefficie
return coefficientTiles;
}

private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> splitIntoCoefficientTiles(
private HashMap<String, IntensityTile> splitIntoCoefficientTiles(
final List<TileSpec> tiles,
final ImageProcessorCache imageProcessorCache
) throws InterruptedException, ExecutionException {
Expand All @@ -139,8 +135,7 @@ private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> splitIntoCoeffic
LOG.info("splitIntoCoefficientTiles: entry, collecting pairs for {} patches with zDistance {}", tiles.size(), parameters.zDistance());

// generate coefficient tiles for all patches
final int nGridPoints = parameters.numCoefficients() * parameters.numCoefficients();
final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles = generateCoefficientsTiles(tiles, nGridPoints);
final HashMap<String, IntensityTile> coefficientTiles = generateCoefficientsTiles(tiles);

final List<ValuePair<TileSpec, TileSpec>> patchPairs = findOverlappingPatches(tiles, parameters.zDistance());

Expand Down Expand Up @@ -187,24 +182,18 @@ private IntensityMatcher getIntensityMatcher(
return new IntensityMatcher(filter, parameters, meshResolution, imageProcessorCache);
}

private HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> generateCoefficientsTiles(
final Collection<TileSpec> patches,
final int nGridPoints
) {
private HashMap<String, IntensityTile> generateCoefficientsTiles(final Collection<TileSpec> patches) {

final InterpolatedAffineModel1D<InterpolatedAffineModel1D<AffineModel1D, TranslationModel1D>, IdentityModel> modelTemplate =
new InterpolatedAffineModel1D<>(
new InterpolatedAffineModel1D<>(
new AffineModel1D(), new TranslationModel1D(), parameters.lambdaTranslation()),
new IdentityModel(), parameters.lambdaIdentity());

final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles = new HashMap<>();
final HashMap<String, IntensityTile> coefficientTiles = new HashMap<>();
for (final TileSpec p : patches) {
final ArrayList<Tile<? extends Affine1D<?>>> coefficientModels = new ArrayList<>();
for (int i = 0; i < nGridPoints; ++i) {
final InterpolatedAffineModel1D<?,?> model = modelTemplate.copy();
coefficientModels.add(new Tile<>(model));
}
coefficientTiles.put(p.getTileId(), coefficientModels);
final IntensityTile tile = new IntensityTile(modelTemplate::copy, parameters.numCoefficients(), 1);
coefficientTiles.put(p.getTileId(), tile);
}
return coefficientTiles;
}
Expand Down Expand Up @@ -237,34 +226,34 @@ private static ArrayList<ValuePair<TileSpec, TileSpec>> findOverlappingPatches(

@SuppressWarnings("SameParameterValue")
private void solveForGlobalCoefficients(
final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles,
final Map<String, IntensityTile> coefficientTiles,
final int iterations
) {
final Tile<? extends Affine1D<?>> equilibrationTile = new Tile<>(new IdentityModel());
final IntensityTile equilibrationTile = new IntensityTile(IdentityModel::new, 1, 1);

connectTilesWithinPatches(coefficientTiles, equilibrationTile);

/* optimize */
final TileConfiguration tc = new TileConfiguration();
coefficientTiles.values().forEach(tc::addTiles);

// anchor the equilibration tile
tc.addTile(equilibrationTile);
tc.fixTile(equilibrationTile);

LOG.info("solveForGlobalCoefficients: optimizing {} tiles with {} threads", tc.getTiles().size(), numThreads);
try {
TileUtil.optimizeConcurrently(new ErrorStatistic(iterations + 1), 0.01f, iterations, iterations, 0.75f, tc, tc.getTiles(), tc.getFixedTiles(), numThreads);
} catch (final Exception e) {
throw new RuntimeException(e);
final List<IntensityTile> tiles = new ArrayList<>(coefficientTiles.values());
final List<IntensityTile> fixedTiles = new ArrayList<>();

// anchor the equilibration tile if it is used, otherwise anchor a random tile (the first one)
if (blockData.solveTypeParameters().equilibrationWeight() > 0.0) {
tiles.add(equilibrationTile);
fixedTiles.add(equilibrationTile);
} else {
final IntensityTile firstTile = tiles.get(0);
fixedTiles.add(firstTile);
}

LOG.info("solveForGlobalCoefficients: optimizing {} tiles with {} threads", tiles.size(), numThreads);
final IntensityTileOptimizer optimizer = new IntensityTileOptimizer(0.01, iterations, iterations, 0.75, numThreads);
optimizer.optimize(tiles, fixedTiles);

// TODO: this is not the right error measure, what is idToBlockErrorMap supposed to be exactly?
coefficientTiles.forEach((tileId, tiles) -> {
final Double error = tiles.stream().mapToDouble(t -> {
t.updateCost();
return t.getDistance();
}).average().orElse(Double.MAX_VALUE);
coefficientTiles.forEach((tileId, tile) -> {
tile.updateDistance();
final double error = tile.getDistance();
final Map<String, Double> errorMap = new HashMap<>();
errorMap.put(tileId, error);
blockData.getResults().recordAllErrors(tileId, errorMap);
Expand All @@ -274,48 +263,39 @@ private void solveForGlobalCoefficients(
}

private void connectTilesWithinPatches(
final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles,
final Tile<? extends Affine1D<?>> equilibrationTile
final Map<String, IntensityTile> coefficientTiles,
final IntensityTile equilibrationTile
) {
final Collection<TileSpec> allTiles = blockData.rtsc().getTileSpecs();
final double equilibrationWeight = blockData.solveTypeParameters().equilibrationWeight();

final ResultContainer<ArrayList<AffineModel1D>> results = blockData.getResults();
for (final TileSpec p : allTiles) {
final List<? extends Tile<?>> coefficientTile = coefficientTiles.get(p.getTileId());
final IntensityTile coefficientTile = coefficientTiles.get(p.getTileId());
for (int i = 1; i < parameters.numCoefficients(); ++i) {
for (int j = 0; j < parameters.numCoefficients(); ++j) {
final int left = getLinearIndex(i-1, j, parameters.numCoefficients());
final int right = getLinearIndex(i, j, parameters.numCoefficients());
final int top = getLinearIndex(j, i, parameters.numCoefficients());
final int bot = getLinearIndex(j, i-1, parameters.numCoefficients());
final Tile<?> left = coefficientTile.getSubTile(i-1, j);
final Tile<?> right = coefficientTile.getSubTile(i, j);
final Tile<?> top = coefficientTile.getSubTile(j, i);
final Tile<?> bot = coefficientTile.getSubTile(j, i-1);

identityConnect(coefficientTile.get(right), coefficientTile.get(left));
identityConnect(coefficientTile.get(top), coefficientTile.get(bot));
identityConnect(right, left);
identityConnect(top, bot);
}
}
if (equilibrationWeight > 0.0) {
final List<Double> averages = results.getAveragesFor(p.getTileId());
for (int i = 0; i < parameters.numCoefficients(); i++) {
for (int j = 0; j < parameters.numCoefficients(); j++) {
final int idx = getLinearIndex(i, j, parameters.numCoefficients());
equilibrateIntensity(coefficientTile.get(idx),
equilibrationTile,
averages.get(idx),
equilibrationWeight);
}
coefficientTile.connectTo(equilibrationTile);
for (int i = 0; i < coefficientTile.nSubTiles(); i++) {
equilibrateIntensity(coefficientTile.getSubTile(i),
equilibrationTile.getSubTile(0),
averages.get(i),
equilibrationWeight);
}
}
}
}

/**
* Get index of the (x,y) pixel in an n x n grid represented by a linear array
*/
private int getLinearIndex(final int x, final int y, final int n) {
return y * n + x;
}

private static void equilibrateIntensity(final Tile<?> coefficientTile,
final Tile<?> equilibrationTile,
final Double average,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import ij.process.ColorProcessor;
import ij.process.FloatProcessor;
import mpicbg.models.Affine1D;
import mpicbg.models.PointMatch;
import mpicbg.models.Tile;
import net.imglib2.util.Pair;
Expand Down Expand Up @@ -55,7 +54,7 @@ public IntensityMatcher(
this.imageProcessorCache = imageProcessorCache;
}

public void match(final TileSpec p1, final TileSpec p2, final HashMap<String, ArrayList<Tile<? extends Affine1D<?>>>> coefficientTiles) {
public void match(final TileSpec p1, final TileSpec p2, final HashMap<String, IntensityTile> intensityTiles) {

final StopWatch stopWatch = StopWatch.createAndStart();

Expand Down Expand Up @@ -115,24 +114,28 @@ public void match(final TileSpec p1, final TileSpec p2, final HashMap<String, Ar
}

/* connect tiles across patches */
final List<Tile<? extends Affine1D<?>>> p1CoefficientTiles = coefficientTiles.get(p1.getTileId());
final List<Tile<? extends Affine1D<?>>> p2CoefficientTiles = coefficientTiles.get(p2.getTileId());
final IntensityTile p1IntensityTile = intensityTiles.get(p1.getTileId());
final IntensityTile p2IntensityTile = intensityTiles.get(p2.getTileId());
int connectionCount = 0;

for (int i = 0; i < nCoefficientTiles; ++i) {
final Tile<?> t1 = p1CoefficientTiles.get(i);
final Tile<?> t1 = p1IntensityTile.getSubTile(i);

for (int j = 0; j < nCoefficientTiles; ++j) {
final List<PointMatch> matches = get(matrix, i, j, nCoefficientTiles);
if (matches.isEmpty())
continue;

final Tile<?> t2 = p2CoefficientTiles.get(j);
final Tile<?> t2 = p2IntensityTile.getSubTile(j);
t1.connect(t2, matches);
connectionCount++;
}
}

if (connectionCount > 0) {
p1IntensityTile.connectTo(p2IntensityTile);
}

stopWatch.stop();
LOG.info("match: pair {} <-> {} has {} connections, matching took {}", p1.getTileId(), p2.getTileId(), connectionCount, stopWatch);
}
Expand Down

0 comments on commit 4209a5f

Please sign in to comment.