Skip to content

Commit

Permalink
Add IntensityTile and corresponding optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
minnerbe committed Jan 26, 2025
1 parent dc23520 commit 8115508
Show file tree
Hide file tree
Showing 2 changed files with 350 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package org.janelia.render.client.newsolver.solvers.intensity;

import mpicbg.models.Affine1D;
import mpicbg.models.IllDefinedDataPointsException;
import mpicbg.models.Model;
import mpicbg.models.NotEnoughDataPointsException;
import mpicbg.models.Tile;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.function.Supplier;


/**
* A tile that contains a grid of sub-tiles, each of which has a model that can be fitted and applied. This acts as a
* convenience class for handling the fitting and applying of the models of the sub-tiles necessary for intensity
* correction. The encapsulated sub-tiles also potentially speed up the optimization process by reducing the overhead
* of parallelizing the optimization of the sub-tiles.
* <p>
* This class doesn't derive from {@link Tile} because most of the methods there are tagged final, so they cannot be
* overridden. This class should only be used in the context of intensity correction.
*/
class IntensityTile {

final private int nSubTilesPerDimension;
final private int nFittingCycles;
final private List<Tile<? extends Affine1D<?>>> subTiles;

private double distance = 0;
private final Set<IntensityTile> connectedTiles = new HashSet<>();

/**
* Creates a new intensity tile with the specified number of sub-tiles per dimension and the number of fitting
* cycles to perform within one fit of the intensity tile.
* @param modelSupplier supplies instances of the model to use for the sub-tiles
* @param nSubTilesPerDimension the number of sub-tiles per side of the tile
* @param nFittingCycles the number of fitting cycles
*/
@SuppressWarnings({"unchecked", "rawtypes"})
public IntensityTile(
final Supplier<? extends Affine1D<?>> modelSupplier,
final int nSubTilesPerDimension,
final int nFittingCycles
) {
this.nSubTilesPerDimension = nSubTilesPerDimension;
this.nFittingCycles = nFittingCycles;
final int N = nSubTilesPerDimension * nSubTilesPerDimension;
this.subTiles = new ArrayList<>(N);

for (int i = 0; i < N; i++) {
final Affine1D<?> model = modelSupplier.get();
this.subTiles.add(new Tile<>((Model) model));
}
}

public Tile<? extends Affine1D<?>> getSubTile(final int i) {
return this.subTiles.get(i);
}

public Tile<? extends Affine1D<?>> getSubTile(final int i, final int j) {
return this.subTiles.get(i * this.nSubTilesPerDimension + j);
}

public int nSubTiles() {
return this.subTiles.size();
}

public double getDistance() {
return distance;
}

/**
* Updates the distance of this tile. The distance is the maximum distance of all sub-tiles.
*/
public void updateDistance() {
distance = 0;
for (final Tile<?> subTile : this.subTiles) {
subTile.updateCost();
distance = Math.max(distance, subTile.getDistance());
}
}

public Set<IntensityTile> getConnectedTiles() {
return connectedTiles;
}

/**
* Connects this tile to another tile. In contrast to the connect method of the Tile class, this method also
* connects the other tile to this tile.
* @param otherTile the tile to connect to (bidirectional connection)
*/
public void connectTo(final IntensityTile otherTile) {
connectedTiles.add(otherTile);
otherTile.connectedTiles.add(this);
}

/**
* Fits the model of all sub-tiles as often as specified by the nFittingCycles parameter. After fitting the model,
* the model is immediately applied to the sub-tile.
* @param damp the damping factor to apply to the model
* @throws NotEnoughDataPointsException if there are not enough data points to fit the model
* @throws IllDefinedDataPointsException if the data points are such that the model cannot be fitted
*/
public void fitAndApply(final double damp) throws NotEnoughDataPointsException, IllDefinedDataPointsException {
final List<Tile<? extends Affine1D<?>>> shuffledTiles = new ArrayList<>(this.subTiles);
for (int i = 0; i < nFittingCycles; i++) {
Collections.shuffle(shuffledTiles);
for (final Tile<? extends Affine1D<?>> subTile : shuffledTiles) {
subTile.fitModel();
subTile.apply(damp);
}
}
}

/**
* Applies the model of all sub-tiles.
*/
public void apply() {
for (final Tile<? extends Affine1D<?>> subTile : this.subTiles) {
subTile.apply();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
package org.janelia.render.client.newsolver.solvers.intensity;

import mpicbg.models.ErrorStatistic;
import mpicbg.models.IllDefinedDataPointsException;
import mpicbg.models.NotEnoughDataPointsException;
import mpicbg.models.Tile;
import mpicbg.models.TileConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.DoubleSummaryStatistics;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadPoolExecutor;


/**
* Concurrent optimizer for collections of {@link IntensityTile}s. This is basically a slightly modified implementation
* of {@link mpicbg.models.TileUtil#optimizeConcurrently(ErrorStatistic, double, int, int, double, TileConfiguration, Set, Set, int, boolean)},
* which is necessary because {@link IntensityTile} doesn't derive from {@link Tile}.
* Also, some methods of {@link mpicbg.models.TileConfiguration} are re-implemented here for the same reason.
* <p>
* Since an {@link IntensityTile} hides the sub-tiles it contains, this reduces the parallelization overhead of the
* optimizer, which would otherwise have to synchronize access to the (large number of) sub-tiles.
*/
class IntensityTileOptimizer {

private static final Logger LOG = LoggerFactory.getLogger(IntensityTileOptimizer.class);

private final double maxAllowedError;
private final int maxIterations;
private final int maxPlateauWidth;
private final double damp;
private final int nThreads;

public IntensityTileOptimizer(
final double maxAllowedError,
final int maxIterations,
final int maxPlateauWidth,
final double damp,
final int nThreads
) {
this.maxAllowedError = maxAllowedError;
this.maxIterations = maxIterations;
this.maxPlateauWidth = maxPlateauWidth;
this.damp = damp;
this.nThreads = nThreads;
}

public void optimize(
final List<IntensityTile> tiles,
final List<IntensityTile> fixedTiles
) {

final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(nThreads);
try {
final long t0 = System.currentTimeMillis();
final ErrorStatistic observer = new ErrorStatistic(maxIterations + 1);

final List<IntensityTile> freeTiles = new ArrayList<>(tiles);
freeTiles.removeAll(fixedTiles);
Collections.shuffle(freeTiles);

final long t1 = System.currentTimeMillis();
LOG.debug("Shuffling took {} ms", t1 - t0);

/* initialize the configuration with the current model of each tile */
applyAll(tiles, executor);

final long t2 = System.currentTimeMillis();
LOG.debug("First apply took {} ms", t2 - t1);

int i = 0;
boolean proceed = i < maxIterations;
final Set<IntensityTile> executingTiles = ConcurrentHashMap.newKeySet();

while (proceed) {
Collections.shuffle(freeTiles);
final Deque<IntensityTile> pending = new ConcurrentLinkedDeque<>(freeTiles);
final List<Future<Void>> tasks = new ArrayList<>(nThreads);

for (int j = 0; j < nThreads; j++) {
final boolean cleanUp = (j == 0);
tasks.add(executor.submit(() -> fitAndApplyWorker(pending, executingTiles, damp, cleanUp)));
}

for (final Future<Void> task : tasks) {
try {
task.get();
} catch (final InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

final double error = computeErrors(tiles, executor);
observer.add(error);

LOG.debug("{}: {} {}", i, error, observer.max);

if (i > maxPlateauWidth) {
proceed = error > maxAllowedError;

int d = maxPlateauWidth;
while (!proceed && d >= 1) {
try {
proceed = Math.abs(observer.getWideSlope(d)) > 0.0001;
} catch (final Exception e) {
LOG.warn("Error while computing slope: {}", e.getMessage());
}
d /= 2;
}
}

proceed &= ++i < maxIterations;
}

final long t3 = System.currentTimeMillis();
LOG.info("Concurrent tile optimization loop took {} ms, total took {} ms", t3 - t2, t3 - t0);

} finally {
executor.shutdownNow();
}
}

private static void applyAll(final List<IntensityTile> tiles, final ThreadPoolExecutor executor) {
final int nTiles = tiles.size();
final int nThreads = executor.getMaximumPoolSize();
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1);
final List<Future<Void>> applyTasks = new ArrayList<>(nThreads);

for (int j = 0; j < nThreads; j++) {
final int start = j * tilesPerThread;
final int end = Math.min((j + 1) * tilesPerThread, nTiles);
applyTasks.add(executor.submit(() -> applyToRange(tiles, start, end)));
}

for (final Future<Void> task : applyTasks) {
try {
task.get();
} catch (final InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}
}

private static Void applyToRange(final List<IntensityTile> tiles, final int start, final int end) {
for (int i = start; i < end; i++) {
final IntensityTile t = tiles.get(i);
t.apply();
}
return null;
}

private static double computeErrors(final List<IntensityTile> tiles, final ThreadPoolExecutor executor) {
final int nTiles = tiles.size();
final int nThreads = executor.getMaximumPoolSize();
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1);
final List<Future<DoubleSummaryStatistics>> applyTasks = new ArrayList<>(nThreads);

for (int j = 0; j < nThreads; j++) {
final int start = j * tilesPerThread;
final int end = Math.min((j + 1) * tilesPerThread, nTiles);
applyTasks.add(executor.submit(() -> computeErrorsOfRange(tiles, start, end)));
}

final DoubleSummaryStatistics totalStats = new DoubleSummaryStatistics();
for (final Future<DoubleSummaryStatistics> task : applyTasks) {
try {
final DoubleSummaryStatistics taskStats = task.get();
totalStats.combine(taskStats);
} catch (final InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
}

return totalStats.getAverage();
}

private static DoubleSummaryStatistics computeErrorsOfRange(final List<IntensityTile> tiles, final int start, final int end) {
final DoubleSummaryStatistics stats = new DoubleSummaryStatistics();
for (int i = start; i < end; i++) {
final IntensityTile t = tiles.get(i);
t.updateDistance();
stats.accept(t.getDistance());
}
return stats;
}

private static Void fitAndApplyWorker(
final Deque<IntensityTile> pendingTiles,
final Set<IntensityTile> executingTiles,
final double damp,
final boolean cleanUp
) throws NotEnoughDataPointsException, IllDefinedDataPointsException {

final int n = pendingTiles.size();
for (int i = 0; (i < n) || cleanUp; i++){
// the polled tile can only be null if the deque is empty, i.e., there is no more work
final IntensityTile tile = pendingTiles.pollFirst();
if (tile == null)
return null;

executingTiles.add(tile);
final boolean canBeProcessed = Collections.disjoint(tile.getConnectedTiles(), executingTiles);

if (canBeProcessed) {
tile.fitAndApply(damp);
executingTiles.remove(tile);
} else {
executingTiles.remove(tile);
pendingTiles.addLast(tile);
}
}
return null;
}
}

0 comments on commit 8115508

Please sign in to comment.