Skip to content

Commit

Permalink
Merge pull request #542 from saalfeldlab/1.3.3
Browse files Browse the repository at this point in the history
1.3.3
  • Loading branch information
cmhulbert authored Jun 24, 2024
2 parents 6c47668 + 9f14173 commit 78a086c
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 48 deletions.
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@

<!-- JavaFx Version-->
<javafx.version>22.0.1</javafx.version>
<saalfx.version>1.4.1</saalfx.version>
<saalfx.version>1.4.2</saalfx.version>
<testfx.version>4.0.16-alpha</testfx.version>

<alphanumeric-comparator.version>1.4.1</alphanumeric-comparator.version>
Expand Down
23 changes: 23 additions & 0 deletions src/main/kotlin/org/janelia/saalfeldlab/paintera/BindingKeys.kt
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,29 @@ enum class LabelSourceStateKeys(lateInitNamedKeyCombo : LateInitNamedKeyCombinat
}
}

enum class RawSourceStateKeys(lateInitNamedKeyCombo : LateInitNamedKeyCombination) : NamedKeyBinding by lateInitNamedKeyCombo {
RESET_MIN_MAX_INTENSITY_THRESHOLD ( SHIFT_DOWN + Y, "Reset Min / Max Intensity Threshold"),
AUTO_MIN_MAX_INTENSITY_THRESHOLD ( Y, "Auto Min / Max Intensity Threshold"),
;


private val formattedName = name.lowercase()
.replace("__", ": ")
.replace("_", " ")

constructor(keys : KeyCombination, name : String? = null) : this(LateInitNamedKeyCombination(keys, name))
constructor(key : KeyCode, name : String? = null) : this(LateInitNamedKeyCombination(key.asCombination(), name))
constructor(key : Modifier, name : String? = null) : this(LateInitNamedKeyCombination(key.asCombination(), name))

init {
lateInitNamedKeyCombo.setName(formattedName)
}

companion object {
fun namedCombinationsCopy() = NamedKeyCombination.CombinationMap(*entries.map { it.deepCopy }.toTypedArray())
}
}

object NavigationKeys {
const val BUTTON_TRANSLATE_ALONG_NORMAL_FORWARD = "translate along normal forward"
const val BUTTON_TRANSLATE_ALONG_NORMAL_FORWARD_FAST = "translate along normal forward fast"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import org.janelia.saalfeldlab.fx.actions.painteraActionSet
import org.janelia.saalfeldlab.fx.ortho.OrthogonalViews
import org.janelia.saalfeldlab.fx.ui.ScaleView
import org.janelia.saalfeldlab.net.imglib2.converter.ARGBColorConverter
import org.janelia.saalfeldlab.paintera.RawSourceStateKeys
import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions
import org.janelia.saalfeldlab.paintera.control.tools.Tool
import org.janelia.saalfeldlab.paintera.paintera
Expand All @@ -46,14 +47,14 @@ object RawSourceMode : AbstractToolMode() {

private val minMaxIntensityThreshold = painteraActionSet("Min/Max Intensity Threshold") {
verifyAll(KEY_PRESSED, "Source State is Raw Source State ") { activeSourceStateProperty.get() is ConnectomicsRawState<*, *> }
KEY_PRESSED(KeyCode.SHIFT, KeyCode.Y) {
KEY_PRESSED(RawSourceStateKeys.RESET_MIN_MAX_INTENSITY_THRESHOLD) {
graphic = { ScaleView().apply { styleClass += "intensity-reset-min-max" } }
onAction {
val rawSource = activeSourceStateProperty.get() as ConnectomicsRawState<*, *>
resetIntensityMinMax(rawSource)
}
}
KEY_PRESSED(KeyCode.Y) {
KEY_PRESSED(RawSourceStateKeys.AUTO_MIN_MAX_INTENSITY_THRESHOLD) {
lateinit var viewer: ViewerPanelFX
graphic = { ScaleView().apply { styleClass += "intensity-auto-min-max" } }
verify("Last focused viewer found") { paintera.baseView.lastFocusHolder.value?.viewer()?.also { viewer = it } != null }
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package org.janelia.saalfeldlab.paintera.control.modes

import bdv.fx.viewer.render.RenderUnitState
import bdv.util.BdvFunctions
import de.jensd.fx.glyphs.fontawesome.FontAwesomeIconView
import io.github.oshai.kotlinlogging.KotlinLogging
import javafx.beans.value.ChangeListener
Expand All @@ -18,39 +17,45 @@ import net.imglib2.loops.LoopBuilder
import net.imglib2.realtransform.AffineTransform3D
import net.imglib2.type.logic.BoolType
import net.imglib2.type.numeric.IntegerType
import net.imglib2.type.numeric.integer.UnsignedIntType
import net.imglib2.type.numeric.integer.UnsignedLongType
import net.imglib2.type.volatiles.VolatileUnsignedIntType
import net.imglib2.util.Intervals
import net.imglib2.view.IntervalView
import net.imglib2.view.Views
import org.janelia.saalfeldlab.control.mcu.MCUButtonControl
import org.janelia.saalfeldlab.fx.actions.*
import org.janelia.saalfeldlab.fx.actions.ActionSet
import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.installActionSet
import org.janelia.saalfeldlab.fx.actions.ActionSet.Companion.removeActionSet
import org.janelia.saalfeldlab.fx.actions.NamedKeyBinding
import org.janelia.saalfeldlab.fx.actions.painteraActionSet
import org.janelia.saalfeldlab.fx.actions.painteraMidiActionSet
import org.janelia.saalfeldlab.fx.extensions.addTriggeredWithListener
import org.janelia.saalfeldlab.fx.midi.MidiButtonEvent
import org.janelia.saalfeldlab.fx.midi.MidiToggleEvent
import org.janelia.saalfeldlab.fx.midi.ToggleAction
import org.janelia.saalfeldlab.fx.ortho.OrthogonalViews
import org.janelia.saalfeldlab.fx.ui.GlyphScaleView
import org.janelia.saalfeldlab.fx.util.InvokeOnJavaFXApplicationThread
import org.janelia.saalfeldlab.labels.Label
import org.janelia.saalfeldlab.net.imglib2.view.BundleView
import org.janelia.saalfeldlab.paintera.DeviceManager
import org.janelia.saalfeldlab.paintera.LabelSourceStateKeys.*
import org.janelia.saalfeldlab.paintera.cache.HashableTransform.Companion.hashable
import org.janelia.saalfeldlab.paintera.cache.SamEmbeddingLoaderCache
import org.janelia.saalfeldlab.paintera.cache.SamEmbeddingLoaderCache.calculateTargetSamScreenScaleFactor
import org.janelia.saalfeldlab.paintera.control.ShapeInterpolationController
import org.janelia.saalfeldlab.paintera.control.ShapeInterpolationController.ControllerState.Moving
import org.janelia.saalfeldlab.paintera.control.ShapeInterpolationController.EditSelectionChoice
import org.janelia.saalfeldlab.paintera.control.actions.AllowedActions
import org.janelia.saalfeldlab.paintera.control.actions.MenuActionType
import org.janelia.saalfeldlab.paintera.control.actions.NavigationActionType
import org.janelia.saalfeldlab.paintera.control.actions.PaintActionType
import org.janelia.saalfeldlab.paintera.cache.SamEmbeddingLoaderCache
import org.janelia.saalfeldlab.paintera.cache.SamEmbeddingLoaderCache.calculateTargetSamScreenScaleFactor
import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask
import org.janelia.saalfeldlab.paintera.control.paint.ViewerMask.Companion.createViewerMask
import org.janelia.saalfeldlab.paintera.control.tools.Tool
import org.janelia.saalfeldlab.paintera.control.tools.paint.*
import org.janelia.saalfeldlab.paintera.control.tools.paint.Fill2DTool
import org.janelia.saalfeldlab.paintera.control.tools.paint.PaintBrushTool
import org.janelia.saalfeldlab.paintera.control.tools.paint.SamPredictor
import org.janelia.saalfeldlab.paintera.control.tools.paint.SamTool
import org.janelia.saalfeldlab.paintera.control.tools.shapeinterpolation.ShapeInterpolationFillTool
import org.janelia.saalfeldlab.paintera.control.tools.shapeinterpolation.ShapeInterpolationPaintBrushTool
import org.janelia.saalfeldlab.paintera.control.tools.shapeinterpolation.ShapeInterpolationSAMTool
Expand All @@ -59,8 +64,6 @@ import org.janelia.saalfeldlab.paintera.data.mask.MaskInfo
import org.janelia.saalfeldlab.paintera.data.mask.MaskedSource
import org.janelia.saalfeldlab.paintera.paintera
import org.janelia.saalfeldlab.util.*
import org.janelia.saalfeldlab.net.imglib2.view.BundleView
import kotlin.collections.forEach
import kotlin.collections.set

class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolationController<D>, private val previousMode: ControlMode) : AbstractToolMode() {
Expand Down Expand Up @@ -165,7 +168,7 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat
with(controller) {
verifyAll(KEY_PRESSED, "Shape Interpolation Controller is Active ") { isControllerActive }
verifyAll(Event.ANY, "Shape Interpolation Tool is Active") { activeTool is ShapeInterpolationTool }
val exitMode = { _ : Event? ->
val exitMode = { _: Event? ->
exitShapeInterpolation(false)
paintera.baseView.changeMode(previousMode)
}
Expand All @@ -190,7 +193,12 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat
filter = true
verify("Fill2DTool is active") { activeTool is Fill2DTool }
onAction {
switchTool(shapeInterpolationTool)
fill2DTool.fillIsRunningProperty.addTriggeredWithListener { obs, _, isRunning ->
if (!isRunning) {
switchTool(shapeInterpolationTool)
obs?.removeListener(this)
}
}
}
}
},
Expand All @@ -202,7 +210,7 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat
}
KEY_PRESSED(fill2DTool.keyTrigger) {
name = "switch to fill2d tool"
verify("Active source is MaskedSource") { activeSourceStateProperty.get()?.dataSource is MaskedSource<*, *> }
verify("Active source is MaskedSource") { activeSourceStateProperty.get()?.dataSource is MaskedSource<*, *> }
onAction { switchTool(fill2DTool) }
}
KEY_PRESSED(samTool.keyTrigger) {
Expand Down Expand Up @@ -343,7 +351,7 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat

private fun ActionSet.keyPressEditSelectionAction(choice: EditSelectionChoice, namedKey: NamedKeyBinding) =
with(controller) {
KEY_PRESSED ( namedKey) {
KEY_PRESSED(namedKey) {
graphic = when (choice) {
EditSelectionChoice.First -> {
{ GlyphScaleView(FontAwesomeIconView().also { it.styleClass += "interpolation-first-slice" }) }
Expand Down Expand Up @@ -398,7 +406,7 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat
.toList()

val renderState = RenderUnitState(mask.initialGlobalToViewerTransform.copy(), mask.info.time, sources, width.toLong(), height.toLong())
val predictionRequest = SamPredictor.SparsePrediction(maxDistancePositions.map { (x,y) -> renderState.getSamPoint(x,y, SamPredictor.SparseLabel.IN) })
val predictionRequest = SamPredictor.SparsePrediction(maxDistancePositions.map { (x, y) -> renderState.getSamPoint(x, y, SamPredictor.SparseLabel.IN) })

SamSliceInfo(renderState, mask, predictionRequest, null, false).also {
SamEmbeddingLoaderCache.load(renderState)
Expand Down Expand Up @@ -465,7 +473,7 @@ class ShapeInterpolationMode<D : IntegerType<D>>(val controller: ShapeInterpolat
selectionIntervalOverMask: Interval,
globalTransform: AffineTransform3D = paintera.baseView.manager().transform,
viewerMask: ViewerMask = controller.currentViewerMask!!,
replaceExistingSlice : Boolean = false
replaceExistingSlice: Boolean = false
): SamSliceInfo? {
val globalToViewerTransform = viewerMask.initialGlobalToMaskTransform
val sliceDepth = controller.depthAt(globalToViewerTransform)
Expand Down Expand Up @@ -527,7 +535,7 @@ internal fun IntervalView<UnsignedLongType>.getComponentMaxDistancePosition(): L

DistanceTransform.binaryTransform(invertedBinaryImg, distances, DistanceTransform.DISTANCE_TYPE.EUCLIDIAN)

val distancePerComponent = mutableMapOf<Int, Pair<Double, LongArray>>()
val distancePerComponent = mutableMapOf<Int, Pair<Double, LongArray>>()

var backgroundId = -1;

Expand Down Expand Up @@ -601,7 +609,7 @@ internal data class SamSliceInfo(val renderState: RenderUnitState, val mask: Vie
prediction = SamPredictor.SparsePrediction(listOf(renderState.getSamPoint(viewerX, viewerY, label)))
}

fun updatePrediction(viewerPositions : List<DoubleArray>, label: SamPredictor.SparseLabel = SamPredictor.SparseLabel.IN) {
prediction = SamPredictor.SparsePrediction(viewerPositions.map { (x,y) -> renderState.getSamPoint(x,y, label) })
fun updatePrediction(viewerPositions: List<DoubleArray>, label: SamPredictor.SparseLabel = SamPredictor.SparseLabel.IN) {
prediction = SamPredictor.SparsePrediction(viewerPositions.map { (x, y) -> renderState.getSamPoint(x, y, label) })
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import javafx.beans.property.SimpleObjectProperty
import javafx.beans.value.ObservableValue
import javafx.scene.Cursor
import javafx.scene.input.*
import javafx.util.Subscription
import net.imglib2.Interval
import net.imglib2.util.Intervals
import org.janelia.saalfeldlab.fx.UtilityTask
Expand Down Expand Up @@ -110,14 +111,16 @@ open class Fill2DTool(activeSourceStateProperty: SimpleObjectProperty<SourceStat
graphic = { GlyphScaleView(FontAwesomeIconView().apply { styleClass += "reject" }).apply { styleClass += "ignore-disable"} }
filter = true
onAction {
fillTask?.run { if (!isCancelled) cancel() } ?: mode?.switchTool(mode.defaultTool)
cancelFloodFill() ?: mode?.switchTool(mode.defaultTool)
}
}
}
)
}

private val fillIsRunningProperty = SimpleBooleanProperty(false, "Fill2D is Running")
fun cancelFloodFill() = fillTask?.run { if (!isCancelled) cancel() }

val fillIsRunningProperty = SimpleBooleanProperty(false, "Fill2D is Running")

internal fun executeFill2DAction(x: Double, y: Double, afterFill: (Interval) -> Unit = {}): UtilityTask<*>? {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import javafx.scene.layout.Pane
import javafx.scene.shape.Circle
import javafx.scene.shape.Rectangle
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import net.imglib2.FinalInterval
import net.imglib2.Interval
import net.imglib2.RandomAccessibleInterval
Expand Down Expand Up @@ -92,7 +93,6 @@ import org.janelia.saalfeldlab.paintera.util.IntervalHelpers.Companion.smallestC
import org.janelia.saalfeldlab.paintera.util.algorithms.otsuThresholdPrediction
import org.janelia.saalfeldlab.util.*
import java.util.concurrent.CancellationException
import java.util.concurrent.LinkedBlockingQueue
import kotlin.collections.List
import kotlin.collections.MutableList
import kotlin.collections.addAll
Expand Down Expand Up @@ -204,7 +204,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
originalWritableVolatileBackingImage = field?.volatileViewerImg?.writableSource
}

private var predictionJob : Job = Job().apply { complete() }
private var predictionJob: Job = Job().apply { complete() }

internal val lastPredictionProperty = SimpleObjectProperty<SamTaskInfo?>(null)
var lastPrediction by lastPredictionProperty.nullable()
Expand Down Expand Up @@ -232,13 +232,13 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*

private var screenScale by Delegates.notNull<Double>()

private val predictionQueue = LinkedBlockingQueue<Pair<SamPredictor.PredictionRequest, Boolean>>(1)
private val predictionChannel = Channel<Pair<SamPredictor.PredictionRequest, Boolean>>(1)

private var currentPredictionRequest: Pair<SamPredictor.PredictionRequest, Boolean>? = null
set(value) = synchronized(predictionQueue) {
predictionQueue.clear()
set(value) = runBlocking {
predictionChannel.tryReceive() /* capacity 1, so this will always either do nothing, or empty the channel */
value?.let { (request, _) ->
predictionQueue.put(value)
predictionChannel.send(value)
if (!temporaryPrompt)
request.drawPrompt()
}
Expand Down Expand Up @@ -297,6 +297,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
clearPromptDrawings()
currentLabelToPaint = Label.INVALID
predictionJob.cancel()
predictionChannel.tryReceive() /*clear the channel if not empty */
if (unwrapResult) {
if (!maskProvided) {
maskedSource?.resetMasks()
Expand Down Expand Up @@ -560,7 +561,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
KEY_PRESSED(CANCEL) {
name = "exit SAM tool"
graphic = { GlyphScaleView(FontAwesomeIconView().apply { styleClass += "reject" }).apply { styleClass += "ignore-disable" } }
onAction { cancelAction() }
onAction { mode?.switchTool(mode.defaultTool) }
}
},

Expand Down Expand Up @@ -651,10 +652,6 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
)
}

protected open fun cancelAction() {
mode?.switchTool(mode.defaultTool)
}

private fun resetPromptAndPrediction() {
clearPromptDrawings()
currentPredictionRequest = null
Expand Down Expand Up @@ -797,7 +794,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
maskedSource.applyMask(currentMask, sourceInterval.smallestContainingInterval, MaskedSource.VALID_LABEL_CHECK)
viewerMask = null
} else {
val predictionMaxInterval = originalWritableBackingImage!!.intersect( maskInterval)
val predictionMaxInterval = originalWritableBackingImage!!.intersect(maskInterval)
LoopBuilder
.setImages(originalWritableBackingImage!!.interval(predictionMaxInterval), currentMask.viewerImg.wrappedSource.interval(predictionMaxInterval))
.multiThreaded()
Expand Down Expand Up @@ -829,6 +826,9 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
}

open fun requestPrediction(promptPoints: List<SamPoint>, estimateThreshold: Boolean = true) {
if (promptPoints.isEmpty())
temporaryPrompt = true

if (!predictionJob.isActive) {
startPredictionJob()
}
Expand All @@ -845,9 +845,24 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
private var embeddingRequest: Deferred<OnnxTensor>? = null

private var currentPrediction: SamPredictor.SamPrediction? = null

private val resetSAMTaskOnException = CoroutineExceptionHandler { _, exception ->
LOG.error(exception) { "Error during SAM Prediction " }
isBusy = false
deactivate()
mode?.apply {
InvokeOnJavaFXApplicationThread {
switchTool(defaultTool)
}
}
SAM_TASK_SCOPE = CoroutineScope(Dispatchers.IO + Job())
}

protected open var currentDisplay = false

private fun startPredictionJob() {
val maskSource = maskedSource ?: return
predictionJob = SAM_TASK_SCOPE.launch {
predictionJob = SAM_TASK_SCOPE.launch(resetSAMTaskOnException) {
val session = createOrtSessionTask.get()
val imageEmbedding = try {
runBlocking {
Expand All @@ -863,9 +878,8 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
isBusy = false
}
val predictor = SamPredictor(ortEnv, session, imageEmbedding, imgWidth to imgHeight)
while (coroutineContext.isActive) {
val predictionPair = predictionQueue.take()
val (predictionRequest, estimateThreshold) = predictionPair
while (predictionJob.isActive) {
val (predictionRequest, estimateThreshold) = predictionChannel.receive()
val points = (predictionRequest as SparsePrediction).points

val newPredictionRequest = estimateThreshold || currentPrediction == null
Expand Down Expand Up @@ -929,7 +943,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*
} catch (e: InterruptedException) {
System.setErr(stdErr)
LOG.debug(e) { "Connected Components Interrupted During SAM" }
cancel("Connected Components Interrupted During SAM" )
cancel("Connected Components Interrupted During SAM")
continue
} finally {
System.setErr(stdErr)
Expand Down Expand Up @@ -1134,7 +1148,7 @@ open class SamTool(activeSourceStateProperty: SimpleObjectProperty<SourceState<*

private val LOG = KotlinLogging.logger { }

internal val SAM_TASK_SCOPE = CoroutineScope(Dispatchers.IO + Job())
private var SAM_TASK_SCOPE = CoroutineScope(Dispatchers.IO + Job())


private fun calculateTargetScreenScaleFactor(viewer: ViewerPanelFX): Double {
Expand Down
Loading

0 comments on commit 78a086c

Please sign in to comment.