diff --git a/packages/client/hmi-client/src/services/model.ts b/packages/client/hmi-client/src/services/model.ts index c221639cda..42f8919e11 100644 --- a/packages/client/hmi-client/src/services/model.ts +++ b/packages/client/hmi-client/src/services/model.ts @@ -174,14 +174,8 @@ export async function getModelEquation(model: Model): Promise { return ''; } - /* TODO - Replace the GET with the POST when the backend is ready, - * see PR https://github.com/DARPA-ASKEM/sciml-service/pull/167 - */ - const response = await API.get(`/transforms/model-to-latex/${model.id}`); - // const response = await API.post(`/transforms/model-to-latex/`, model); - const latex = response?.data?.latex; - if (!latex) return ''; - return latex ?? ''; + const response = await API.post(`/mira/model-to-latex`, model); + return response?.data?.response ?? ''; } export const getUnitsFromModelParts = (model: Model) => { diff --git a/packages/client/hmi-client/src/types/Types.ts b/packages/client/hmi-client/src/types/Types.ts index dfa28bf3f7..d1306c304e 100644 --- a/packages/client/hmi-client/src/types/Types.ts +++ b/packages/client/hmi-client/src/types/Types.ts @@ -1059,6 +1059,7 @@ export enum ClientEventType { TaskFunmanValidation = "TASK_FUNMAN_VALIDATION", TaskGollmEnrichAmr = "TASK_GOLLM_ENRICH_AMR", TaskMiraAmrToMmt = "TASK_MIRA_AMR_TO_MMT", + TaskMiraGenerateModelLatex = "TASK_MIRA_GENERATE_MODEL_LATEX", TaskEnrichAmr = "TASK_ENRICH_AMR", WorkflowUpdate = "WORKFLOW_UPDATE", WorkflowDelete = "WORKFLOW_DELETE", diff --git a/packages/funman/tasks/__init__.py b/packages/funman/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/mira/setup.py b/packages/mira/setup.py index 44e9952efd..b4a891991d 100644 --- a/packages/mira/setup.py +++ b/packages/mira/setup.py @@ -11,6 +11,7 @@ "mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main", "mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main", "mira_task:amr_to_mmt=tasks.amr_to_mmt:main", + "mira_task:generate_model_latex=tasks.generate_model_latex:main", ], }, python_requires=">=3.10", diff --git a/packages/mira/tasks/__init__.py b/packages/mira/tasks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/packages/mira/tasks/generate_model_latex.py b/packages/mira/tasks/generate_model_latex.py new file mode 100644 index 0000000000..1dad400e5d --- /dev/null +++ b/packages/mira/tasks/generate_model_latex.py @@ -0,0 +1,75 @@ +import sys +import json +import traceback +from taskrunner import TaskRunnerInterface +import sympy +from mira.sources.amr import model_from_json + +def cleanup(): + pass + +def main(): + exitCode = 0 + + try: + taskrunner = TaskRunnerInterface(description="Generate latex") + taskrunner.on_cancellation(cleanup) + + data = taskrunner.read_input_str_with_timeout() + amr = json.loads(data) + model = model_from_json(amr) + + odeterms = {var: 0 for var in model.get_concepts_name_map().keys()} + + for t in model.templates: + if hasattr(t, "subject"): + var = t.subject.name + odeterms[var] -= t.rate_law.args[0] + + if hasattr(t, "outcome"): + var = t.outcome.name + odeterms[var] += t.rate_law.args[0] + + # Time + if model.time and model.time.name: + time = model.time.name + else: + time = "t" + + t = sympy.Symbol(time) + + # Observables + if len(model.observables) != 0: + obs_eqs = [ + f"{{{obs.name}}}(t) = " + sympy.latex(obs.expression.args[0]) + for obs in model.observables.values() + ] + + # Construct Sympy equations + odesys = [ + sympy.latex(sympy.Eq(sympy.diff(sympy.Function(var)(t), t), terms)) + for var, terms in odeterms.items() + ] + + #add observables. + odesys += obs_eqs + #Reformat: + odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}" + + taskrunner.write_output_dict_with_timeout({"response": odesys}) + print("Generate latex succeeded") + + except Exception as e: + sys.stderr.write(f"Error: {str(e)}\n") + sys.stderr.write(traceback.format_exc()) + sys.stderr.flush() + exitCode = 1 + + + taskrunner.log("Shutting down") + taskrunner.shutdown() + sys.exit(exitCode) + + +if __name__ == "__main__": + main() diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java index 9f93710c0c..cd2b45dc0e 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/controller/mira/MiraController.java @@ -46,6 +46,7 @@ import software.uncharted.terarium.hmiserver.service.data.ModelConfigurationService; import software.uncharted.terarium.hmiserver.service.data.ProjectService; import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler; +import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler; import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler; @@ -167,6 +168,66 @@ public ResponseEntity convertAMRtoMMT(@RequestBody final JsonNode mode return ResponseEntity.ok().body(mmtInfo); } + @PostMapping("/model-to-latex") + @Secured(Roles.USER) + @Operation(summary = "Generate latex from a model id") + @ApiResponses( + value = { + @ApiResponse( + responseCode = "200", + description = "Dispatched successfully", + content = @Content( + mediaType = "application/json", + schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class) + ) + ), + @ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content) + } + ) + public ResponseEntity generateModelLatex(@RequestBody final JsonNode model) { + //create request: + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.MIRA); + + try { + req.setInput(objectMapper.treeToValue(model, Model.class).serializeWithoutTerariumFields().getBytes()); + } catch (final Exception e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write")); + } + + req.setScript(GenerateModelLatexResponseHandler.NAME); + req.setUserId(currentUserService.get().getId()); + + // send the request + final TaskResponse resp; + try { + resp = taskService.runTaskSync(req); + } catch (final JsonProcessingException e) { + log.error("Unable to serialize input", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing")); + } catch (final TimeoutException e) { + log.warn("Timeout while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout")); + } catch (final InterruptedException e) { + log.warn("Interrupted while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted")); + } catch (final ExecutionException e) { + log.error("Error while waiting for task response", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure")); + } + + final JsonNode latexResponse; + try { + latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class); + } catch (final IOException e) { + log.error("Unable to deserialize output", e); + throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read")); + } + + return ResponseEntity.ok().body(latexResponse); + } + @PostMapping("/convert-and-create-model") @Secured(Roles.USER) @Operation(summary = "Dispatch a MIRA conversion task") diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/ClientEventType.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/ClientEventType.java index f92f20f4a8..b6a9440bb2 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/ClientEventType.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/models/ClientEventType.java @@ -19,6 +19,7 @@ public enum ClientEventType { TASK_FUNMAN_VALIDATION, TASK_GOLLM_ENRICH_AMR, TASK_MIRA_AMR_TO_MMT, + TASK_MIRA_GENERATE_MODEL_LATEX, TASK_ENRICH_AMR, WORKFLOW_UPDATE, WORKFLOW_DELETE diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/GenerateModelLatexResponseHandler.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/GenerateModelLatexResponseHandler.java new file mode 100644 index 0000000000..d83297806f --- /dev/null +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/GenerateModelLatexResponseHandler.java @@ -0,0 +1,18 @@ +package software.uncharted.terarium.hmiserver.service.tasks; + +import lombok.RequiredArgsConstructor; +import lombok.extern.slf4j.Slf4j; +import org.springframework.stereotype.Component; + +@Component +@RequiredArgsConstructor +@Slf4j +public class GenerateModelLatexResponseHandler extends TaskResponseHandler { + + public static final String NAME = "mira_task:generate_model_latex"; + + @Override + public String getName() { + return NAME; + } +} diff --git a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskNotificationEventTypes.java b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskNotificationEventTypes.java index 3b33ef4574..391468da52 100644 --- a/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskNotificationEventTypes.java +++ b/packages/server/src/main/java/software/uncharted/terarium/hmiserver/service/tasks/TaskNotificationEventTypes.java @@ -23,7 +23,9 @@ public class TaskNotificationEventTypes { EnrichAmrResponseHandler.NAME, ClientEventType.TASK_GOLLM_ENRICH_AMR, AMRToMMTResponseHandler.NAME, - ClientEventType.TASK_MIRA_AMR_TO_MMT + ClientEventType.TASK_MIRA_AMR_TO_MMT, + GenerateModelLatexResponseHandler.NAME, + ClientEventType.TASK_MIRA_GENERATE_MODEL_LATEX ); public static ClientEventType getTypeFor(final String taskName) { diff --git a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java index 8402333992..319f3d1370 100644 --- a/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java +++ b/packages/server/src/test/java/software/uncharted/terarium/hmiserver/service/tasks/TaskServiceTest.java @@ -314,6 +314,26 @@ public void testItCanSendAmrToMmtRequest() throws Exception { log.info(new String(resp.getOutput())); } + // @Test + @WithUserDetails(MockUser.URSULA) + public void testItCanSendGenerateModelLatexRequest() throws Exception { + final UUID taskId = UUID.randomUUID(); + + final ClassPathResource resource = new ClassPathResource("mira/problem.json"); + final String content = new String(Files.readAllBytes(resource.getFile().toPath())); + + final TaskRequest req = new TaskRequest(); + req.setType(TaskType.MIRA); + req.setScript("mira_task:generate_model_latex"); + req.setInput(content.getBytes()); + + final TaskResponse resp = taskService.runTaskSync(req); + + Assertions.assertEquals(taskId, resp.getId()); + + log.info(new String(resp.getOutput())); + } + // @Test @WithUserDetails(MockUser.URSULA) public void testItCanCacheSuccess() throws Exception {