Skip to content

Commit

Permalink
Model to latex with mira task (#4831)
Browse files Browse the repository at this point in the history
Co-authored-by: Yohann Paris <[email protected]>
  • Loading branch information
Tom-Szendrey and YohannParis authored Sep 18, 2024
1 parent c09f3ea commit c8e6da7
Show file tree
Hide file tree
Showing 11 changed files with 182 additions and 9 deletions.
10 changes: 2 additions & 8 deletions packages/client/hmi-client/src/services/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,14 +174,8 @@ export async function getModelEquation(model: Model): Promise<string> {
return '';
}

/* TODO - Replace the GET with the POST when the backend is ready,
* see PR https://github.com/DARPA-ASKEM/sciml-service/pull/167
*/
const response = await API.get(`/transforms/model-to-latex/${model.id}`);
// const response = await API.post(`/transforms/model-to-latex/`, model);
const latex = response?.data?.latex;
if (!latex) return '';
return latex ?? '';
const response = await API.post(`/mira/model-to-latex`, model);
return response?.data?.response ?? '';
}

export const getUnitsFromModelParts = (model: Model) => {
Expand Down
1 change: 1 addition & 0 deletions packages/client/hmi-client/src/types/Types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1059,6 +1059,7 @@ export enum ClientEventType {
TaskFunmanValidation = "TASK_FUNMAN_VALIDATION",
TaskGollmEnrichAmr = "TASK_GOLLM_ENRICH_AMR",
TaskMiraAmrToMmt = "TASK_MIRA_AMR_TO_MMT",
TaskMiraGenerateModelLatex = "TASK_MIRA_GENERATE_MODEL_LATEX",
TaskEnrichAmr = "TASK_ENRICH_AMR",
WorkflowUpdate = "WORKFLOW_UPDATE",
WorkflowDelete = "WORKFLOW_DELETE",
Expand Down
Empty file.
1 change: 1 addition & 0 deletions packages/mira/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"mira_task:mdl_to_stockflow=tasks.mdl_to_stockflow:main",
"mira_task:stella_to_stockflow=tasks.stella_to_stockflow:main",
"mira_task:amr_to_mmt=tasks.amr_to_mmt:main",
"mira_task:generate_model_latex=tasks.generate_model_latex:main",
],
},
python_requires=">=3.10",
Expand Down
Empty file added packages/mira/tasks/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions packages/mira/tasks/generate_model_latex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
import sys
import json
import traceback
from taskrunner import TaskRunnerInterface
import sympy
from mira.sources.amr import model_from_json

def cleanup():
pass

def main():
exitCode = 0

try:
taskrunner = TaskRunnerInterface(description="Generate latex")
taskrunner.on_cancellation(cleanup)

data = taskrunner.read_input_str_with_timeout()
amr = json.loads(data)
model = model_from_json(amr)

odeterms = {var: 0 for var in model.get_concepts_name_map().keys()}

for t in model.templates:
if hasattr(t, "subject"):
var = t.subject.name
odeterms[var] -= t.rate_law.args[0]

if hasattr(t, "outcome"):
var = t.outcome.name
odeterms[var] += t.rate_law.args[0]

# Time
if model.time and model.time.name:
time = model.time.name
else:
time = "t"

t = sympy.Symbol(time)

# Observables
if len(model.observables) != 0:
obs_eqs = [
f"{{{obs.name}}}(t) = " + sympy.latex(obs.expression.args[0])
for obs in model.observables.values()
]

# Construct Sympy equations
odesys = [
sympy.latex(sympy.Eq(sympy.diff(sympy.Function(var)(t), t), terms))
for var, terms in odeterms.items()
]

#add observables.
odesys += obs_eqs
#Reformat:
odesys = "\\begin{align} \n " + " \\\\ \n ".join([eq for eq in odesys]) + "\n\\end{align}"

taskrunner.write_output_dict_with_timeout({"response": odesys})
print("Generate latex succeeded")

except Exception as e:
sys.stderr.write(f"Error: {str(e)}\n")
sys.stderr.write(traceback.format_exc())
sys.stderr.flush()
exitCode = 1


taskrunner.log("Shutting down")
taskrunner.shutdown()
sys.exit(exitCode)


if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import software.uncharted.terarium.hmiserver.service.data.ModelConfigurationService;
import software.uncharted.terarium.hmiserver.service.data.ProjectService;
import software.uncharted.terarium.hmiserver.service.tasks.AMRToMMTResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.GenerateModelLatexResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.MdlToStockflowResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.SbmlToPetrinetResponseHandler;
import software.uncharted.terarium.hmiserver.service.tasks.StellaToStockflowResponseHandler;
Expand Down Expand Up @@ -167,6 +168,66 @@ public ResponseEntity<JsonNode> convertAMRtoMMT(@RequestBody final JsonNode mode
return ResponseEntity.ok().body(mmtInfo);
}

@PostMapping("/model-to-latex")
@Secured(Roles.USER)
@Operation(summary = "Generate latex from a model id")
@ApiResponses(
value = {
@ApiResponse(
responseCode = "200",
description = "Dispatched successfully",
content = @Content(
mediaType = "application/json",
schema = @io.swagger.v3.oas.annotations.media.Schema(implementation = TaskResponse.class)
)
),
@ApiResponse(responseCode = "500", description = "There was an issue dispatching the request", content = @Content)
}
)
public ResponseEntity<JsonNode> generateModelLatex(@RequestBody final JsonNode model) {
//create request:
final TaskRequest req = new TaskRequest();
req.setType(TaskType.MIRA);

try {
req.setInput(objectMapper.treeToValue(model, Model.class).serializeWithoutTerariumFields().getBytes());
} catch (final Exception e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.write"));
}

req.setScript(GenerateModelLatexResponseHandler.NAME);
req.setUserId(currentUserService.get().getId());

// send the request
final TaskResponse resp;
try {
resp = taskService.runTaskSync(req);
} catch (final JsonProcessingException e) {
log.error("Unable to serialize input", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.json-processing"));
} catch (final TimeoutException e) {
log.warn("Timeout while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.SERVICE_UNAVAILABLE, messages.get("task.mira.timeout"));
} catch (final InterruptedException e) {
log.warn("Interrupted while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.UNPROCESSABLE_ENTITY, messages.get("task.mira.interrupted"));
} catch (final ExecutionException e) {
log.error("Error while waiting for task response", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("task.mira.execution-failure"));
}

final JsonNode latexResponse;
try {
latexResponse = objectMapper.readValue(resp.getOutput(), JsonNode.class);
} catch (final IOException e) {
log.error("Unable to deserialize output", e);
throw new ResponseStatusException(HttpStatus.INTERNAL_SERVER_ERROR, messages.get("generic.io-error.read"));
}

return ResponseEntity.ok().body(latexResponse);
}

@PostMapping("/convert-and-create-model")
@Secured(Roles.USER)
@Operation(summary = "Dispatch a MIRA conversion task")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public enum ClientEventType {
TASK_FUNMAN_VALIDATION,
TASK_GOLLM_ENRICH_AMR,
TASK_MIRA_AMR_TO_MMT,
TASK_MIRA_GENERATE_MODEL_LATEX,
TASK_ENRICH_AMR,
WORKFLOW_UPDATE,
WORKFLOW_DELETE
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package software.uncharted.terarium.hmiserver.service.tasks;

import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Component;

@Component
@RequiredArgsConstructor
@Slf4j
public class GenerateModelLatexResponseHandler extends TaskResponseHandler {

public static final String NAME = "mira_task:generate_model_latex";

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ public class TaskNotificationEventTypes {
EnrichAmrResponseHandler.NAME,
ClientEventType.TASK_GOLLM_ENRICH_AMR,
AMRToMMTResponseHandler.NAME,
ClientEventType.TASK_MIRA_AMR_TO_MMT
ClientEventType.TASK_MIRA_AMR_TO_MMT,
GenerateModelLatexResponseHandler.NAME,
ClientEventType.TASK_MIRA_GENERATE_MODEL_LATEX
);

public static ClientEventType getTypeFor(final String taskName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,26 @@ public void testItCanSendAmrToMmtRequest() throws Exception {
log.info(new String(resp.getOutput()));
}

// @Test
@WithUserDetails(MockUser.URSULA)
public void testItCanSendGenerateModelLatexRequest() throws Exception {
final UUID taskId = UUID.randomUUID();

final ClassPathResource resource = new ClassPathResource("mira/problem.json");
final String content = new String(Files.readAllBytes(resource.getFile().toPath()));

final TaskRequest req = new TaskRequest();
req.setType(TaskType.MIRA);
req.setScript("mira_task:generate_model_latex");
req.setInput(content.getBytes());

final TaskResponse resp = taskService.runTaskSync(req);

Assertions.assertEquals(taskId, resp.getId());

log.info(new String(resp.getOutput()));
}

// @Test
@WithUserDetails(MockUser.URSULA)
public void testItCanCacheSuccess() throws Exception {
Expand Down

0 comments on commit c8e6da7

Please sign in to comment.