Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow supplying custom workflow yaml for infer_with_model job #7902

Merged
merged 14 commits into from
Aug 5, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions app/controllers/AiModelController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ case class RunInferenceParameters(annotationId: Option[ObjectId],
datasetName: String,
colorLayerName: String,
boundingBox: String,
newSegmentationLayerName: String,
newDatasetName: String,
maskAnnotationLayerName: Option[String])
maskAnnotationLayerName: Option[String],
workflowYaml: Option[String])

object RunInferenceParameters {
implicit val jsonFormat: OFormat[RunInferenceParameters] = Json.format[RunInferenceParameters]
Expand Down Expand Up @@ -168,7 +168,6 @@ class AiModelController @Inject()(
_ <- aiModelDAO.findOne(request.body.aiModelId) ?~> "aiModel.notFound"
_ <- datasetService.assertValidDatasetName(request.body.newDatasetName)
_ <- datasetService.assertNewDatasetName(request.body.newDatasetName, organization._id)
_ <- datasetService.assertValidLayerNameLax(request.body.newSegmentationLayerName)
jobCommand = JobCommand.infer_with_model
boundingBox <- BoundingBox.fromLiteral(request.body.boundingBox).toFox
commandArgs = Json.obj(
Expand All @@ -177,8 +176,8 @@ class AiModelController @Inject()(
"color_layer_name" -> request.body.colorLayerName,
"bounding_box" -> boundingBox.toLiteral,
"model_id" -> request.body.aiModelId,
"new_segmentation_layer_name" -> request.body.newSegmentationLayerName,
"new_dataset_name" -> request.body.newDatasetName
"new_dataset_name" -> request.body.newDatasetName,
"workflow_yaml" -> request.body.workflowYaml
)
newInferenceJob <- jobService.submitJob(jobCommand, commandArgs, request.identity, dataStore.name) ?~> "job.couldNotRunInferWithModel"
newAiInference = AiInference(
Expand All @@ -189,7 +188,7 @@ class AiModelController @Inject()(
_annotation = request.body.annotationId,
boundingBox = boundingBox,
_inferenceJob = newInferenceJob._id,
newSegmentationLayerName = request.body.newSegmentationLayerName,
newSegmentationLayerName = "segmentation",
maskAnnotationLayerName = request.body.maskAnnotationLayerName
)
_ <- aiInferenceDAO.insertOne(newAiInference)
Expand Down
6 changes: 0 additions & 6 deletions app/controllers/JobController.scala
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ class JobController @Inject()(
datasetName: String,
layerName: String,
bbox: String,
outputSegmentationLayerName: String,
newDatasetName: String): Action[AnyContent] =
sil.SecuredAction.async { implicit request =>
log(Some(slackNotificationService.noticeFailedJobRequest)) {
Expand All @@ -234,7 +233,6 @@ class JobController @Inject()(
"dataset.notFound",
datasetName) ~> NOT_FOUND
_ <- datasetService.assertValidDatasetName(newDatasetName)
_ <- datasetService.assertValidLayerNameLax(outputSegmentationLayerName)
_ <- datasetService.assertValidLayerNameLax(layerName)
multiUser <- multiUserDAO.findOne(request.identity._multiUser)
_ <- Fox.runIf(!multiUser.isSuperUser)(jobService.assertBoundingBoxLimits(bbox, None))
Expand All @@ -244,7 +242,6 @@ class JobController @Inject()(
"dataset_name" -> datasetName,
"new_dataset_name" -> newDatasetName,
"layer_name" -> layerName,
"output_segmentation_layer_name" -> outputSegmentationLayerName,
"bbox" -> bbox,
)
job <- jobService.submitJob(command, commandArgs, request.identity, dataset._dataStore) ?~> "job.couldNotRunNeuronInferral"
Expand All @@ -257,7 +254,6 @@ class JobController @Inject()(
datasetName: String,
layerName: String,
bbox: String,
outputSegmentationLayerName: String,
newDatasetName: String): Action[AnyContent] =
sil.SecuredAction.async { implicit request =>
log(Some(slackNotificationService.noticeFailedJobRequest)) {
Expand All @@ -269,7 +265,6 @@ class JobController @Inject()(
"dataset.notFound",
datasetName) ~> NOT_FOUND
_ <- datasetService.assertValidDatasetName(newDatasetName)
_ <- datasetService.assertValidLayerNameLax(outputSegmentationLayerName)
_ <- datasetService.assertValidLayerNameLax(layerName)
multiUser <- multiUserDAO.findOne(request.identity._multiUser)
_ <- bool2Fox(multiUser.isSuperUser) ?~> "job.inferMitochondria.notAllowed.onlySuperUsers"
Expand All @@ -280,7 +275,6 @@ class JobController @Inject()(
"dataset_name" -> datasetName,
"new_dataset_name" -> newDatasetName,
"layer_name" -> layerName,
"output_segmentation_layer_name" -> outputSegmentationLayerName,
"bbox" -> bbox,
)
job <- jobService.submitJob(command, commandArgs, request.identity, dataset._dataStore) ?~> "job.couldNotRunInferMitochondria"
Expand Down
4 changes: 2 additions & 2 deletions conf/webknossos.latest.routes
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ POST /jobs/run/computeMeshFile/:organizationName/:datasetName
POST /jobs/run/computeSegmentIndexFile/:organizationName/:datasetName controllers.JobController.runComputeSegmentIndexFileJob(organizationName: String, datasetName: String, layerName: String)
POST /jobs/run/exportTiff/:organizationName/:datasetName controllers.JobController.runExportTiffJob(organizationName: String, datasetName: String, bbox: String, layerName: Option[String], mag: Option[String], annotationLayerName: Option[String], annotationId: Option[String], asOmeTiff: Boolean)
POST /jobs/run/inferNuclei/:organizationName/:datasetName controllers.JobController.runInferNucleiJob(organizationName: String, datasetName: String, layerName: String, newDatasetName: String)
POST /jobs/run/inferNeurons/:organizationName/:datasetName controllers.JobController.runInferNeuronsJob(organizationName: String, datasetName: String, layerName: String, bbox: String, outputSegmentationLayerName: String, newDatasetName: String)
POST /jobs/run/inferMitochondria/:organizationName/:datasetName controllers.JobController.runInferMitochondriaJob(organizationName: String, datasetName: String, layerName: String, bbox: String, outputSegmentationLayerName: String, newDatasetName: String)
POST /jobs/run/inferNeurons/:organizationName/:datasetName controllers.JobController.runInferNeuronsJob(organizationName: String, datasetName: String, layerName: String, bbox: String, newDatasetName: String)
POST /jobs/run/inferMitochondria/:organizationName/:datasetName controllers.JobController.runInferMitochondriaJob(organizationName: String, datasetName: String, layerName: String, bbox: String, newDatasetName: String)
POST /jobs/run/alignSections/:organizationName/:datasetName controllers.JobController.runAlignSectionsJob(organizationName: String, datasetName: String, layerName: String, newDatasetName: String)
POST /jobs/run/materializeVolumeAnnotation/:organizationName/:datasetName controllers.JobController.runMaterializeVolumeAnnotationJob(organizationName: String, datasetName: String, fallbackLayerName: String, annotationId: String, annotationType: String, newDatasetName: String, outputSegmentationLayerName: String, mergeSegments: Boolean, volumeLayerName: Option[String])
POST /jobs/run/findLargestSegmentId/:organizationName/:datasetName controllers.JobController.runFindLargestSegmentIdJob(organizationName: String, datasetName: String, layerName: String)
Expand Down
1 change: 1 addition & 0 deletions frontend/javascripts/admin/api/jobs.ts
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ type RunInferenceParameters = {
boundingBox: Vector6;
newSegmentationLayerName: string;
newDatasetName: string;
workflowYaml?: string;
// maskAnnotationLayerName?: string | null
};

Expand Down
24 changes: 24 additions & 0 deletions frontend/javascripts/oxalis/view/action-bar/default-workflow.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
export default `predict:
MichaelBuessemeyer marked this conversation as resolved.
Show resolved Hide resolved
task: PredictTask
distribution:
default:
processes: 2
inputs:
model: TO_BE_SET_BY_WORKER
config:
name: predict
datasource_config: TO_BE_SET_BY_WORKER
# your additional config keys here

# your additional tasks here

publish_dataset_meshes:
task: PublishDatasetTask
inputs:
dataset: # your dataset here
config:
name: TO_BE_SET_BY_WORKER
public_directory: TO_BE_SET_BY_WORKER
webknossos_organization: TO_BE_SET_BY_WORKER
use_symlinks: False
move_dataset_symlink_artifact: True`;
Original file line number Diff line number Diff line change
Expand Up @@ -55,10 +55,11 @@ import { isBoundingBoxExportable } from "./download_modal_view";
import features from "features";
import { setAIJobModalStateAction } from "oxalis/model/actions/ui_actions";
import { InfoCircleOutlined } from "@ant-design/icons";
import { TrainAiModelTab } from "../jobs/train_ai_model";
import { TrainAiModelTab, CollapsibleWorkflowYamlEditor } from "../jobs/train_ai_model";
import { LayerSelectionFormItem } from "components/layer_selection";
import { useGuardedFetch } from "libs/react_helpers";
import _ from "lodash";
import DefaultPredictWorkflow from "./default-workflow";
MichaelBuessemeyer marked this conversation as resolved.
Show resolved Hide resolved

const { ThinSpace } = Unicode;

Expand Down Expand Up @@ -99,6 +100,7 @@ type JobApiCallArgsType = {
selectedLayer: APIDataLayer;
outputSegmentationLayerName?: string;
selectedBoundingBox: UserBoundingBox | null | undefined;
useCustomWorkflow?: boolean;
};
type StartJobFormProps = Props & {
jobApiCall: (arg0: JobApiCallArgsType, form: FormInstance<any>) => Promise<void | APIJob>;
Expand All @@ -110,6 +112,7 @@ type StartJobFormProps = Props & {
fixedSelectedLayer?: APIDataLayer | null | undefined;
title: string;
buttonLabel?: string | null;
showWorkflowYaml?: boolean;
};

type BoundingBoxSelectionProps = {
Expand Down Expand Up @@ -516,6 +519,7 @@ function StartJobForm(props: StartJobFormProps) {
const activeUser = useSelector((state: OxalisState) => state.activeUser);
const layers = chooseSegmentationLayer ? getSegmentationLayers(dataset) : getColorLayers(dataset);
const allLayers = getDataLayers(dataset);
const [useCustomWorkflow, setUseCustomWorkflow] = React.useState(false);
const defaultBBForLayers: UserBoundingBox[] = layers.map((layer, index) => {
return {
id: -1 * index,
Expand Down Expand Up @@ -555,6 +559,7 @@ function StartJobForm(props: StartJobFormProps) {
newDatasetName,
selectedLayer,
selectedBoundingBox,
useCustomWorkflow,
};
const apiJob = await jobApiCall(jobArgs, form);

Expand Down Expand Up @@ -616,6 +621,7 @@ function StartJobForm(props: StartJobFormProps) {
layerName: initialLayerName,
boundingBoxId: null,
outputSegmentationLayerName: initialOutputSegmentationLayerName,
workflowYaml: DefaultPredictWorkflow,
}}
form={form}
>
Expand All @@ -642,6 +648,14 @@ function StartJobForm(props: StartJobFormProps) {
onChangeSelectedBoundingBox={(bBoxId) => form.setFieldsValue({ boundingBoxId: bBoxId })}
value={form.getFieldValue("boundingBoxId")}
/>

{props.showWorkflowYaml ? (
<CollapsibleWorkflowYamlEditor
isActive={useCustomWorkflow}
setActive={setUseCustomWorkflow}
/>
) : null}

<div style={{ textAlign: "center" }}>
<Button type="primary" size="large" htmlType="submit">
{props.buttonLabel ? props.buttonLabel : title}
Expand Down Expand Up @@ -813,12 +827,14 @@ function CustomAiModelInferenceForm() {
title="AI Inference"
suggestedDatasetSuffix="with_custom_model"
isBoundingBoxConfigurable
showWorkflowYaml
jobApiCall={async (
{
newDatasetName,
selectedLayer: colorLayer,
selectedBoundingBox,
outputSegmentationLayerName,
useCustomWorkflow,
},
form,
) => {
Expand All @@ -832,6 +848,7 @@ function CustomAiModelInferenceForm() {
return runInferenceJob({
...maybeAnnotationId,
aiModelId: form.getFieldValue("aiModel"),
workflowYaml: useCustomWorkflow ? form.getFieldValue("workflowYaml") : undefined,
datasetName: dataset.name,
colorLayerName: colorLayer.name,
boundingBox,
Expand Down
63 changes: 39 additions & 24 deletions frontend/javascripts/oxalis/view/jobs/train_ai_model.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React from "react";
import { Alert, Form, Row, Col, Input, Button, Select, Collapse, Tooltip } from "antd";
import { Alert, Form, Row, Col, Input, Button, Select, Collapse, Tooltip, Checkbox } from "antd";
import { useSelector } from "react-redux";
import { OxalisState, UserBoundingBox } from "oxalis/store";
import { getUserBoundingBoxesFromState } from "oxalis/model/accessors/tracing_accessor";
Expand Down Expand Up @@ -30,6 +30,7 @@ enum AiModelCategory {
export function TrainAiModelTab({ onClose }: { onClose: () => void }) {
const [form] = Form.useForm();

const [useCustomWorkflow, setUseCustomWorkflow] = React.useState(false);
const tracing = useSelector((state: OxalisState) => state.tracing);
const dataset = useSelector((state: OxalisState) => state.dataset);
const onFinish = async (values: any) => {
Expand All @@ -49,7 +50,7 @@ export function TrainAiModelTab({ onClose }: { onClose: () => void }) {
],
name: values.modelName,
aiModelCategory: values.modelCategory,
workflowYaml: values.workflowYaml,
workflowYaml: useCustomWorkflow ? values.workflowYaml : undefined,
comment: values.comment,
});
Toast.success("The training has successfully started.");
Expand Down Expand Up @@ -146,28 +147,9 @@ export function TrainAiModelTab({ onClose }: { onClose: () => void }) {
</FormItem>
</Col>
</Row>
<Collapse
style={{ marginBottom: 8 }}
items={[
{
key: "advanced",
label: "Advanced",
children: (
<FormItem name="workflowYaml" label="Workflow Description (yaml)">
<TextArea
className="input-monospace"
autoSize={{
minRows: 6,
}}
style={{
fontFamily: 'Monaco, Consolas, "Lucida Console", "Courier New", monospace',
}}
/>
</FormItem>
),
},
]}
defaultActiveKey={[]}
<CollapsibleWorkflowYamlEditor
isActive={useCustomWorkflow}
setActive={setUseCustomWorkflow}
/>

<FormItem hasFeedback name="dummy" label="Training Data">
Expand All @@ -194,6 +176,39 @@ export function TrainAiModelTab({ onClose }: { onClose: () => void }) {
);
}

export function CollapsibleWorkflowYamlEditor({
isActive = false,
setActive,
}: { isActive: boolean; setActive: (active: boolean) => void }) {
return (
<Collapse
style={{ marginBottom: 8 }}
onChange={() => setActive(!isActive)}
expandIcon={() => <Checkbox checked={isActive} onChange={() => setActive(!isActive)} />}
items={[
{
key: "advanced",
label: "Advanced",
children: (
<FormItem name="workflowYaml" label="Workflow Description (yaml)">
<TextArea
className="input-monospace"
autoSize={{
minRows: 6,
}}
style={{
fontFamily: 'Monaco, Consolas, "Lucida Console", "Courier New", monospace',
}}
/>
</FormItem>
),
},
]}
activeKey={isActive ? "advanced" : []}
/>
);
}

function areBoundingBoxesValid(userBoundingBoxes: UserBoundingBox[]): {
valid: boolean;
reason: string | null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,10 @@ trait AbstractRequestLogging extends LazyLogging {
def logTime(notifier: String => Unit, durationThreshold: FiniteDuration = 30 seconds)(
block: => Future[Result])(implicit request: Request[_], ec: ExecutionContext): Future[Result] = {
def logTimeFormatted(executionTime: FiniteDuration, request: Request[_], result: Result): Unit = {
val debugString = s"Request ${request.method} ${request.uri} took ${BigDecimal(executionTime.toMillis / 1000)
.setScale(2, BigDecimal.RoundingMode.HALF_UP)} seconds and was${if (result.header.status != 200) " not "
else " "}successful"
val debugString =
s"Request ${request.method} ${request.uri} took ${BigDecimal(executionTime.toMillis.toDouble / 1000)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

slipped in here: the decimals were lost due to integer division

.setScale(2, BigDecimal.RoundingMode.HALF_UP)} seconds and was${if (result.header.status != 200) " not "
else " "}successful"
logger.info(debugString)
notifier(debugString)
}
Expand Down