diff --git a/mira/metamodel/ops.py b/mira/metamodel/ops.py index a7b11b58f..a765644a5 100644 --- a/mira/metamodel/ops.py +++ b/mira/metamodel/ops.py @@ -10,6 +10,7 @@ from .template_model import TemplateModel, Initial, Parameter, Observable from .templates import * +from .comparison import get_dkg_refinement_closure from .units import Unit from .utils import SympyExprStr @@ -748,7 +749,8 @@ def deactivate_templates( def add_observable_pattern( template_model: TemplateModel, concept_pattern: Concept, - name: str + name: str, + refinement_func=None, ): """Add an observable for a pattern of concepts. @@ -766,10 +768,11 @@ def add_observable_pattern( : A template model with the observable added. """ - + if refinement_func is None: + refinement_func = get_dkg_refinement_closure().is_ontological_child observable_concepts = [] - for key, concept in template_model.get_concepts_map(): - if concept.refinement_of(concept_pattern): + for key, concept in template_model.get_concepts_map().items(): + if concept.refinement_of(concept_pattern, refinement_func): observable_concepts.append(concept) obs = get_observable_for_concepts(observable_concepts, name) template_model.observables[name] = obs diff --git a/tests/test_ops.py b/tests/test_ops.py index f2151077c..a469c08e0 100644 --- a/tests/test_ops.py +++ b/tests/test_ops.py @@ -8,7 +8,9 @@ import sympy from mira.metamodel import * -from mira.metamodel.ops import stratify, simplify_rate_law, counts_to_dimensionless +from mira.metamodel.ops import stratify, simplify_rate_law, \ + counts_to_dimensionless, add_observable_pattern, \ + get_observable_for_concepts from mira.examples.sir import cities, sir, sir_2_city, sir_parameterized from mira.examples.concepts import infected, susceptible from mira.examples.chime import sviivr @@ -582,3 +584,30 @@ def test_stratify_parameter_consistency(): # be the case when parameters would be incrementally numbered for each # new template assert len(tm.parameters) == 2 + + +def test_get_observable_for_concepts(): + concepts = [ + Concept(name='A'), + Concept(name='B'), + Concept(name='C'), + ] + obs = get_observable_for_concepts(concepts, 'obs') + assert obs.name == 'obs' + assert obs.expression.args[0] == sum([sympy.Symbol(c.name) for c in concepts]) + + +def test_add_observable_pattern(): + templates = [ + NaturalDegradation(subject=Concept(name='A', identifiers={'ido': '0000514'}), + rate_law=sympy.Symbol('alpha') * sympy.Symbol('A')), + NaturalDegradation(subject=Concept(name='B', identifiers={'ido': '0000515'}), + rate_law=sympy.Symbol('alpha') * sympy.Symbol('B')), + ] + tm = TemplateModel(templates=templates, + parameters={'alpha': Parameter(name='alpha', value=0.1)}) + tm = stratify(tm, key='age', strata=['young', 'old'], structure=[]) + add_observable_pattern(tm, Concept(name='A', identifiers={'ido': '0000514'}), 'obs') + assert 'obs' in tm.observables + obs = tm.observables['obs'] + assert obs.expression.args[0] == sympy.Symbol('A_old') + sympy.Symbol('A_young')