Skip to content

Commit

Permalink
Dropping all other states/observables except the ones in `solution_st…
Browse files Browse the repository at this point in the history
…ring_mapping`
  • Loading branch information
anirban-chaudhuri committed Apr 15, 2024
1 parent aef588c commit b000bec
Show file tree
Hide file tree
Showing 2 changed files with 312 additions and 309 deletions.
574 changes: 285 additions & 289 deletions docs/source/interfaces.ipynb

Large diffs are not rendered by default.

47 changes: 27 additions & 20 deletions pyciemss/integration_utils/result_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def make_quantiles(
) -> Union[pd.DataFrame, None]:
"""Make quantiles for each timepoint"""
_, num_timepoints = next(iter(pyciemss_results["states"].values())).shape
key_list = ["timepoint_id", "inc_cum", "output", "type", "quantile", "value"]
key_list = ["timepoint_id", "output", "type", "quantile", "value"]
q: Dict[str, List] = {k: [] for k in key_list}
if alpha_qs is not None:
num_quantiles = len(alpha_qs)
Expand All @@ -207,10 +207,10 @@ def make_quantiles(
)
)
)
if "cum" in k.lower():
q["inc_cum"].extend(["cum"] * num_timepoints * num_quantiles)
else:
q["inc_cum"].extend(["inc"] * num_timepoints * num_quantiles)
# if "cum" in k.lower():
# q["inc_cum"].extend(["cum"] * num_timepoints * num_quantiles)
# else:
# q["inc_cum"].extend(["inc"] * num_timepoints * num_quantiles)
elif stacking_order == "quantiles":
# Keeping quantiles together
q["timepoint_id"].extend(
Expand All @@ -224,10 +224,10 @@ def make_quantiles(
np.squeeze(q_vals.reshape((num_timepoints * num_quantiles, 1)))
)
)
if "cum" in k.lower():
q["inc_cum"].extend(["cum"] * num_timepoints * num_quantiles)
else:
q["inc_cum"].extend(["inc"] * num_timepoints * num_quantiles)
# if "cum" in k.lower():
# q["inc_cum"].extend(["cum"] * num_timepoints * num_quantiles)
# else:
# q["inc_cum"].extend(["inc"] * num_timepoints * num_quantiles)
else:
raise Exception("Incorrect input for stacking_order.")

Expand All @@ -241,7 +241,7 @@ def make_quantiles(
[
"timepoint_id",
f"number_{time_unit}",
"inc_cum",
# "inc_cum",
"output",
"type",
"quantile",
Expand All @@ -255,20 +255,22 @@ def make_quantiles(

def cdc_format(
q_ensemble_input: pd.DataFrame,
solution_string_mapping: Dict[str, str],
*,
time_unit: Optional[str] = None,
solution_string_mapping: Optional[Dict[str, str]] = None,
forecast_start_date: Optional[str] = None,
location: Optional[str] = None,
drop_column_names: List[str] = [
"timepoint_id",
"inc_cum",
"output",
],
train_end_point: Optional[float] = None,
) -> pd.DataFrame:
"""
Reformat the quantiles pandas dataframe file to CDC ensemble forecast format
Note that solution_string_mapping maps name of states/observables in the dictionary key to the dictionary value
and also drops any states/observables not available in the dictionary keys.
forecast_start_date is the date of last observed data.
"""
q_ensemble_data = deepcopy(q_ensemble_input)
if time_unit != "days" or time_unit is None:
Expand All @@ -283,19 +285,20 @@ def cdc_format(

if train_end_point is None:
q_ensemble_data["Forecast_Backcast"] = "Forecast"
number_data_days = 0.0
else:
q_ensemble_data["Forecast_Backcast"] = np.where(
q_ensemble_data[f"number_{time_unit}"] > train_end_point,
"Forecast",
"Backcast",
)
drop_column_names.extend(["Forecast_Backcast"])
# Number of days for which data is available
number_data_days = max(
q_ensemble_data[q_ensemble_data["Forecast_Backcast"].str.contains("Backcast")][
f"number_{time_unit}"
]
)
# Number of days for which data is available
number_data_days = max(
q_ensemble_data[q_ensemble_data["Forecast_Backcast"].str.contains("Backcast")][
f"number_{time_unit}"
]
)
drop_column_names.extend(["Forecast_Backcast"])
# Subtracting number of backast days from number_days
q_ensemble_data[f"number_{time_unit}"] = (
q_ensemble_data[f"number_{time_unit}"] - number_data_days
Expand All @@ -306,14 +309,18 @@ def cdc_format(
]
# Changing name of state according to user provided strings
if solution_string_mapping:
# Drop rows that are not present in the solution_string_mapping keys
q_ensemble_data = q_ensemble_data[
q_ensemble_data["output"].str.contains("|".join(solution_string_mapping.keys()))
]
for k, v in solution_string_mapping.items():
q_ensemble_data["output"] = q_ensemble_data["output"].replace(k, v)

# Creating target column
q_ensemble_data["target"] = (
q_ensemble_data[f"number_{time_unit}"].astype("string")
+ " days ahead "
+ q_ensemble_data["inc_cum"]
# + q_ensemble_data["inc_cum"]
+ " "
+ q_ensemble_data["output"]
)
Expand Down

0 comments on commit b000bec

Please sign in to comment.