Skip to content

Commit

Permalink
Implement refinement and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bgyori committed Sep 18, 2024
1 parent 9ba12ae commit a84a9b5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 5 deletions.
11 changes: 7 additions & 4 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from .template_model import TemplateModel, Initial, Parameter, Observable
from .templates import *
from .comparison import get_dkg_refinement_closure
from .units import Unit
from .utils import SympyExprStr

Expand Down Expand Up @@ -748,7 +749,8 @@ def deactivate_templates(
def add_observable_pattern(
template_model: TemplateModel,
concept_pattern: Concept,
name: str
name: str,
refinement_func=None,
):
"""Add an observable for a pattern of concepts.
Expand All @@ -766,10 +768,11 @@ def add_observable_pattern(
:
A template model with the observable added.
"""

if refinement_func is None:
refinement_func = get_dkg_refinement_closure().is_ontological_child
observable_concepts = []
for key, concept in template_model.get_concepts_map():
if concept.refinement_of(concept_pattern):
for key, concept in template_model.get_concepts_map().items():
if concept.refinement_of(concept_pattern, refinement_func):
observable_concepts.append(concept)
obs = get_observable_for_concepts(observable_concepts, name)
template_model.observables[name] = obs
Expand Down
31 changes: 30 additions & 1 deletion tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import sympy

from mira.metamodel import *
from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless
from mira.metamodel.ops import stratify, simplify_rate_law, \
counts_to_dimensionless, add_observable_pattern, \
get_observable_for_concepts
from mira.examples.sir import cities, sir, sir_2_city, sir_parameterized
from mira.examples.concepts import infected, susceptible
from mira.examples.chime import sviivr
Expand Down Expand Up @@ -582,3 +584,30 @@ def test_stratify_parameter_consistency():
# be the case when parameters would be incrementally numbered for each
# new template
assert len(tm.parameters) == 2


def test_get_observable_for_concepts():
concepts = [
Concept(name='A'),
Concept(name='B'),
Concept(name='C'),
]
obs = get_observable_for_concepts(concepts, 'obs')
assert obs.name == 'obs'
assert obs.expression.args[0] == sum([sympy.Symbol(c.name) for c in concepts])


def test_add_observable_pattern():
templates = [
NaturalDegradation(subject=Concept(name='A', identifiers={'ido': '0000514'}),
rate_law=sympy.Symbol('alpha') * sympy.Symbol('A')),
NaturalDegradation(subject=Concept(name='B', identifiers={'ido': '0000515'}),
rate_law=sympy.Symbol('alpha') * sympy.Symbol('B')),
]
tm = TemplateModel(templates=templates,
parameters={'alpha': Parameter(name='alpha', value=0.1)})
tm = stratify(tm, key='age', strata=['young', 'old'], structure=[])
add_observable_pattern(tm, Concept(name='A', identifiers={'ido': '0000514'}), 'obs')
assert 'obs' in tm.observables
obs = tm.observables['obs']
assert obs.expression.args[0] == sympy.Symbol('A_old') + sympy.Symbol('A_young')

0 comments on commit a84a9b5

Please sign in to comment.