Skip to content

Commit

Permalink
Merge pull request #338 from gyorilab/stratify_name
Browse files Browse the repository at this point in the history
Rename templates during stratification
  • Loading branch information
bgyori authored Jun 13, 2024
2 parents 8fa0de9 + 638c85e commit 7f7dc2a
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
12 changes: 9 additions & 3 deletions mira/metamodel/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ def stratify(
for stratum, stratum_idx in stratum_index_map.items():
template_strata = []
new_template = deepcopy(template)
new_template.name = \
f"{template.name if template.name else 't'}_{stratum}"
# We have to make sure that we only add the stratum to the
# list of template strata if we stratified any of the non-controllers
# in this first for loop
Expand Down Expand Up @@ -227,6 +229,7 @@ def stratify(
# We now apply the stratum assigned to each controller in this particular
# tuple to the controller
for controller, c_stratum in zip(stratified_controllers, c_strata_tuple):
stratified_template.name += f"_{c_stratum}"
controller.with_context(do_rename=modify_names, inplace=True,
**{key: c_stratum})
template_strata.append(c_stratum if param_renaming_uses_strata_names
Expand Down Expand Up @@ -297,7 +300,8 @@ def stratify(
observables[observable_key].expression = SympyExprStr(expr)

# Generate a conversion between each concept of each strata based on the network structure
for (source_stratum, target_stratum), concept in itt.product(structure, concept_map.values()):
for idx, ((source_stratum, target_stratum), concept) in \
enumerate(itt.product(structure, concept_map.values())):
if concept.name in exclude_concepts:
continue
# Get stratum names from map if provided, otherwise use the stratum
Expand All @@ -317,14 +321,16 @@ def stratify(
curie_to_name_map=strata_curie_to_name,
**{key: target_stratum})
# todo will need to generalize for different kwargs for different conversions
template = conversion_cls(subject=subject, outcome=outcome)
template = conversion_cls(subject=subject, outcome=outcome,
name=f't{idx}_{source_stratum_name}_{target_stratum_name}')
template.set_mass_action_rate_law(param_name)
templates.append(template)
if not directed:
param_name = f"p_{target_stratum_name}_{source_stratum_name}"
if param_name not in parameters:
parameters[param_name] = Parameter(name=param_name, value=0.1)
reverse_template = conversion_cls(subject=outcome, outcome=subject)
reverse_template = conversion_cls(subject=outcome, outcome=subject,
name=f't{idx}_{target_stratum_name}_{source_stratum_name}')
reverse_template.set_mass_action_rate_law(param_name)
templates.append(reverse_template)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_stratify_full(self):
)

expected_0 = ControlledConversion(
name="t_unvaccinated_unvaccinated",
subject=susceptible.with_context(vaccination_status="unvaccinated",
do_rename=True),
outcome=infected.with_context(vaccination_status="unvaccinated",
Expand All @@ -54,6 +55,7 @@ def test_stratify_full(self):
)
)
expected_1 = ControlledConversion(
name="t_unvaccinated_vaccinated",
subject=susceptible.with_context(vaccination_status="unvaccinated",
do_rename=True),
outcome=infected.with_context(vaccination_status="unvaccinated",
Expand All @@ -66,6 +68,7 @@ def test_stratify_full(self):
)
)
expected_2 = ControlledConversion(
name="t_vaccinated_unvaccinated",
subject=susceptible.with_context(vaccination_status="vaccinated",
do_rename=True),
outcome=infected.with_context(vaccination_status="vaccinated",
Expand All @@ -78,6 +81,7 @@ def test_stratify_full(self):
)
)
expected_3 = ControlledConversion(
name="t_vaccinated_vaccinated",
subject=susceptible.with_context(vaccination_status="vaccinated",
do_rename=True),
outcome=infected.with_context(vaccination_status="vaccinated",
Expand Down

0 comments on commit 7f7dc2a

Please sign in to comment.