Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

273 error required amr has no initials #395

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view

Large diffs are not rendered by default.

Large diffs are not rendered by default.

Large diffs are not rendered by default.

424 changes: 424 additions & 0 deletions notebook/Examples_for_TA2_Model_Representation/model1_no_header.json

Large diffs are not rendered by default.

424 changes: 424 additions & 0 deletions notebook/Examples_for_TA2_Model_Representation/model1_working.json

Large diffs are not rendered by default.

Large diffs are not rendered by default.

18 changes: 16 additions & 2 deletions notebook/april_ensemble/syndata_generate_and_test.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,23 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 3,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "AttributeError",
"evalue": "module 'torch' has no attribute 'fx'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mpyciemss\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mutils\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpetri_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m deterministic\n",
"File \u001b[0;32m~/Projects/pyciemss/src/pyciemss/utils/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpetri_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (seq_id_suffix,\n\u001b[1;32m 2\u001b[0m reparameterize,\n\u001b[1;32m 3\u001b[0m load_sim_result,\n\u001b[1;32m 4\u001b[0m load,\n\u001b[1;32m 5\u001b[0m draw_petri,\n\u001b[1;32m 6\u001b[0m natural_order,\n\u001b[1;32m 7\u001b[0m add_state_indicies,\n\u001b[1;32m 8\u001b[0m register_template,\n\u001b[1;32m 9\u001b[0m natural_conversion,\n\u001b[1;32m 10\u001b[0m natural_degradation,\n\u001b[1;32m 11\u001b[0m natural_order,\n\u001b[1;32m 12\u001b[0m controlled_conversion,\n\u001b[1;32m 13\u001b[0m grouped_controlled_conversion,\n\u001b[1;32m 14\u001b[0m deterministic,\n\u001b[1;32m 15\u001b[0m petri_to_ode,\n\u001b[1;32m 16\u001b[0m order_state,\n\u001b[1;32m 17\u001b[0m reparameterize,\n\u001b[1;32m 18\u001b[0m unorder_state,\n\u001b[1;32m 19\u001b[0m duplicate_petri_net,\n\u001b[1;32m 20\u001b[0m intervene_petri_net)\n\u001b[1;32m 21\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01minference_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (get_tspan,\n\u001b[1;32m 22\u001b[0m state_flux_constraint,\n\u001b[1;32m 23\u001b[0m run_inference,\n\u001b[1;32m 24\u001b[0m is_density_equal,\n\u001b[1;32m 25\u001b[0m is_intervention_density_equal)\n\u001b[1;32m 26\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mplot_utils\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m (setup_ax,\n\u001b[1;32m 27\u001b[0m plot_predictive,\n\u001b[1;32m 28\u001b[0m plot_trajectory,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 31\u001b[0m sideaxis,\n\u001b[1;32m 32\u001b[0m sideaxishist)\n",
"File \u001b[0;32m~/Projects/pyciemss/src/pyciemss/utils/petri_utils.py:264\u001b[0m\n\u001b[1;32m 259\u001b[0m \u001b[38;5;129m@register_template\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mGroupedControlledConversion\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 260\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mgrouped_controlled_conversion\u001b[39m(params: Dict[\u001b[38;5;28mstr\u001b[39m, T], t: T, states: Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]:\n\u001b[1;32m 261\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mNotImplementedError\u001b[39;00m \u001b[38;5;66;03m# TODO\u001b[39;00m\n\u001b[0;32m--> 264\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfx\u001b[49m\u001b[38;5;241m.\u001b[39mwrap\n\u001b[1;32m 265\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mdeterministic\u001b[39m(name, value, \u001b[38;5;241m*\u001b[39m, event_dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m):\n\u001b[1;32m 266\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m pyro\u001b[38;5;241m.\u001b[39mdeterministic(name, value, event_dim\u001b[38;5;241m=\u001b[39mevent_dim)\n\u001b[1;32m 269\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mpetri_to_ode\u001b[39m(\n\u001b[1;32m 270\u001b[0m G: nx\u001b[38;5;241m.\u001b[39mMultiDiGraph,\n\u001b[1;32m 271\u001b[0m funcs: Optional[Dict[\u001b[38;5;28mstr\u001b[39m, Callable[[Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m], T], Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]]]] \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m,\n\u001b[1;32m 272\u001b[0m ) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Callable[[T, Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]], Tuple[T, \u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m\u001b[38;5;241m.\u001b[39m]]:\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'torch' has no attribute 'fx'"
]
}
],
"source": [
"from pyciemss.utils.synth_data_utils import *"
]
Expand Down
16 changes: 14 additions & 2 deletions notebook/april_ensemble/synthetic_data_with_custom_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,19 @@
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named 'pyro'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyro\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpyro\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mdistributions\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mdist\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'pyro'"
]
}
],
"source": [
"import os\n",
"import torch\n",
Expand Down Expand Up @@ -1787,7 +1799,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.10.9"
}
},
"nbformat": 4,
Expand Down
45 changes: 43 additions & 2 deletions src/pyciemss/PetriNetODE/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pyciemss.visuals import plots

import mira
import json

# Load base interfaces
from pyciemss.interfaces import (
Expand Down Expand Up @@ -53,6 +54,46 @@
PetriSolution = dict # NOTE: [str, torch.tensor] type argument removed because of issues with type-based dispatch.
PetriInferredParameters = pyro.nn.PyroModule

def run_petri_model_checks(petri_model_or_path: Union[str, mira.metamodel.TemplateModel, mira.modeling.Model]) -> bool:
"""
Check that the model AMR (a) is readable/formatted correctly, (b) that the number of transitions matches the number of rate laws, and (c) includes initial conditions for each state variable.
"""
amr_dict = {}
# Read the AMR as a dictionary and alert the user if the model path is incorrect, or the AMR in not properly formatted
try:
with open(petri_model_or_path, 'r') as file:
amr_dict = json.load(file)
except FileNotFoundError:
print("File not found. Make sure you have the correct model path.")
except json.JSONDecodeError:
print("Error decoding JSON data. Check that the AMR is formatted correctly.")

# Check that the AMR contains the required header, model, semantics, and metadata keys:
required_keys = ['header', 'model', 'semantics', 'metadata']
missing_keys = [key for key in required_keys if key not in amr_dict]

# Get the number of transitions and rate laws for comparison
num_transitions = len(amr_dict["model"]["transitions"])
num_rate_laws = len(amr_dict["semantics"]["ode"]["rates"])

# Check that every state variable is assigned an initial condition
missing_ids = [state['id'] for state in amr_dict["model"]["states"] if state['id'] not in set(item['target'] for item in amr_dict["semantics"]["ode"]["initials"])]

# Check that the AMR contains the required header, model, semantics, and metadata keys:
if missing_keys:
missing_keys_str = ', '.join(missing_keys)
raise ValueError(f"The AMR is missing: {missing_keys_str}")
# Check that the number of transitions matches the number of rate laws
elif num_transitions < num_rate_laws:
raise ValueError("At least one transition is missing. The number of transitions must equal the number of rate laws.")
elif num_transitions > num_rate_laws:
raise ValueError("There are more transitions than rate laws. The number of transitions must equal the number of rate laws.")
# Check that every state variable is assigned an initial condition
elif missing_ids:
missing_ids_str = ', '.join(missing_ids)
raise ValueError(f"The following state variables do not have corresponding initials: {missing_ids_str}")
else:
return True

def load_petri_model(
petri_model_or_path: Union[str, mira.metamodel.TemplateModel, mira.modeling.Model],
Expand All @@ -65,11 +106,11 @@ def load_petri_model(
"""
Load a petri net from a file and compile it into a probabilistic program.
"""
if noise_model == "scaled_beta":
if (noise_model == "scaled_beta") and run_petri_model_checks(petri_model_or_path):
return ScaledBetaNoisePetriNetODESystem.from_askenet(
petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty
)
elif noise_model == "scaled_normal":
elif (noise_model == "scaled_normal") and run_petri_model_checks(petri_model_or_path):
return ScaledNormalNoisePetriNetODESystem.from_askenet(
petri_model_or_path, noise_scale=noise_scale, compile_rate_law_p=compile_rate_law_p, compile_observables_p=compile_observables_p, add_uncertainty=add_uncertainty
)
Expand Down
2 changes: 1 addition & 1 deletion src/pyciemss/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
natural_order,
controlled_conversion,
grouped_controlled_conversion,
deterministic,
# deterministic,
petri_to_ode,
order_state,
reparameterize,
Expand Down
6 changes: 3 additions & 3 deletions src/pyciemss/utils/petri_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,9 +261,9 @@ def grouped_controlled_conversion(params: Dict[str, T], t: T, states: Tuple[T, .
raise NotImplementedError # TODO


@torch.fx.wrap
def deterministic(name, value, *, event_dim=None):
return pyro.deterministic(name, value, event_dim=event_dim)
# @torch.fx.wrap
# def deterministic(name, value, *, event_dim=None):
# return pyro.deterministic(name, value, event_dim=event_dim)


def petri_to_ode(
Expand Down
Loading