Skip to content

Commit

Permalink
Merge pull request #364 from gyorilab/obs_pattern
Browse files Browse the repository at this point in the history
Implement adding observables by pattern
  • Loading branch information
bgyori authored Sep 20, 2024
2 parents a3cd2a9 + 11beeb5 commit 8e39c17
Show file tree
Hide file tree
Showing 6 changed files with 489 additions and 6 deletions.
2 changes: 1 addition & 1 deletion mira/dkg/askemo/generate_site.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main():
resource=ASKEMO,
metaregistry_name="MIRA Epi Metaregistry",
metaregistry_metaprefix="askemr",
metaregistry_base_url="http://34.230.33.149:8772/",
metaregistry_base_url="http://mira-metaregistry-lb-be8a34d7051f5236.elb.us-east-1.amazonaws.com/",
)


Expand Down
2 changes: 1 addition & 1 deletion mira/dkg/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
WIKIDATA_API = "https://www.wikidata.org/w/api.php"

#: Base URL for the metaregistry, used in creating links
METAREGISTRY_BASE = "http://34.230.33.149:8772"
METAREGISTRY_BASE = "http://mira-metaregistry-lb-be8a34d7051f5236.elb.us-east-1.amazonaws.com"

Node: TypeAlias = Mapping[str, Any]

Expand Down
62 changes: 60 additions & 2 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@

import sympy

from .template_model import TemplateModel, Initial, Parameter
from .template_model import TemplateModel, Initial, Parameter, Observable
from .templates import *
from .comparison import get_dkg_refinement_closure
from .units import Unit
from .utils import SympyExprStr

Expand All @@ -19,7 +20,8 @@
"aggregate_parameters",
"get_term_roles",
"counts_to_dimensionless",
"deactivate_templates"
"deactivate_templates",
"add_observable_pattern",
]


Expand Down Expand Up @@ -743,3 +745,59 @@ def deactivate_templates(
for template in template_model.templates:
if condition(template):
template.deactivate()


def add_observable_pattern(
template_model: TemplateModel,
name: str,
identifiers: Mapping = None,
context: Mapping = None,
):
"""Add an observable for a pattern of concepts.
Parameters
----------
template_model :
A template model.
name :
The name of the observable.
identifiers :
Identifiers serving as a pattern for concepts to observe.
context :
Context serving as a pattern for concepts to observe.
"""
observable_concepts = []
identifiers = set(identifiers.items() if identifiers else {})
contexts = set(context.items() if context else {})
for key, concept in template_model.get_concepts_map().items():
if (not identifiers) or identifiers.issubset(
set(concept.identifiers.items())):
if (not contexts) or contexts.issubset(
set(concept.context.items())):
observable_concepts.append(concept)
obs = get_observable_for_concepts(observable_concepts, name)
template_model.observables[name] = obs


def get_observable_for_concepts(concepts: List[Concept], name: str):
"""Return an observable expressing a sum of a set of concepts.
Parameters
----------
concepts :
A list of concepts.
name :
The name of the observable.
Returns
-------
:
An observable that sums the given concepts.
"""
expr = None
for concept in concepts:
if expr is None:
expr = sympy.Symbol(concept.name)
else:
expr += sympy.Symbol(concept.name)
return Observable(name=name, expression=SympyExprStr(expr))
391 changes: 391 additions & 0 deletions notebooks/observable_patterns.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/viz_strat_petri.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.10.14"
}
},
"nbformat": 4,
Expand Down
36 changes: 35 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import sympy

from mira.metamodel import *
from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless
from mira.metamodel.ops import stratify, simplify_rate_law, \
counts_to_dimensionless, add_observable_pattern, \
get_observable_for_concepts
from mira.examples.sir import cities, sir, sir_2_city, sir_parameterized
from mira.examples.concepts import infected, susceptible
from mira.examples.chime import sviivr
Expand Down Expand Up @@ -582,3 +584,35 @@ def test_stratify_parameter_consistency():
# be the case when parameters would be incrementally numbered for each
# new template
assert len(tm.parameters) == 2


def test_get_observable_for_concepts():
concepts = [
Concept(name='A'),
Concept(name='B'),
Concept(name='C'),
]
obs = get_observable_for_concepts(concepts, 'obs')
assert obs.name == 'obs'
assert obs.expression.args[0] == sum([sympy.Symbol(c.name) for c in concepts])


def test_add_observable_pattern():
templates = [
NaturalDegradation(subject=Concept(name='A', identifiers={'ido': '0000514'}),
rate_law=sympy.Symbol('alpha') * sympy.Symbol('A')),
NaturalDegradation(subject=Concept(name='B', identifiers={'ido': '0000515'}),
rate_law=sympy.Symbol('alpha') * sympy.Symbol('B')),
]
tm = TemplateModel(templates=templates,
parameters={'alpha': Parameter(name='alpha', value=0.1)})
tm = stratify(tm, key='age', strata=['young', 'old'], structure=[])
add_observable_pattern(tm, name='A', identifiers={'ido': '0000514'})
assert 'A' in tm.observables
obs = tm.observables['A']
assert obs.expression.args[0] == sympy.Symbol('A_old') + sympy.Symbol('A_young')

add_observable_pattern(tm, 'young', context={'age': 'young'})
assert 'young' in tm.observables
obs = tm.observables['young']
assert obs.expression.args[0] == sympy.Symbol('A_young') + sympy.Symbol('B_young')

0 comments on commit 8e39c17

Please sign in to comment.