Skip to content

Commit

Permalink
Optimize interface for TA4 integration (#204)
Browse files Browse the repository at this point in the history
* Optimize TA4 interface

* Update interfaces.py

* Change setup to fix MIRA name error (#205) (#206)

* changed mira version before bug

* fix error

Co-authored-by: Sam Witty <[email protected]>

* Update demo notebook

* Update demo.ipynb

* Update MIRA loading (#209)

* Change setup to fix MIRA name error (#205)

* changed mira version before bug

* fix error

* Add utilities for loading distributions from AMR (#200)

* added mira distribution loading

* added normal distribution

* fixed Normal2 and Normal3

* nit

* added minimal mira_distribution_to_pyro test

---------

Co-authored-by: Sam Witty <[email protected]>

* Updated interface with string inputs for QOI

* Interface to optimize after calibrating

* Update demo.ipynb

* Update sample csv file to avoid merge conflicts

* Updating file to avoid merge conflict with main

* Updating from main to fix model loading errors (#211)

* Change setup to fix MIRA name error (#205)

* changed mira version before bug

* fix error

* Add utilities for loading distributions from AMR (#200)

* added mira distribution loading

* added normal distribution

* fixed Normal2 and Normal3

* nit

* added minimal mira_distribution_to_pyro test

* Symbolic Rate law to Pytorch Rate law (#201)

* I believe I wrote the correct code, based on experiments in the notebook. Will test next.

* FAILED test/test_mira/test_rate_law.py::TestRateLaw::test_rate_law_compilation - AttributeError: 'ScaledBetaNoisePetriNetODESystem' object has no attribute 'beta'

* Added Symbolic_Deriv_Experiments notebook

* Something weird is happening. I can confirm that 'beta' is an attribute of ScaledBetaNoisePetriNetODESystem after setting up the model, but then it can't be found at sample time

* Clarified the bug in the Symbolic derivatives notebook

* Expected and actual derivative match

* Time varying parameter rate law correctly read

* Thought we added this already

* Added kwargs to from_askenet and from_mira and compile_rate_law_p to load_petri_net

* Blocked on gyorilab/mira#189 but tests pass by making compile_rate_law_p False by default

* Removed unnecessary pygraphviz dependency

* Unit test to fail when concept name does not equal rate law symbols

* All tests pass with default compile_rate_law_p = False

* Merged from main. removed dependency on older version of mira

* point mira to the github repo main branch

* point mira to the github repo main branch

* load_and_calibrate_and_sample(..., compile_rate_law_p=True) works with the caveat that the ScaledBetaNoisePetriNetODESystem solution was returning very slightly negative values, so I set mean = torch.abs(solution[var_name]) to address the issue

* merged changes to MiraPetriNetODESystem and ScaledBetaNoisePetriNetODESystem from main.  ScaledBetaNoisePetriNetODESystem has default compiled_rate_law_p=True

* observation_function for ScaledBetaNoisePetriNetODESystem now uses torch.maximum(solution[var_name], torch.tensor(1e-9)) to deal with overshooting derivatives

* aggregate parameters is now by default opt-out, and AMR models with multiple parameters per transition can be interpreted using compile_rate_law

---------

Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>

* Update notebooks to avoid merge conflicts

* Update NB to resolve conflicts

* Update from main to avoid conflicts (#212)

* Change setup to fix MIRA name error (#205)

* changed mira version before bug

* fix error

* Add utilities for loading distributions from AMR (#200)

* added mira distribution loading

* added normal distribution

* fixed Normal2 and Normal3

* nit

* added minimal mira_distribution_to_pyro test

* Symbolic Rate law to Pytorch Rate law (#201)

* I believe I wrote the correct code, based on experiments in the notebook. Will test next.

* FAILED test/test_mira/test_rate_law.py::TestRateLaw::test_rate_law_compilation - AttributeError: 'ScaledBetaNoisePetriNetODESystem' object has no attribute 'beta'

* Added Symbolic_Deriv_Experiments notebook

* Something weird is happening. I can confirm that 'beta' is an attribute of ScaledBetaNoisePetriNetODESystem after setting up the model, but then it can't be found at sample time

* Clarified the bug in the Symbolic derivatives notebook

* Expected and actual derivative match

* Time varying parameter rate law correctly read

* Thought we added this already

* Added kwargs to from_askenet and from_mira and compile_rate_law_p to load_petri_net

* Blocked on gyorilab/mira#189 but tests pass by making compile_rate_law_p False by default

* Removed unnecessary pygraphviz dependency

* Unit test to fail when concept name does not equal rate law symbols

* All tests pass with default compile_rate_law_p = False

* Merged from main. removed dependency on older version of mira

* point mira to the github repo main branch

* point mira to the github repo main branch

* load_and_calibrate_and_sample(..., compile_rate_law_p=True) works with the caveat that the ScaledBetaNoisePetriNetODESystem solution was returning very slightly negative values, so I set mean = torch.abs(solution[var_name]) to address the issue

* merged changes to MiraPetriNetODESystem and ScaledBetaNoisePetriNetODESystem from main.  ScaledBetaNoisePetriNetODESystem has default compiled_rate_law_p=True

* observation_function for ScaledBetaNoisePetriNetODESystem now uses torch.maximum(solution[var_name], torch.tensor(1e-9)) to deal with overshooting derivatives

* aggregate parameters is now by default opt-out, and AMR models with multiple parameters per transition can be interpreted using compile_rate_law

* 12-Month Hackathon Notebooks (#207)

* started hackathon prep scenario notebooks

* more on hackathon scenarios

* more work on hackathon notebooks

* created AMR for scenario1 with constant beta

* Vs hackathon prep (#203)

* beginning of ensemble challenge ipynb

* Update ensemble_challenge.ipynb

* added to scenario2 notebook, ready for scenario3

* added to scenario2 notebook, ready for scenario3

* Update scenario1.ipynb

* Update scenario1

* changes to scenario3, and new AMR

* pre-hackathon prep update

- scenario1 task 1 almost done (pending AMR changes)

- ensemble challenge layout started

* more on scenario notebooks, added AMR to scenario2

* Fixed the nan inside the intervened parameters column output

* updates to scenario3 notebook

* merged from main

---------

Co-authored-by: vsella <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>

---------

Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>
Co-authored-by: sabinala <[email protected]>
Co-authored-by: vsella <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>

* Update demo.ipynb

* Update demo.ipynb

* updated demo with new MIRA model

* Fixed setup.cfg

* Update demo.ipynb

* Updated solution key in demo

* finished demo notebook run

* Added test for samples obtained after OUU

* Update output format and description

* Fixed mira on setup to help pass tests

* Added `self.intervened_samples` to `test_ode_interfaces'

---------

Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>
Co-authored-by: sabinala <[email protected]>
Co-authored-by: vsella <[email protected]>
Co-authored-by: Jeremy Zucker <[email protected]>
  • Loading branch information
6 people authored Jul 12, 2023
1 parent df2d47a commit 3327ede
Show file tree
Hide file tree
Showing 10 changed files with 2,167 additions and 624 deletions.
2 changes: 1 addition & 1 deletion notebook/december_demo/scenario2.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@
],
"source": [
"OBJFUN = lambda x: np.abs(x)\n",
"INTERVENTION= {\"VaccinationParam\": [7.5, \"nu\"]}\n",
"INTERVENTION= [(7.5, \"nu\")]\n",
"QOI = lambda y: scenario2dec_nday_average(y, contexts=[\"I_obs\"], ndays=7)\n",
"timepoints_qoi = range(83,90)\n",
"ouu_policy = optimize(initialized_petri_net_ode_model,\n",
Expand Down
257 changes: 207 additions & 50 deletions notebook/integration_demo/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,23 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from pyciemss.PetriNetODE.interfaces import (\n",
" load_and_sample_petri_model,\n",
" load_and_calibrate_and_sample_petri_model\n",
")"
" load_and_calibrate_and_sample_petri_model,\n",
" load_and_optimize_and_sample_petri_model,\n",
" load_and_calibrate_and_optimize_and_sample_petri_model\n",
")\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -33,7 +36,7 @@
},
{
"cell_type": "code",
"execution_count": 5,
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -49,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 10,
"metadata": {
"scrolled": true
},
Expand Down Expand Up @@ -79,53 +82,53 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration 0: loss = 62.12709617614746\n",
"iteration 25: loss = 60.099618911743164\n",
"iteration 50: loss = 58.367534935474396\n",
"iteration 75: loss = 57.764896750450134\n",
"iteration 100: loss = 56.60384750366211\n",
"iteration 125: loss = 56.626110792160034\n",
"iteration 150: loss = 56.81937026977539\n",
"iteration 175: loss = 56.734710931777954\n",
"iteration 200: loss = 56.507239818573\n",
"iteration 225: loss = 56.81278920173645\n",
"iteration 250: loss = 56.38322448730469\n",
"iteration 275: loss = 56.98680853843689\n",
"iteration 300: loss = 56.82800793647766\n",
"iteration 325: loss = 56.91287016868591\n",
"iteration 350: loss = 56.571919679641724\n",
"iteration 375: loss = 56.62337040901184\n",
"iteration 400: loss = 56.986464977264404\n",
"iteration 425: loss = 56.08261704444885\n",
"iteration 450: loss = 56.555689573287964\n",
"iteration 475: loss = 56.814436197280884\n",
"iteration 500: loss = 56.36864924430847\n",
"iteration 525: loss = 56.35692620277405\n",
"iteration 550: loss = 56.53825569152832\n",
"iteration 575: loss = 56.52056694030762\n",
"iteration 600: loss = 56.57140803337097\n",
"iteration 625: loss = 56.3572359085083\n",
"iteration 650: loss = 56.788116693496704\n",
"iteration 675: loss = 56.779704332351685\n",
"iteration 700: loss = 56.59728670120239\n",
"iteration 725: loss = 56.58475613594055\n",
"iteration 750: loss = 56.811803579330444\n",
"iteration 775: loss = 56.95957016944885\n",
"iteration 800: loss = 56.53050971031189\n",
"iteration 825: loss = 57.138880014419556\n",
"iteration 850: loss = 56.93205904960632\n",
"iteration 875: loss = 56.596068143844604\n",
"iteration 900: loss = 56.303569078445435\n",
"iteration 925: loss = 56.353702545166016\n",
"iteration 950: loss = 56.812726736068726\n",
"iteration 975: loss = 56.6675009727478\n"
"iteration 0: loss = 61.70457148551941\n",
"iteration 25: loss = 56.963383436203\n",
"iteration 50: loss = 57.868717432022095\n",
"iteration 75: loss = 57.586416482925415\n",
"iteration 100: loss = 56.53365778923035\n",
"iteration 125: loss = 56.85279989242554\n",
"iteration 150: loss = 56.778077602386475\n",
"iteration 175: loss = 56.841498613357544\n",
"iteration 200: loss = 56.82657432556152\n",
"iteration 225: loss = 57.287031173706055\n",
"iteration 250: loss = 56.763044357299805\n",
"iteration 275: loss = 56.49668288230896\n",
"iteration 300: loss = 57.06786513328552\n",
"iteration 325: loss = 56.646276235580444\n",
"iteration 350: loss = 56.41763639450073\n",
"iteration 375: loss = 56.383132457733154\n",
"iteration 400: loss = 56.655441999435425\n",
"iteration 425: loss = 56.522589683532715\n",
"iteration 450: loss = 56.92046332359314\n",
"iteration 475: loss = 56.76862335205078\n",
"iteration 500: loss = 56.490626096725464\n",
"iteration 525: loss = 56.579216718673706\n",
"iteration 550: loss = 56.69615292549133\n",
"iteration 575: loss = 56.92038869857788\n",
"iteration 600: loss = 56.410447120666504\n",
"iteration 625: loss = 56.608824491500854\n",
"iteration 650: loss = 56.606239557266235\n",
"iteration 675: loss = 56.5149290561676\n",
"iteration 700: loss = 56.694045066833496\n",
"iteration 725: loss = 56.6541702747345\n",
"iteration 750: loss = 56.73468089103699\n",
"iteration 775: loss = 56.48121380805969\n",
"iteration 800: loss = 56.6811363697052\n",
"iteration 825: loss = 56.99781060218811\n",
"iteration 850: loss = 56.98245882987976\n",
"iteration 875: loss = 56.80254626274109\n",
"iteration 900: loss = 56.61562776565552\n",
"iteration 925: loss = 56.61105751991272\n",
"iteration 950: loss = 56.783122062683105\n",
"iteration 975: loss = 56.69991850852966\n"
]
}
],
Expand All @@ -148,13 +151,167 @@
" os.path.join(DEMO_PATH, \"results_petri/calibrated_sample_results.csv\"), index=False\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## load_and_optimize_and_sample_petri_model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Time taken: (1.76e-01 seconds per model evaluation).\n",
"Performing risk-based optimization under uncertainty (using alpha-superquantile)\n",
"Estimated wait time 1321.9 seconds...\n",
"Optimization completed in time 485.78 seconds.\n",
"Optimal policy:\t0.0\n",
"Post-processing optimal policy...\n",
"Estimated risk at optimal policy [0.7027968883514404]\n",
"Optimal policy: 0.0\n",
"Estimated risk at optimal policy [0.7027968883514404]\n"
]
}
],
"source": [
"num_samples = 100\n",
"timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]\n",
"OBJFUN = lambda x: np.abs(x)\n",
"INTERVENTION = [(0.1, \"beta\")]\n",
"QOI = (\"scenario2dec_nday_average\", \"I_sol\", 2)\n",
"# Run the optimization and sampling\n",
"ouu_samples, opt_policy = load_and_optimize_and_sample_petri_model(\n",
" ASKENET_PATH,\n",
" num_samples,\n",
" timepoints=timepoints,\n",
" interventions=INTERVENTION,\n",
" qoi=QOI,\n",
" risk_bound=10.,\n",
" objfun=OBJFUN,\n",
" initial_guess=0.02,\n",
" bounds=[[0.],[3.]],\n",
" verbose=True,\n",
")\n",
"\n",
"# Save results\n",
"ouu_samples.to_csv(\n",
" os.path.join(DEMO_PATH, \"results_petri/optimize_sample_results.csv\"), index=False\n",
")\n",
"print(\"Optimal policy:\", opt_policy[\"policy\"])\n",
"print(\"Estimated risk at optimal policy\", opt_policy[\"risk\"])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## load_and_calibrate_and_optimize_and_sample_petri_model"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"iteration 0: loss = 37.5659122467041\n",
"iteration 25: loss = 35.41392183303833\n",
"iteration 50: loss = 33.67425513267517\n",
"iteration 75: loss = 32.595038652420044\n",
"iteration 100: loss = 33.10420513153076\n",
"iteration 125: loss = 32.42216110229492\n",
"iteration 150: loss = 33.292160987854004\n",
"iteration 175: loss = 31.957093477249146\n",
"iteration 200: loss = 32.40791869163513\n",
"iteration 225: loss = 32.1630494594574\n",
"iteration 250: loss = 32.64620661735535\n",
"iteration 275: loss = 32.447582960128784\n",
"iteration 300: loss = 32.42112612724304\n",
"iteration 325: loss = 32.70798373222351\n",
"iteration 350: loss = 32.1295063495636\n",
"iteration 375: loss = 32.78153920173645\n",
"iteration 400: loss = 32.15399241447449\n",
"iteration 425: loss = 32.18069672584534\n",
"iteration 450: loss = 32.34887766838074\n",
"iteration 475: loss = 32.39561104774475\n",
"iteration 500: loss = 32.51458191871643\n",
"iteration 525: loss = 32.285157680511475\n",
"iteration 550: loss = 32.21916651725769\n",
"iteration 575: loss = 32.54395031929016\n",
"iteration 600: loss = 32.487563133239746\n",
"iteration 625: loss = 32.54312562942505\n",
"iteration 650: loss = 31.965022087097168\n",
"iteration 675: loss = 32.316070318222046\n",
"iteration 700: loss = 32.36382842063904\n",
"iteration 725: loss = 32.52105975151062\n",
"iteration 750: loss = 32.83406376838684\n",
"iteration 775: loss = 32.393730878829956\n",
"iteration 800: loss = 32.43019080162048\n",
"iteration 825: loss = 32.3959174156189\n",
"iteration 850: loss = 32.112430572509766\n",
"iteration 875: loss = 32.36790990829468\n",
"iteration 900: loss = 32.66983437538147\n",
"iteration 925: loss = 32.438520669937134\n",
"iteration 950: loss = 32.27924203872681\n",
"iteration 975: loss = 32.38777709007263\n",
"Time taken: (2.09e-01 seconds per model evaluation).\n",
"Performing risk-based optimization under uncertainty (using alpha-superquantile)\n",
"Estimated wait time 1567.9 seconds...\n",
"Optimization completed in time 395.12 seconds.\n",
"Optimal policy:\t0.0\n",
"Post-processing optimal policy...\n",
"Estimated risk at optimal policy [0.6986928939819336]\n",
"Optimal policy after calibration: 0.0\n",
"Estimated risk at optimal policy after calibration [0.7027968883514404]\n"
]
}
],
"source": [
"data_path = os.path.join(DEMO_PATH, \"data.csv\")\n",
"num_samples = 100\n",
"timepoints = [0.0, 1.0, 2.0, 3.0, 4.0]\n",
"OBJFUN = lambda x: np.abs(x)\n",
"INTERVENTION = [(0.1, \"beta\")]\n",
"QOI = (\"scenario2dec_nday_average\", \"I_sol\", 2)\n",
"# Run the calibration, optimization, and sampling\n",
"ouu_cal_samples, opt_cal_policy = load_and_calibrate_and_optimize_and_sample_petri_model(\n",
" ASKENET_PATH,\n",
" data_path,\n",
" num_samples,\n",
" timepoints=timepoints,\n",
" interventions=INTERVENTION,\n",
" qoi=QOI,\n",
" risk_bound=10.,\n",
" objfun=OBJFUN,\n",
" initial_guess=0.02,\n",
" bounds=[[0.],[3.]],\n",
" verbose=True,\n",
")\n",
"\n",
"# Save results\n",
"ouu_cal_samples.to_csv(\n",
" os.path.join(DEMO_PATH, \"results_petri/calibrate_optimize_sample_results.csv\"), index=False\n",
")\n",
"print(\"Optimal policy after calibration:\", opt_policy[\"policy\"])\n",
"print(\"Estimated risk at optimal policy after calibration\", opt_policy[\"risk\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pyciemss-main",
"display_name": "askemv2",
"language": "python",
"name": "pyciemss-main"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -166,7 +323,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.9"
"version": "3.11.4"
}
},
"nbformat": 4,
Expand Down
Loading

0 comments on commit 3327ede

Please sign in to comment.