diff --git a/mira/dkg/model.py b/mira/dkg/model.py index 9758373a2..976d59902 100644 --- a/mira/dkg/model.py +++ b/mira/dkg/model.py @@ -6,7 +6,7 @@ import uuid from pathlib import Path from textwrap import dedent -from typing import Any, Dict, List, Literal, Optional, Set, Type, Union +from typing import Any, Dict, List, Literal, Optional, Set, Type, Union, Tuple import pystow from fastapi import ( @@ -19,14 +19,14 @@ Request, ) from fastapi.responses import FileResponse -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, root_validator, validator from mira.examples.sir import sir_bilayer, sir, sir_parameterized_init from mira.metamodel import ( NaturalConversion, Template, ControlledConversion, stratify, Concept, ModelComparisonGraphdata, TemplateModelDelta, TemplateModel, Parameter, simplify_rate_laws, aggregate_parameters, - counts_to_dimensionless + counts_to_dimensionless, deactivate_templates ) from mira.modeling import Model from mira.modeling.askenet.petrinet import AskeNetPetriNetModel, ModelSpecification @@ -87,12 +87,20 @@ ] ) +# Used as example in the deactivation endpoint +age_strata = stratify(sir_parameterized_init, + key='age', + strata=['young', 'old'], + cartesian_control=True) + #: PetriNetModel json example petrinet_json = PetriNetModel(Model(sir)).to_pydantic() askenet_petrinet_json = AskeNetPetriNetModel(Model(sir)).to_pydantic() askenet_petrinet_json_units_values = AskeNetPetriNetModel( Model(sir_parameterized_init) ).to_pydantic() +askenet_petrinet_json_deactivate = AskeNetPetriNetModel(Model( + age_strata)).to_pydantic() @model_blueprint.post( @@ -296,6 +304,146 @@ def model_stratification( return template_model +# template deactivation +class DeactivationQuery(BaseModel): + model: Dict[str, Any] = Field( + ..., + description="The model to deactivate transitions in", + example=askenet_petrinet_json_units_values + ) + parameters: Optional[List[str]] = Field( + None, + description="Deactivates transitions that have a parameter from the " + "provided list in their rate law", + example=["beta"] + ) + transitions: Optional[List[List[str]]] = Field( + None, + description="Deactivates transitions that have a source-target " + "pair from the provided list", + example=[ + ["infected_population_old", "infected_population_young"], + ["infected_population_young", "infected_population_old"] + ] + ) + and_or: Literal["and", "or"] = Field( + "and", + description="If both transitions and parameters are provided, " + "whether to deactivate transitions that match both " + "or either of the provided conditions. If only one " + "of transitions or parameters is provided, this " + "parameter is has no effect.", + example="and" + ) + + @validator('transitions') + def check_transitions(cls, v): + # This enforces that the transitions are a list of lists of length 2 + # (since we can't use tuples for JSON (or can we?)) + if v is not None: + for transition in v: + if len(transition) != 2: + raise ValueError( + "Each transition must be a list of length 2" + ) + return v + + @root_validator(skip_on_failure=True) + def check_a_or_b(cls, values): + if ( + values.get("parameters") is None or + values.get("parameters") == [] + ) and ( + values.get("transitions") is None or + not any(values.get("transitions", [])) + ): + raise ValueError( + 'At least one of "parameters" or "transitions" is required' + ) + return values + + +@model_blueprint.post( + "/deactivate_transitions", + response_model=ModelSpecification, + tags=["modeling"], +) +def deactivate_transitions( + query: DeactivationQuery = Body( + ..., + examples={ + "With parameters": { + "model": askenet_petrinet_json_units_values, + "parameters": ["beta"], + }, + "With transitions": { + "model": askenet_petrinet_json_deactivate, + "transitions": list( + [t.subject.name, t.outcome.name] + for t in age_strata.templates + if hasattr(t, "subject") and hasattr(t, "outcome") and + ( + t.subject.name.endswith('_young') and + t.outcome.name.endswith('_old') + or + t.subject.name.endswith('_old') and + t.outcome.name.endswith('_young') + ) + ), + }, + }, + ) +): + """Deactivate transitions in a model""" + amr_json = query.model + tm = template_model_from_askenet_json(amr_json) + + # Create callables for deactivating transitions + if query.parameters: + def deactivate_parameter(t: Template) -> bool: + """Deactivate templates that have the given parameter(s) in + their rate law""" + if t.rate_law is None: + return False + for symb in t.rate_law.atoms(): + if str(symb) in set(query.parameters): + return True + else: + deactivate_parameter = None + + if query.transitions is not None: + def deactivate_transition(t: Template) -> bool: + """Deactivate template if it is a transition-like template and it + matches the source-target pair""" + if hasattr(t, "subject") and hasattr(t, "outcome"): + for subject, outcome in query.transitions: + if t.subject.name == subject and t.outcome.name == outcome: + return True + return False + else: + deactivate_transition = None + + def meta_deactivate(t: Template) -> bool: + if deactivate_parameter is not None and \ + deactivate_transition is not None: + if query.and_or == "and": + return deactivate_parameter(t) and deactivate_transition(t) + else: + return deactivate_parameter(t) or deactivate_transition(t) + elif deactivate_parameter is None: + return deactivate_transition(t) + elif deactivate_transition is None: + return deactivate_parameter(t) + else: + raise ValueError( + "Need to provide either or both of parameters or transitions" + ) + + deactivate_templates(template_model=tm, condition=meta_deactivate) + + return AskeNetPetriNetModel(Model(tm)).to_pydantic() + + @model_blueprint.post( "/counts_to_dimensionless_mira", response_model=TemplateModel, diff --git a/mira/examples/sir.py b/mira/examples/sir.py index 0fff0cad0..44f52f49a 100644 --- a/mira/examples/sir.py +++ b/mira/examples/sir.py @@ -152,7 +152,6 @@ sir_parameterized_init.parameters['beta'].units = \ Unit(expression=1 / (sympy.Symbol('person') * sympy.Symbol('day'))) -old_beta = sir_parameterized_init.parameters['beta'].value for initial in sir_parameterized_init.initials.values(): initial.concept.units = Unit(expression=sympy.Symbol('person')) diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 14fb83ff6..5957486cd 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -694,3 +694,88 @@ def test_reconstruct_ode_semantics_endpoint(self): assert len(flux_span_tm.parameters) == 11 assert all(t.rate_law for t in flux_span_tm.templates) + def test_deactivation_endpoint(self): + # Deliberately create a stratifiction that will lead to nonsense + # transitions, i.e. a transitions between age groups + age_strata = stratify(sir_parameterized_init, + key='age', + strata=['y', 'o'], + cartesian_control=True) + + amr_sir = AskeNetPetriNetModel(Model(age_strata)).to_json() + + # Test the endpoint itself + # Should fail with 422 because of missing transitions or parameters + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of empty transition list + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": [[]]} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of transitions are triples + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": [['a', 'b', 'c']]} + ) + self.assertEqual(422, response.status_code) + + # Should fail with 422 because of empty parameters list + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "parameters": []} + ) + self.assertEqual(422, response.status_code) + + # Actual Test + # Assert that there are old to young transitions + transition_list = [] + for template in age_strata.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if subj.endswith('_o') and outc.endswith('_y') or \ + subj.endswith('_y') and outc.endswith('_o'): + transition_list.append((subj, outc)) + assert len(transition_list), "No old to young transitions found" + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "transitions": transition_list} + ) + self.assertEqual(200, response.status_code) + + # Check that the transitions are deactivated + amr_sir_deactivated = response.json() + tm_deactivated = template_model_from_askenet_json(amr_sir_deactivated) + for template in tm_deactivated.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if (subj, outc) in transition_list: + assert template.rate_law.args[0] == \ + sympy.core.numbers.Zero(), \ + template.rate_law + + # Test using parameter names for deactivation + deactivate_key = list(age_strata.parameters.keys())[0] + response = self.client.post( + "/api/deactivate_transitions", + json={"model": amr_sir, "parameters": [deactivate_key]} + ) + self.assertEqual(200, response.status_code) + amr_sir_deactivated_params = response.json() + tm_deactivated_params = template_model_from_askenet_json( + amr_sir_deactivated_params) + for template in tm_deactivated_params.templates: + # All rate laws must either be zero or not contain the deactivated + # parameter + if template.rate_law and not template.rate_law.is_zero: + for symb in template.rate_law.atoms(): + assert str(symb) != deactivate_key + else: + assert (template.rate_law.args[0] == sympy.core.numbers.Zero(), + template.rate_law)