diff --git a/tests/test_model_api.py b/tests/test_model_api.py index 308777198..5957486cd 100644 --- a/tests/test_model_api.py +++ b/tests/test_model_api.py @@ -702,16 +702,6 @@ def test_deactivation_endpoint(self): strata=['y', 'o'], cartesian_control=True) - # Assert that there are old to young transitions - transition_list = [] - for template in age_strata.templates: - if hasattr(template, 'subject') and hasattr(template, 'outcome'): - subj, outc = template.subject.name, template.outcome.name - if subj.endswith('_o') and outc.endswith('_y') or \ - subj.endswith('_y') and outc.endswith('_o'): - transition_list.append((subj, outc)) - assert len(transition_list), "No old to young transitions found" - amr_sir = AskeNetPetriNetModel(Model(age_strata)).to_json() # Test the endpoint itself @@ -744,6 +734,15 @@ def test_deactivation_endpoint(self): self.assertEqual(422, response.status_code) # Actual Test + # Assert that there are old to young transitions + transition_list = [] + for template in age_strata.templates: + if hasattr(template, 'subject') and hasattr(template, 'outcome'): + subj, outc = template.subject.name, template.outcome.name + if subj.endswith('_o') and outc.endswith('_y') or \ + subj.endswith('_y') and outc.endswith('_o'): + transition_list.append((subj, outc)) + assert len(transition_list), "No old to young transitions found" response = self.client.post( "/api/deactivate_transitions", json={"model": amr_sir, "transitions": transition_list} @@ -772,10 +771,11 @@ def test_deactivation_endpoint(self): tm_deactivated_params = template_model_from_askenet_json( amr_sir_deactivated_params) for template in tm_deactivated_params.templates: - for symb in template.rate_law.atoms(): - if str(symb) == deactivate_key: - assert ( - template.rate_law.rate_law.args[0] == - sympy.core.numbers.Zero(), - template.rate_law - ) + # All rate laws must either be zero or not contain the deactivated + # parameter + if template.rate_law and not template.rate_law.is_zero: + for symb in template.rate_law.atoms(): + assert str(symb) != deactivate_key + else: + assert (template.rate_law.args[0] == sympy.core.numbers.Zero(), + template.rate_law)