Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deactivation endpoint #226

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 151 additions & 3 deletions mira/dkg/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from pathlib import Path
from textwrap import dedent
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union
from typing import Any, Dict, List, Literal, Optional, Set, Type, Union, Tuple

import pystow
from fastapi import (
Expand All @@ -19,14 +19,14 @@
Request,
)
from fastapi.responses import FileResponse
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, root_validator, validator

from mira.examples.sir import sir_bilayer, sir, sir_parameterized_init
from mira.metamodel import (
NaturalConversion, Template, ControlledConversion,
stratify, Concept, ModelComparisonGraphdata, TemplateModelDelta,
TemplateModel, Parameter, simplify_rate_laws, aggregate_parameters,
counts_to_dimensionless
counts_to_dimensionless, deactivate_templates
)
from mira.modeling import Model
from mira.modeling.askenet.petrinet import AskeNetPetriNetModel, ModelSpecification
Expand Down Expand Up @@ -87,12 +87,20 @@
]
)

# Used as example in the deactivation endpoint
age_strata = stratify(sir_parameterized_init,
key='age',
strata=['young', 'old'],
cartesian_control=True)

#: PetriNetModel json example
petrinet_json = PetriNetModel(Model(sir)).to_pydantic()
askenet_petrinet_json = AskeNetPetriNetModel(Model(sir)).to_pydantic()
askenet_petrinet_json_units_values = AskeNetPetriNetModel(
Model(sir_parameterized_init)
).to_pydantic()
askenet_petrinet_json_deactivate = AskeNetPetriNetModel(Model(
age_strata)).to_pydantic()


@model_blueprint.post(
Expand Down Expand Up @@ -296,6 +304,146 @@ def model_stratification(
return template_model


# template deactivation
class DeactivationQuery(BaseModel):
model: Dict[str, Any] = Field(
...,
description="The model to deactivate transitions in",
example=askenet_petrinet_json_units_values
)
parameters: Optional[List[str]] = Field(
None,
description="Deactivates transitions that have a parameter from the "
"provided list in their rate law",
example=["beta"]
)
transitions: Optional[List[List[str]]] = Field(
None,
description="Deactivates transitions that have a source-target "
"pair from the provided list",
example=[
["infected_population_old", "infected_population_young"],
["infected_population_young", "infected_population_old"]
]
)
and_or: Literal["and", "or"] = Field(
"and",
description="If both transitions and parameters are provided, "
"whether to deactivate transitions that match both "
"or either of the provided conditions. If only one "
"of transitions or parameters is provided, this "
"parameter is has no effect.",
example="and"
)

@validator('transitions')
def check_transitions(cls, v):
# This enforces that the transitions are a list of lists of length 2
# (since we can't use tuples for JSON (or can we?))
if v is not None:
for transition in v:
if len(transition) != 2:
raise ValueError(
"Each transition must be a list of length 2"
)
return v

@root_validator(skip_on_failure=True)
def check_a_or_b(cls, values):
if (
values.get("parameters") is None or
values.get("parameters") == []
) and (
values.get("transitions") is None or
not any(values.get("transitions", []))
):
raise ValueError(
'At least one of "parameters" or "transitions" is required'
)
return values


@model_blueprint.post(
"/deactivate_transitions",
response_model=ModelSpecification,
tags=["modeling"],
)
def deactivate_transitions(
query: DeactivationQuery = Body(
...,
examples={
"With parameters": {
"model": askenet_petrinet_json_units_values,
"parameters": ["beta"],
},
"With transitions": {
"model": askenet_petrinet_json_deactivate,
"transitions": list(
[t.subject.name, t.outcome.name]
for t in age_strata.templates
if hasattr(t, "subject") and hasattr(t, "outcome") and
(
t.subject.name.endswith('_young') and
t.outcome.name.endswith('_old')
or
t.subject.name.endswith('_old') and
t.outcome.name.endswith('_young')
)
),
},
},
)
):
"""Deactivate transitions in a model"""
amr_json = query.model
tm = template_model_from_askenet_json(amr_json)

# Create callables for deactivating transitions
if query.parameters:
def deactivate_parameter(t: Template) -> bool:
"""Deactivate templates that have the given parameter(s) in
their rate law"""
if t.rate_law is None:
return False
for symb in t.rate_law.atoms():
if str(symb) in set(query.parameters):
return True
else:
deactivate_parameter = None

if query.transitions is not None:
def deactivate_transition(t: Template) -> bool:
"""Deactivate template if it is a transition-like template and it
matches the source-target pair"""
if hasattr(t, "subject") and hasattr(t, "outcome"):
for subject, outcome in query.transitions:
if t.subject.name == subject and t.outcome.name == outcome:
return True
return False
else:
deactivate_transition = None

def meta_deactivate(t: Template) -> bool:
if deactivate_parameter is not None and \
deactivate_transition is not None:
if query.and_or == "and":
return deactivate_parameter(t) and deactivate_transition(t)
else:
return deactivate_parameter(t) or deactivate_transition(t)
elif deactivate_parameter is None:
return deactivate_transition(t)
elif deactivate_transition is None:
return deactivate_parameter(t)
else:
raise ValueError(
"Need to provide either or both of parameters or transitions"
)

deactivate_templates(template_model=tm, condition=meta_deactivate)

return AskeNetPetriNetModel(Model(tm)).to_pydantic()


@model_blueprint.post(
"/counts_to_dimensionless_mira",
response_model=TemplateModel,
Expand Down
1 change: 0 additions & 1 deletion mira/examples/sir.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@

sir_parameterized_init.parameters['beta'].units = \
Unit(expression=1 / (sympy.Symbol('person') * sympy.Symbol('day')))
old_beta = sir_parameterized_init.parameters['beta'].value

for initial in sir_parameterized_init.initials.values():
initial.concept.units = Unit(expression=sympy.Symbol('person'))
85 changes: 85 additions & 0 deletions tests/test_model_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -694,3 +694,88 @@ def test_reconstruct_ode_semantics_endpoint(self):
assert len(flux_span_tm.parameters) == 11
assert all(t.rate_law for t in flux_span_tm.templates)

def test_deactivation_endpoint(self):
# Deliberately create a stratifiction that will lead to nonsense
# transitions, i.e. a transitions between age groups
age_strata = stratify(sir_parameterized_init,
key='age',
strata=['y', 'o'],
cartesian_control=True)

amr_sir = AskeNetPetriNetModel(Model(age_strata)).to_json()

# Test the endpoint itself
# Should fail with 422 because of missing transitions or parameters
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir}
)
self.assertEqual(422, response.status_code)

# Should fail with 422 because of empty transition list
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir, "transitions": [[]]}
)
self.assertEqual(422, response.status_code)

# Should fail with 422 because of transitions are triples
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir, "transitions": [['a', 'b', 'c']]}
)
self.assertEqual(422, response.status_code)

# Should fail with 422 because of empty parameters list
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir, "parameters": []}
)
self.assertEqual(422, response.status_code)

# Actual Test
# Assert that there are old to young transitions
transition_list = []
for template in age_strata.templates:
if hasattr(template, 'subject') and hasattr(template, 'outcome'):
subj, outc = template.subject.name, template.outcome.name
if subj.endswith('_o') and outc.endswith('_y') or \
subj.endswith('_y') and outc.endswith('_o'):
transition_list.append((subj, outc))
assert len(transition_list), "No old to young transitions found"
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir, "transitions": transition_list}
)
self.assertEqual(200, response.status_code)

# Check that the transitions are deactivated
amr_sir_deactivated = response.json()
tm_deactivated = template_model_from_askenet_json(amr_sir_deactivated)
for template in tm_deactivated.templates:
if hasattr(template, 'subject') and hasattr(template, 'outcome'):
subj, outc = template.subject.name, template.outcome.name
if (subj, outc) in transition_list:
assert template.rate_law.args[0] == \
sympy.core.numbers.Zero(), \
template.rate_law

# Test using parameter names for deactivation
deactivate_key = list(age_strata.parameters.keys())[0]
response = self.client.post(
"/api/deactivate_transitions",
json={"model": amr_sir, "parameters": [deactivate_key]}
)
self.assertEqual(200, response.status_code)
amr_sir_deactivated_params = response.json()
tm_deactivated_params = template_model_from_askenet_json(
amr_sir_deactivated_params)
for template in tm_deactivated_params.templates:
# All rate laws must either be zero or not contain the deactivated
# parameter
if template.rate_law and not template.rate_law.is_zero:
for symb in template.rate_law.atoms():
assert str(symb) != deactivate_key
else:
assert (template.rate_law.args[0] == sympy.core.numbers.Zero(),
template.rate_law)
Loading