Skip to content

Commit

Permalink
Provisional support for time_unit in TA4 interface dataframe (#221)
Browse files Browse the repository at this point in the history
* Provisional support for time_unit in TA4 interface dataframe

* Missed format identifier

* Default time-unit

* Testing time_unit behaviors

* Updating tests for new TA4 itnerface behavior

* To preserve context in failure messages, changed  to
  • Loading branch information
JosephCottam authored Jul 12, 2023
1 parent df5efba commit 533965d
Show file tree
Hide file tree
Showing 6 changed files with 609 additions and 279 deletions.
66 changes: 42 additions & 24 deletions src/pyciemss/utils/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ def convert_to_output_format(
samples: Dict[str, torch.Tensor],
timepoints: Iterable[float],
interventions: Optional[Dict[str, torch.Tensor]] = None,
*,
time_unit: Optional[str] = "(unknown)",
) -> pd.DataFrame:
"""
Convert the samples from the Pyro model to a DataFrame in the TA4 requested format.
time_unit -- Label timepoints in a semantically relevant way `timepoint_<time_unit>`.
If ommited, a `timepoint_<time_unit>` field is not provided.
"""

pyciemss_results = {"parameters": {}, "states": {}}
Expand All @@ -30,11 +35,7 @@ def convert_to_output_format(
n_models = sample.shape[1]
for i in range(n_models):
pyciemss_results["parameters"][f"model_{i}_weight"] = (
sample[:, i]
.data.detach()
.cpu()
.numpy()
.astype(np.float64)
sample[:, i].data.detach().cpu().numpy().astype(np.float64)
)
else:
pyciemss_results["states"][name] = (
Expand All @@ -59,7 +60,9 @@ def convert_to_output_format(
else:
d = {
**d,
**assign_interventions_to_timepoints(interventions, timepoints, pyciemss_results["parameters"])
**assign_interventions_to_timepoints(
interventions, timepoints, pyciemss_results["parameters"]
),
}

# Solution (state variables)
Expand All @@ -71,7 +74,12 @@ def convert_to_output_format(
},
}

return pd.DataFrame(d)
result = pd.DataFrame(d)
if time_unit is not None:
all_timepoints = result["timepoint_id"].map(lambda v: timepoints[v])
result = result.assign(**{f"timepoint_{time_unit}": all_timepoints})

return result


def csv_to_list(filename):
Expand All @@ -88,35 +96,43 @@ def csv_to_list(filename):
return result


def interventions_and_sampled_params_to_interval(interventions: dict, sampled_params: dict) -> dict:
def interventions_and_sampled_params_to_interval(
interventions: dict, sampled_params: dict
) -> dict:
"""Convert interventions and sampled parameters to dict of intervals.
:param interventions: dict keyed by parameter name where each value is a tuple (intervention_time, value)
:param sampled_params: dict keyed by param where each value is an array of sampled parameter values
:return: dict keyed by param where the values lists of intervals and values sorted by start time
"""
# assign each sampled parameter to an infinite interval
param_dict = {param: [dict(start=-np.inf, end=np.inf, param_values=value)]
for param, value in sampled_params.items()}

param_dict = {
param: [dict(start=-np.inf, end=np.inf, param_values=value)]
for param, value in sampled_params.items()
}

# sort the interventions by start time
for start, param, intervention_value in sorted(interventions):

# update the end time of the previous interval
param_dict[f"{param}_param"][-1]['end'] = start
param_dict[f"{param}_param"][-1]["end"] = start

# add new interval and broadcast the intevention value to the size of the sampled parameters
param_dict[f"{param}_param"].append(
dict(start=start, end=np.inf, param_values=[intervention_value]*len(sampled_params[f"{param}_param"])))

dict(
start=start,
end=np.inf,
param_values=[intervention_value]
* len(sampled_params[f"{param}_param"]),
)
)

# sort intervals by start time
return {
k: sorted(v, key=lambda x: x['start'])
for k, v in param_dict.items()
}
return {k: sorted(v, key=lambda x: x["start"]) for k, v in param_dict.items()}


def assign_interventions_to_timepoints(interventions: dict, timepoints: Iterable[float], sampled_params: dict) -> dict:
def assign_interventions_to_timepoints(
interventions: dict, timepoints: Iterable[float], sampled_params: dict
) -> dict:
"""Assign the value of each parameter to every timepoint, taking into account interventions.
:param interventions: dict keyed by parameter name where each value is a tuple (intervention_time, value)
Expand All @@ -125,11 +141,13 @@ def assign_interventions_to_timepoints(interventions: dict, timepoints: Iterable
:return: dict keyed by param where the values are sorted by sample then timepoint
"""
# transform interventions and sampled parameters into intervals
param_interval_dict = interventions_and_sampled_params_to_interval(interventions, sampled_params)
param_interval_dict = interventions_and_sampled_params_to_interval(
interventions, sampled_params
)
result = {}
for param, interval_dict in param_interval_dict.items():
intervals = [(d['start'], d['end']) for d in interval_dict]
param_values = [d['param_values'] for d in interval_dict]
intervals = [(d["start"], d["end"]) for d in interval_dict]
param_values = [d["param_values"] for d in interval_dict]

# generate list of parameter values at each timepoint
result[param] = []
Expand Down
139 changes: 94 additions & 45 deletions test/test_ensemble/test_ensemble_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import unittest
import os
import copy

import torch
import pandas as pd
import numpy as np

from pyciemss.PetriNetODE.base import MiraPetriNetODESystem, ScaledBetaNoisePetriNetODESystem
from pyciemss.PetriNetODE.events import Event, StartEvent, LoggingEvent, ObservationEvent, StaticParameterInterventionEvent
from pyciemss.Ensemble.base import EnsembleSystem
import pyciemss

from pyciemss.PetriNetODE.interfaces import load_petri_model
from pyciemss.Ensemble.interfaces import load_and_sample_petri_ensemble, load_and_calibrate_and_sample_ensemble_model, setup_model, reset_model, intervene, sample, calibrate, optimize
from pyciemss.Ensemble.interfaces import (
load_and_sample_petri_ensemble,
load_and_calibrate_and_sample_ensemble_model,
setup_model,
sample,
calibrate,
)

class Test_Samples_Format(unittest.TestCase):

class Test_Samples_Format(unittest.TestCase):
"""Tests for the output of PetriNetODE.interfaces.load_*_sample_petri_net_model."""

# Setup for the tests
Expand All @@ -29,7 +31,6 @@ def setUp(self):
def solution_mapping1(model1_solution: dict) -> dict:
return model1_solution


def solution_mapping2(model2_solution: dict) -> dict:
mapped_solution = {}
mapped_solution["S"] = (
Expand Down Expand Up @@ -59,11 +60,7 @@ def solution_mapping2(model2_solution: dict) -> dict:
self.num_timepoints = len(timepoints)

self.samples = load_and_sample_petri_ensemble(
ASKENET_PATHS,
weights,
solution_mappings,
self.num_samples,
timepoints
ASKENET_PATHS, weights, solution_mappings, self.num_samples, timepoints
)

data_path = os.path.join(DEMO_PATH, "data.csv")
Expand All @@ -76,7 +73,7 @@ def solution_mapping2(model2_solution: dict) -> dict:
self.num_samples,
timepoints,
total_population=1000,
num_iterations=5
num_iterations=5,
)

def test_samples_type(self):
Expand All @@ -95,7 +92,9 @@ def test_samples_column_names(self):
for s in [self.samples, self.calibrated_samples]:
self.assertEqual(list(s.columns)[:2], ["timepoint_id", "sample_id"])
for col_name in s.columns[2:]:
self.assertIn(col_name.split("_")[-1], ("param", "sol", "weight"))
self.assertIn(
col_name.split("_")[-1], ("param", "sol", "weight", "(unknown)")
)

def test_samples_dtype(self):
"""Test that `samples` has the required data types"""
Expand All @@ -107,27 +106,33 @@ def test_samples_dtype(self):


class TestEnsembleInterfaces(unittest.TestCase):
'''Tests for the Ensemble interfaces.'''
"""Tests for the Ensemble interfaces."""

# Setup for the tests
def setUp(self):
MIRA_PATH = "test/models/april_ensemble_demo/"

filename1 = "BIOMD0000000955_template_model.json"
self.filename1 = os.path.join(MIRA_PATH, filename1)
self.model1 = load_petri_model(self.filename1, add_uncertainty=True)
self.start_state1 = {k[0]: v.data['initial_value'] for k, v in self.model1.G.variables.items()}
self.start_state1 = {
k[0]: v.data["initial_value"] for k, v in self.model1.G.variables.items()
}

filename2 = "BIOMD0000000960_template_model.json"
self.filename2 = os.path.join(MIRA_PATH, filename2)
self.model2 = load_petri_model(self.filename2, add_uncertainty=True)
self.start_state2 = {k[0]: v.data['initial_value'] for k, v in self.model2.G.variables.items()}
self.start_state2 = {
k[0]: v.data["initial_value"] for k, v in self.model2.G.variables.items()
}

self.initial_time = 0.0

solution_ratio = self.start_state2['Infectious'] / self.start_state1['Infected']
self.solution_mapping1 = lambda x : {"Infected": x["Infected"]}
self.solution_mapping2 = lambda x: {"Infected": x["Infectious"] / solution_ratio}
solution_ratio = self.start_state2["Infectious"] / self.start_state1["Infected"]
self.solution_mapping1 = lambda x: {"Infected": x["Infected"]}
self.solution_mapping2 = lambda x: {
"Infected": x["Infectious"] / solution_ratio
}

self.models = [self.model1, self.model2]
self.weights = [0.5, 0.5]
Expand All @@ -137,19 +142,32 @@ def setUp(self):
self.noise_pseudocount = 1.0

def test_solution_mapping(self):
'''Test the solution_mapping function.'''
"""Test the solution_mapping function."""

# Test that the solution_mapping results in the same variables.
self.assertEqual(set(self.solution_mapping1(self.start_state1).keys()),
set(self.solution_mapping2(self.start_state2).keys()))
self.assertEqual(
set(self.solution_mapping1(self.start_state1).keys()),
set(self.solution_mapping2(self.start_state2).keys()),
)

#Assert that the solution_mapping results in the same values.
self.assertAlmostEqual(self.solution_mapping1(self.start_state1)["Infected"],
self.solution_mapping2(self.start_state2)["Infected"])
# Assert that the solution_mapping results in the same values.
self.assertAlmostEqual(
self.solution_mapping1(self.start_state1)["Infected"],
self.solution_mapping2(self.start_state2)["Infected"],
)

def test_setup_model(self):
'''Test the setup_model function.'''
ensemble = setup_model(self.models, self.weights, self.solution_mappings, self.initial_time, [self.start_state1, self.start_state2], self.total_population, self.noise_pseudocount, self.dirichlet_concentration)
"""Test the setup_model function."""
ensemble = setup_model(
self.models,
self.weights,
self.solution_mappings,
self.initial_time,
[self.start_state1, self.start_state2],
self.total_population,
self.noise_pseudocount,
self.dirichlet_concentration,
)

# Test that the model is an EnsembleSystem
self.assertIsInstance(ensemble, EnsembleSystem)
Expand All @@ -161,8 +179,13 @@ def test_setup_model(self):
self.assertEqual(len(ensemble.dirichlet_alpha), len(self.weights))

# Test that the model had the correct weight values
self.assertTrue(torch.all(ensemble.dirichlet_alpha == torch.as_tensor(self.weights) * self.dirichlet_concentration))
self.assertTrue(torch.all(ensemble.dirichlet_alpha == torch.tensor([0.5, 0.5])))
self.assertTrue(
torch.all(
ensemble.dirichlet_alpha
== torch.as_tensor(self.weights) * self.dirichlet_concentration
)
)
self.assertTrue(torch.all(ensemble.dirichlet_alpha == torch.tensor([0.5, 0.5])))

# Test that the model has the correct number of solution_mappings
self.assertEqual(len(ensemble.solution_mappings), len(self.solution_mappings))
Expand All @@ -172,33 +195,59 @@ def test_setup_model(self):
self.assertEqual(len(ensemble.models), len(ensemble.solution_mappings))

def test_calibrate(self):
'''Test the calibrate function.'''
ensemble = setup_model(self.models, self.weights, self.solution_mappings, self.initial_time, [self.start_state1, self.start_state2], self.total_population, self.noise_pseudocount, self.dirichlet_concentration)

"""Test the calibrate function."""
ensemble = setup_model(
self.models,
self.weights,
self.solution_mappings,
self.initial_time,
[self.start_state1, self.start_state2],
self.total_population,
self.noise_pseudocount,
self.dirichlet_concentration,
)

data = [(1.1, {"Infected": 0.003}), (1.2, {"Infected": 0.005})]
parameters = calibrate(ensemble, data, num_iterations=2)

self.assertIsNotNone(parameters)

def test_sample(self):
'''Test the sample function.'''
ensemble = setup_model(self.models, self.weights, self.solution_mappings, self.initial_time, [self.start_state1, self.start_state2], self.total_population, self.noise_pseudocount, self.dirichlet_concentration)

"""Test the sample function."""
ensemble = setup_model(
self.models,
self.weights,
self.solution_mappings,
self.initial_time,
[self.start_state1, self.start_state2],
self.total_population,
self.noise_pseudocount,
self.dirichlet_concentration,
)

timepoints = [1.0, 5.0, 10.0]
num_samples = 10
# Test that sample works without inferred parameters
simulation = sample(ensemble, timepoints, num_samples)

self.assertEqual(simulation['Infected_sol'].shape[0], num_samples)
self.assertEqual(simulation['Infected_sol'].shape[1], len(timepoints))

data = [(0.2, {"Infected": 0.1}), (0.4, {"Infected": 0.2}), (0.6, {"Infected": 0.3})]

self.assertEqual(simulation["Infected_sol"].shape[0], num_samples)
self.assertEqual(simulation["Infected_sol"].shape[1], len(timepoints))

data = [
(0.2, {"Infected": 0.1}),
(0.4, {"Infected": 0.2}),
(0.6, {"Infected": 0.3}),
]
parameters = calibrate(ensemble, data, num_iterations=2)
# Test that sample works with inferred parameters
simulation = sample(ensemble, timepoints, num_samples, parameters)

self.assertEqual(simulation['Infected_sol'].shape[0], num_samples)
self.assertEqual(simulation['Infected_sol'].shape[1], len(timepoints))
self.assertEqual(simulation["Infected_sol"].shape[0], num_samples)
self.assertEqual(simulation["Infected_sol"].shape[1], len(timepoints))

# Test that samples are different when num_samples > 1
self.assertTrue(torch.all(simulation['Infected_sol'][0, :] != simulation['Infected_sol'][1, :]))
self.assertTrue(
torch.all(
simulation["Infected_sol"][0, :] != simulation["Infected_sol"][1, :]
)
)
Loading

0 comments on commit 533965d

Please sign in to comment.