Skip to content

Commit

Permalink
refactor!: use TaskExecutor instead of numThreads and ExecutorService
Browse files Browse the repository at this point in the history
  • Loading branch information
cmhulbert committed Sep 12, 2023
1 parent 89a1adc commit ed2ca94
Show file tree
Hide file tree
Showing 12 changed files with 167 additions and 235 deletions.
51 changes: 20 additions & 31 deletions src/main/java/bdv/fx/viewer/ViewerPanelFX.java
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import net.imglib2.RealLocalizable;
import net.imglib2.RealPoint;
import net.imglib2.RealPositionable;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.realtransform.AffineTransform3D;
import net.imglib2.util.Intervals;
import org.janelia.saalfeldlab.fx.ObservablePosition;
Expand All @@ -61,8 +62,6 @@
import java.util.Collection;
import java.util.List;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
Expand Down Expand Up @@ -91,8 +90,6 @@ public class ViewerPanelFX

private ThreadGroup threadGroup;

private final ExecutorService renderingExecutorService;

private final CopyOnWriteArrayList<TransformListener<AffineTransform3D>> transformListeners;

private final ViewerOptions.Values options;
Expand All @@ -105,9 +102,10 @@ public ViewerPanelFX(
final List<SourceAndConverter<?>> sources,
final int numTimePoints,
final CacheControl cacheControl,
final Function<Source<?>, Interpolation> interpolation) {
final Function<Source<?>, Interpolation> interpolation,
final TaskExecutor taskExecutor) {

this(sources, numTimePoints, cacheControl, ViewerOptions.options(), interpolation);
this(sources, numTimePoints, cacheControl, ViewerOptions.options(), interpolation, taskExecutor);
}

/**
Expand All @@ -120,9 +118,10 @@ public ViewerPanelFX(
public ViewerPanelFX(
final CacheControl cacheControl,
final ViewerOptions optional,
final Function<Source<?>, Interpolation> interpolation) {
final Function<Source<?>, Interpolation> interpolation,
final TaskExecutor taskExecutor) {

this(1, cacheControl, optional, interpolation);
this(1, cacheControl, optional, interpolation, taskExecutor);
}

/**
Expand All @@ -137,9 +136,10 @@ public ViewerPanelFX(
final int numTimepoints,
final CacheControl cacheControl,
final ViewerOptions optional,
final Function<Source<?>, Interpolation> interpolation) {
final Function<Source<?>, Interpolation> interpolation,
final TaskExecutor taskExecutor) {

this(new ArrayList<>(), numTimepoints, cacheControl, optional, interpolation);
this(new ArrayList<>(), numTimepoints, cacheControl, optional, interpolation, taskExecutor);
}

/**
Expand All @@ -156,11 +156,11 @@ public ViewerPanelFX(
final int numTimepoints,
final CacheControl cacheControl,
final ViewerOptions optional,
final Function<Source<?>, Interpolation> interpolation) {
final Function<Source<?>, Interpolation> interpolation,
final TaskExecutor taskExecutor) {

super();
super.getChildren().setAll(canvasPane, overlayPane);
this.renderingExecutorService = Executors.newFixedThreadPool(optional.values.getNumRenderingThreads(), new RenderThreadFactory());
options = optional.values;

threadGroup = new ThreadGroup(this.toString());
Expand All @@ -177,15 +177,14 @@ public ViewerPanelFX(
options.getAccumulateProjectorFactory(),
cacheControl,
options.getTargetRenderNanos(),
options.getNumRenderingThreads(),
renderingExecutorService
taskExecutor
);

setRenderedImageListener();
setWidth(options.getWidth());
setHeight(options.getHeight());
this.widthProperty().addListener((obs, oldv, newv) -> this.renderUnit.setDimensions((long)getWidth(), (long)getHeight()));
this.heightProperty().addListener((obs, oldv, newv) -> this.renderUnit.setDimensions((long)getWidth(), (long)getHeight()));
this.widthProperty().addListener((obs, oldv, newv) -> this.renderUnit.setDimensions((long) getWidth(), (long) getHeight()));
this.heightProperty().addListener((obs, oldv, newv) -> this.renderUnit.setDimensions((long) getWidth(), (long) getHeight()));

transformListeners.add(tf -> Paintera.whenPaintable(getDisplay()::drawOverlays));

Expand Down Expand Up @@ -217,7 +216,8 @@ public void setFocusable(boolean focusable) {
this.focusable = focusable;
}

@Override public void requestFocus() {
@Override
public void requestFocus() {

if (this.focusable) {
super.requestFocus();
Expand Down Expand Up @@ -324,8 +324,8 @@ public void getMouseCoordinates(final Positionable p) {

assert p.numDimensions() >= 2;
synchronized (mouseTracker) {
p.setPosition((long)mouseTracker.getMouseX(), 0);
p.setPosition((long)mouseTracker.getMouseY(), 1);
p.setPosition((long) mouseTracker.getMouseX(), 0);
p.setPosition((long) mouseTracker.getMouseY(), 1);
}
}

Expand Down Expand Up @@ -434,16 +434,9 @@ public void removeTransformListener(final TransformListener<AffineTransform3D> l
}
}

/**
* Shutdown the {@link ExecutorService} used for rendering tiles onto the screen.
*/
public void stop() {

renderingExecutorService.shutdown();
}

private static final AtomicInteger panelNumber = new AtomicInteger(1);

// TODO: rendering
protected class RenderThreadFactory implements ThreadFactory {

private final String threadNameFormat;
Expand Down Expand Up @@ -547,10 +540,6 @@ public double[] getScreenScales() {
*/
public OverlayPane<?> getDisplay() {

var pos = new ObservablePosition(0, 0);
pos.getX();
pos.setX(0.0);

return this.overlayPane;
}

Expand Down
132 changes: 50 additions & 82 deletions src/main/java/bdv/fx/viewer/project/VolatileHierarchyProjector.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@
import net.imglib2.converter.Converter;
import net.imglib2.img.array.ArrayImgs;
import net.imglib2.loops.LoopBuilder;
import net.imglib2.parallel.Parallelization;
import net.imglib2.parallel.TaskExecutor;
import net.imglib2.parallel.TaskExecutors;
import net.imglib2.type.numeric.integer.ByteType;
import net.imglib2.type.operators.SetZero;
import net.imglib2.util.Intervals;
Expand All @@ -54,7 +52,6 @@
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
Expand Down Expand Up @@ -120,15 +117,10 @@ public class VolatileHierarchyProjector<A extends Volatile<?>, B extends SetZero
*/
protected final IterableInterval<B> iterableTarget;

/**
* Number of threads to use for rendering
*/
private final int numThreads;

/**
* Executor service to be used for rendering
*/
private final ExecutorService executorService;
protected final TaskExecutor taskExecutor;

/**
* Time needed for rendering the last frame, in nano-seconds.
Expand Down Expand Up @@ -156,19 +148,17 @@ public VolatileHierarchyProjector(
final List<? extends RandomAccessible<A>> sources,
final Converter<? super A, B> converter,
final RandomAccessibleInterval<B> target,
final int numThreads,
final ExecutorService executorService) {
final TaskExecutor taskExecutor) {

this(sources, converter, target, ArrayImgs.bytes(Intervals.dimensionsAsLongArray(target)), numThreads, executorService);
this(sources, converter, target, ArrayImgs.bytes(Intervals.dimensionsAsLongArray(target)), taskExecutor);
}

public VolatileHierarchyProjector(
final List<? extends RandomAccessible<A>> sources,
final Converter<? super A, B> converter,
final RandomAccessibleInterval<B> target,
final RandomAccessibleInterval<ByteType> mask,
final int numThreads,
final ExecutorService executorService) {
final TaskExecutor taskExecutor) {

this.converter = converter;
this.target = target;
Expand All @@ -187,8 +177,7 @@ public VolatileHierarchyProjector(
max[1] = target.max(1);
sourceInterval = new FinalInterval(min, max);

this.numThreads = numThreads;
this.executorService = executorService;
this.taskExecutor = taskExecutor;

lastFrameRenderNanoTime = -1;
clearMask();
Expand Down Expand Up @@ -224,7 +213,7 @@ public boolean isValid() {
public void clearMask() {

try {
LoopBuilder.setImages(mask).multiThreaded().forEachPixel(val -> val.set(Byte.MAX_VALUE));
LoopBuilder.setImages(mask).multiThreaded(taskExecutor).forEachPixel(val -> val.set(Byte.MAX_VALUE));
} catch (RuntimeException e) {
if (!e.getMessage().contains("Interrupted")) {
throw e;
Expand All @@ -238,20 +227,28 @@ public void clearMask() {
*/
private void clearUntouchedTargetPixels() {

final int[] data = ProjectorUtils.getARGBArrayImgData(target);
if (data != null) {
final Cursor<ByteType> maskCursor = Views.iterable(mask).cursor();
final int size = (int) Intervals.numElements(target);
for (int i = 0; i < size; ++i) {
if (maskCursor.next().get() == Byte.MAX_VALUE)
data[i] = 0;
}
} else {
final Cursor<ByteType> maskCursor = Views.iterable(mask).cursor();
for (final B t : iterableTarget) {
if (maskCursor.next().get() == Byte.MAX_VALUE)
t.setZero();
}
if (true) return;
try {
taskExecutor.getExecutorService().invokeAll(List.of(Executors.callable(() -> {
//TODO: Rendering; use loopbuilder and multithreading; Also just do this during the map task?
final int[] data = ProjectorUtils.getARGBArrayImgData(target);
if (data != null) {
final Cursor<ByteType> maskCursor = Views.iterable(mask).cursor();
final int size = (int) Intervals.numElements(target);
for (int i = 0; i < size; ++i) {
if (maskCursor.next().get() == Byte.MAX_VALUE)
data[i] = 0;
}
} else {
final Cursor<ByteType> maskCursor = Views.iterable(mask).cursor();
for (final B t : iterableTarget) {
if (maskCursor.next().get() == Byte.MAX_VALUE)
t.setZero();
}
}
}, null)));
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}

Expand All @@ -267,53 +264,37 @@ public boolean map(final boolean clearUntouchedTargetPixels) {
final long startTimeIo = iostat.getIoNanoTime();
final long startTimeIoCumulative = iostat.getCumulativeIoNanoTime();

final int targetHeight = (int) target.dimension(1);
final int numTasks = 1; //numThreads <= 1 ? 1 : Math.min(numThreads * 10, targetHeight);
final double taskHeight = (double) targetHeight / numTasks;
final int[] taskStartHeights = new int[numTasks + 1];
for (int i = 0; i < numTasks; ++i) {
taskStartHeights[i] = (int) (i * taskHeight);
}
taskStartHeights[numTasks] = targetHeight;

valid = false;

final boolean createExecutor = (executorService == null);
final ExecutorService ex = createExecutor ? Executors.newFixedThreadPool(numThreads) : executorService;
int resolutionLevel;
try {
/*
* After the for loop, resolutionLevel is the highest (coarsest)
* resolution for which all pixels could be filled from valid data. This
* means that in the next pass, i.e., map() call, levels up to
* resolutionLevel have to be re-rendered.
*/
for (resolutionLevel = 0; resolutionLevel < numInvalidLevels && !valid; ++resolutionLevel) {
final List<Callable<Void>> tasks = new ArrayList<>(numTasks);
valid = true;
numInvalidPixels.set(0);
for (int i = 0; i < numTasks; ++i) {
tasks.add(createMapTask((byte) resolutionLevel, taskStartHeights[i], taskStartHeights[i + 1]));
}
try {
ex.invokeAll(tasks);
} catch (final InterruptedException e) {
Thread.currentThread().interrupt();
}
if (canceled.get())
return false;
/*
* After the for loop, resolutionLevel is the highest (coarsest)
* resolution for which all pixels could be filled from valid data. This
* means that in the next pass, i.e., map() call, levels up to
* resolutionLevel have to be re-rendered.
*/
for (resolutionLevel = 0; resolutionLevel < numInvalidLevels && !valid; ++resolutionLevel) {
final List<Callable<Void>> tasks = new ArrayList<>();
valid = true;
numInvalidPixels.set(0);
final byte idx = (byte) resolutionLevel;
tasks.add(Executors.callable(() -> map(idx, -1, -1), null));

try {
taskExecutor.getExecutorService().invokeAll(tasks);
} catch (final InterruptedException e) {
canceled.set(true);
}
} finally {
if (createExecutor)
ex.shutdown();
if (canceled.get())
return false;
}

if (clearUntouchedTargetPixels && !canceled.get())
clearUntouchedTargetPixels();

final long lastFrameTime = stopWatch.nanoTime();
lastFrameIoNanoTime = iostat.getIoNanoTime() - startTimeIo;
lastFrameRenderNanoTime = lastFrameTime - (iostat.getCumulativeIoNanoTime() - startTimeIoCumulative) / numThreads;
lastFrameRenderNanoTime = lastFrameTime - (iostat.getCumulativeIoNanoTime() - startTimeIoCumulative) / taskExecutor.getParallelism();

if (valid)
numInvalidLevels = resolutionLevel - 1;
Expand All @@ -322,15 +303,6 @@ public boolean map(final boolean clearUntouchedTargetPixels) {
return !canceled.get();
}

/**
* @return a {@code Callable} that runs
* {@code map(resolutionIndex, startHeight, endHeight)}
*/
private Callable<Void> createMapTask(final byte resolutionIndex, final int startHeight, final int endHeight) {

return Executors.callable(() -> map(resolutionIndex, startHeight, endHeight), null);
}

/**
* Copy lines from {@code y = startHeight} up to {@code endHeight}
* (exclusive) from source {@code resolutionIndex} to target. Check after
Expand All @@ -355,18 +327,14 @@ protected void map(final byte resolutionIndex, final int startHeight, final int

final AtomicInteger myNumInvalidPixels = new AtomicInteger();

final TaskExecutor projectorExecutor = TaskExecutors.fixedThreadPool(Math.max(1, (Runtime.getRuntime().availableProcessors() / 3) - 1));

LoopBuilder.setImages(
Views.interval(new BundleView<>(target), sourceInterval),
Views.interval(new BundleView<>(sources.get(resolutionIndex)), sourceInterval),
Views.interval(mask, sourceInterval)
).multiThreaded(projectorExecutor)
).multiThreaded(taskExecutor)
.forEachChunk(chunk -> {
if (canceled.get()) {
if (!projectorExecutor.getExecutorService().isShutdown()) {
projectorExecutor.getExecutorService().shutdown();
}
Thread.currentThread().interrupt();
return null;
}
chunk.forEachPixel((targetVal, sourceVal, maskVal) -> {
Expand Down
Loading

0 comments on commit ed2ca94

Please sign in to comment.