-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add IntensityTile and corresponding optimizer
- Loading branch information
Showing
2 changed files
with
350 additions
and
0 deletions.
There are no files selected for viewing
126 changes: 126 additions & 0 deletions
126
...nt/src/main/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityTile.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,126 @@ | ||
package org.janelia.render.client.newsolver.solvers.intensity; | ||
|
||
import mpicbg.models.Affine1D; | ||
import mpicbg.models.IllDefinedDataPointsException; | ||
import mpicbg.models.Model; | ||
import mpicbg.models.NotEnoughDataPointsException; | ||
import mpicbg.models.Tile; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.HashSet; | ||
import java.util.List; | ||
import java.util.Set; | ||
import java.util.function.Supplier; | ||
|
||
|
||
/** | ||
* A tile that contains a grid of sub-tiles, each of which has a model that can be fitted and applied. This acts as a | ||
* convenience class for handling the fitting and applying of the models of the sub-tiles necessary for intensity | ||
* correction. The encapsulated sub-tiles also potentially speed up the optimization process by reducing the overhead | ||
* of parallelizing the optimization of the sub-tiles. | ||
* <p> | ||
* This class doesn't derive from {@link Tile} because most of the methods there are tagged final, so they cannot be | ||
* overridden. This class should only be used in the context of intensity correction. | ||
*/ | ||
class IntensityTile { | ||
|
||
final private int nSubTilesPerDimension; | ||
final private int nFittingCycles; | ||
final private List<Tile<? extends Affine1D<?>>> subTiles; | ||
|
||
private double distance = 0; | ||
private final Set<IntensityTile> connectedTiles = new HashSet<>(); | ||
|
||
/** | ||
* Creates a new intensity tile with the specified number of sub-tiles per dimension and the number of fitting | ||
* cycles to perform within one fit of the intensity tile. | ||
* @param modelSupplier supplies instances of the model to use for the sub-tiles | ||
* @param nSubTilesPerDimension the number of sub-tiles per side of the tile | ||
* @param nFittingCycles the number of fitting cycles | ||
*/ | ||
@SuppressWarnings({"unchecked", "rawtypes"}) | ||
public IntensityTile( | ||
final Supplier<? extends Affine1D<?>> modelSupplier, | ||
final int nSubTilesPerDimension, | ||
final int nFittingCycles | ||
) { | ||
this.nSubTilesPerDimension = nSubTilesPerDimension; | ||
this.nFittingCycles = nFittingCycles; | ||
final int N = nSubTilesPerDimension * nSubTilesPerDimension; | ||
this.subTiles = new ArrayList<>(N); | ||
|
||
for (int i = 0; i < N; i++) { | ||
final Affine1D<?> model = modelSupplier.get(); | ||
this.subTiles.add(new Tile<>((Model) model)); | ||
} | ||
} | ||
|
||
public Tile<? extends Affine1D<?>> getSubTile(final int i) { | ||
return this.subTiles.get(i); | ||
} | ||
|
||
public Tile<? extends Affine1D<?>> getSubTile(final int i, final int j) { | ||
return this.subTiles.get(i * this.nSubTilesPerDimension + j); | ||
} | ||
|
||
public int nSubTiles() { | ||
return this.subTiles.size(); | ||
} | ||
|
||
public double getDistance() { | ||
return distance; | ||
} | ||
|
||
/** | ||
* Updates the distance of this tile. The distance is the maximum distance of all sub-tiles. | ||
*/ | ||
public void updateDistance() { | ||
distance = 0; | ||
for (final Tile<?> subTile : this.subTiles) { | ||
subTile.updateCost(); | ||
distance = Math.max(distance, subTile.getDistance()); | ||
} | ||
} | ||
|
||
public Set<IntensityTile> getConnectedTiles() { | ||
return connectedTiles; | ||
} | ||
|
||
/** | ||
* Connects this tile to another tile. In contrast to the connect method of the Tile class, this method also | ||
* connects the other tile to this tile. | ||
* @param otherTile the tile to connect to (bidirectional connection) | ||
*/ | ||
public void connectTo(final IntensityTile otherTile) { | ||
connectedTiles.add(otherTile); | ||
otherTile.connectedTiles.add(this); | ||
} | ||
|
||
/** | ||
* Fits the model of all sub-tiles as often as specified by the nFittingCycles parameter. After fitting the model, | ||
* the model is immediately applied to the sub-tile. | ||
* @param damp the damping factor to apply to the model | ||
* @throws NotEnoughDataPointsException if there are not enough data points to fit the model | ||
* @throws IllDefinedDataPointsException if the data points are such that the model cannot be fitted | ||
*/ | ||
public void fitAndApply(final double damp) throws NotEnoughDataPointsException, IllDefinedDataPointsException { | ||
final List<Tile<? extends Affine1D<?>>> shuffledTiles = new ArrayList<>(this.subTiles); | ||
for (int i = 0; i < nFittingCycles; i++) { | ||
Collections.shuffle(shuffledTiles); | ||
for (final Tile<? extends Affine1D<?>> subTile : shuffledTiles) { | ||
subTile.fitModel(); | ||
subTile.apply(damp); | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Applies the model of all sub-tiles. | ||
*/ | ||
public void apply() { | ||
for (final Tile<? extends Affine1D<?>> subTile : this.subTiles) { | ||
subTile.apply(); | ||
} | ||
} | ||
} |
224 changes: 224 additions & 0 deletions
224
...in/java/org/janelia/render/client/newsolver/solvers/intensity/IntensityTileOptimizer.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,224 @@ | ||
package org.janelia.render.client.newsolver.solvers.intensity; | ||
|
||
import mpicbg.models.ErrorStatistic; | ||
import mpicbg.models.IllDefinedDataPointsException; | ||
import mpicbg.models.NotEnoughDataPointsException; | ||
import mpicbg.models.Tile; | ||
import mpicbg.models.TileConfiguration; | ||
import org.slf4j.Logger; | ||
import org.slf4j.LoggerFactory; | ||
|
||
import java.util.ArrayList; | ||
import java.util.Collections; | ||
import java.util.Deque; | ||
import java.util.DoubleSummaryStatistics; | ||
import java.util.List; | ||
import java.util.Set; | ||
import java.util.concurrent.ConcurrentHashMap; | ||
import java.util.concurrent.ConcurrentLinkedDeque; | ||
import java.util.concurrent.ExecutionException; | ||
import java.util.concurrent.Executors; | ||
import java.util.concurrent.Future; | ||
import java.util.concurrent.ThreadPoolExecutor; | ||
|
||
|
||
/** | ||
* Concurrent optimizer for collections of {@link IntensityTile}s. This is basically a slightly modified implementation | ||
* of {@link mpicbg.models.TileUtil#optimizeConcurrently(ErrorStatistic, double, int, int, double, TileConfiguration, Set, Set, int, boolean)}, | ||
* which is necessary because {@link IntensityTile} doesn't derive from {@link Tile}. | ||
* Also, some methods of {@link mpicbg.models.TileConfiguration} are re-implemented here for the same reason. | ||
* <p> | ||
* Since an {@link IntensityTile} hides the sub-tiles it contains, this reduces the parallelization overhead of the | ||
* optimizer, which would otherwise have to synchronize access to the (large number of) sub-tiles. | ||
*/ | ||
class IntensityTileOptimizer { | ||
|
||
private static final Logger LOG = LoggerFactory.getLogger(IntensityTileOptimizer.class); | ||
|
||
private final double maxAllowedError; | ||
private final int maxIterations; | ||
private final int maxPlateauWidth; | ||
private final double damp; | ||
private final int nThreads; | ||
|
||
public IntensityTileOptimizer( | ||
final double maxAllowedError, | ||
final int maxIterations, | ||
final int maxPlateauWidth, | ||
final double damp, | ||
final int nThreads | ||
) { | ||
this.maxAllowedError = maxAllowedError; | ||
this.maxIterations = maxIterations; | ||
this.maxPlateauWidth = maxPlateauWidth; | ||
this.damp = damp; | ||
this.nThreads = nThreads; | ||
} | ||
|
||
public void optimize( | ||
final List<IntensityTile> tiles, | ||
final List<IntensityTile> fixedTiles | ||
) { | ||
|
||
final ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(nThreads); | ||
try { | ||
final long t0 = System.currentTimeMillis(); | ||
final ErrorStatistic observer = new ErrorStatistic(maxIterations + 1); | ||
|
||
final List<IntensityTile> freeTiles = new ArrayList<>(tiles); | ||
freeTiles.removeAll(fixedTiles); | ||
Collections.shuffle(freeTiles); | ||
|
||
final long t1 = System.currentTimeMillis(); | ||
LOG.debug("Shuffling took {} ms", t1 - t0); | ||
|
||
/* initialize the configuration with the current model of each tile */ | ||
applyAll(tiles, executor); | ||
|
||
final long t2 = System.currentTimeMillis(); | ||
LOG.debug("First apply took {} ms", t2 - t1); | ||
|
||
int i = 0; | ||
boolean proceed = i < maxIterations; | ||
final Set<IntensityTile> executingTiles = ConcurrentHashMap.newKeySet(); | ||
|
||
while (proceed) { | ||
Collections.shuffle(freeTiles); | ||
final Deque<IntensityTile> pending = new ConcurrentLinkedDeque<>(freeTiles); | ||
final List<Future<Void>> tasks = new ArrayList<>(nThreads); | ||
|
||
for (int j = 0; j < nThreads; j++) { | ||
final boolean cleanUp = (j == 0); | ||
tasks.add(executor.submit(() -> fitAndApplyWorker(pending, executingTiles, damp, cleanUp))); | ||
} | ||
|
||
for (final Future<Void> task : tasks) { | ||
try { | ||
task.get(); | ||
} catch (final InterruptedException | ExecutionException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
final double error = computeErrors(tiles, executor); | ||
observer.add(error); | ||
|
||
LOG.debug("{}: {} {}", i, error, observer.max); | ||
|
||
if (i > maxPlateauWidth) { | ||
proceed = error > maxAllowedError; | ||
|
||
int d = maxPlateauWidth; | ||
while (!proceed && d >= 1) { | ||
try { | ||
proceed = Math.abs(observer.getWideSlope(d)) > 0.0001; | ||
} catch (final Exception e) { | ||
LOG.warn("Error while computing slope: {}", e.getMessage()); | ||
} | ||
d /= 2; | ||
} | ||
} | ||
|
||
proceed &= ++i < maxIterations; | ||
} | ||
|
||
final long t3 = System.currentTimeMillis(); | ||
LOG.info("Concurrent tile optimization loop took {} ms, total took {} ms", t3 - t2, t3 - t0); | ||
|
||
} finally { | ||
executor.shutdownNow(); | ||
} | ||
} | ||
|
||
private static void applyAll(final List<IntensityTile> tiles, final ThreadPoolExecutor executor) { | ||
final int nTiles = tiles.size(); | ||
final int nThreads = executor.getMaximumPoolSize(); | ||
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1); | ||
final List<Future<Void>> applyTasks = new ArrayList<>(nThreads); | ||
|
||
for (int j = 0; j < nThreads; j++) { | ||
final int start = j * tilesPerThread; | ||
final int end = Math.min((j + 1) * tilesPerThread, nTiles); | ||
applyTasks.add(executor.submit(() -> applyToRange(tiles, start, end))); | ||
} | ||
|
||
for (final Future<Void> task : applyTasks) { | ||
try { | ||
task.get(); | ||
} catch (final InterruptedException | ExecutionException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
} | ||
|
||
private static Void applyToRange(final List<IntensityTile> tiles, final int start, final int end) { | ||
for (int i = start; i < end; i++) { | ||
final IntensityTile t = tiles.get(i); | ||
t.apply(); | ||
} | ||
return null; | ||
} | ||
|
||
private static double computeErrors(final List<IntensityTile> tiles, final ThreadPoolExecutor executor) { | ||
final int nTiles = tiles.size(); | ||
final int nThreads = executor.getMaximumPoolSize(); | ||
final int tilesPerThread = nTiles / nThreads + (nTiles % nThreads == 0 ? 0 : 1); | ||
final List<Future<DoubleSummaryStatistics>> applyTasks = new ArrayList<>(nThreads); | ||
|
||
for (int j = 0; j < nThreads; j++) { | ||
final int start = j * tilesPerThread; | ||
final int end = Math.min((j + 1) * tilesPerThread, nTiles); | ||
applyTasks.add(executor.submit(() -> computeErrorsOfRange(tiles, start, end))); | ||
} | ||
|
||
final DoubleSummaryStatistics totalStats = new DoubleSummaryStatistics(); | ||
for (final Future<DoubleSummaryStatistics> task : applyTasks) { | ||
try { | ||
final DoubleSummaryStatistics taskStats = task.get(); | ||
totalStats.combine(taskStats); | ||
} catch (final InterruptedException | ExecutionException e) { | ||
throw new RuntimeException(e); | ||
} | ||
} | ||
|
||
return totalStats.getAverage(); | ||
} | ||
|
||
private static DoubleSummaryStatistics computeErrorsOfRange(final List<IntensityTile> tiles, final int start, final int end) { | ||
final DoubleSummaryStatistics stats = new DoubleSummaryStatistics(); | ||
for (int i = start; i < end; i++) { | ||
final IntensityTile t = tiles.get(i); | ||
t.updateDistance(); | ||
stats.accept(t.getDistance()); | ||
} | ||
return stats; | ||
} | ||
|
||
private static Void fitAndApplyWorker( | ||
final Deque<IntensityTile> pendingTiles, | ||
final Set<IntensityTile> executingTiles, | ||
final double damp, | ||
final boolean cleanUp | ||
) throws NotEnoughDataPointsException, IllDefinedDataPointsException { | ||
|
||
final int n = pendingTiles.size(); | ||
for (int i = 0; (i < n) || cleanUp; i++){ | ||
// the polled tile can only be null if the deque is empty, i.e., there is no more work | ||
final IntensityTile tile = pendingTiles.pollFirst(); | ||
if (tile == null) | ||
return null; | ||
|
||
executingTiles.add(tile); | ||
final boolean canBeProcessed = Collections.disjoint(tile.getConnectedTiles(), executingTiles); | ||
|
||
if (canBeProcessed) { | ||
tile.fitAndApply(damp); | ||
executingTiles.remove(tile); | ||
} else { | ||
executingTiles.remove(tile); | ||
pendingTiles.addLast(tile); | ||
} | ||
} | ||
return null; | ||
} | ||
} |