Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 29 additions & 30 deletions pygem/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pandas as pd
import xarray as xr
import os, types, json, cftime, collections
import os, types, json, cftime, collections, warnings
import pygem
from pygem.setup.config import ConfigManager
# instantiate ConfigManager
Expand Down Expand Up @@ -735,41 +735,40 @@ def _update_dicts(self):
'temporal_resolution': 'annual',
'comment': 'climatic mass balance is computed before dynamics so can theoretically exceed ice thickness'}


def calc_stats_array(data, stats_cns=pygem_prms['sim']['out']['sim_stats']):
"""
Calculate stats for a given variable
Calculate stats for a given variable.

Parameters
----------
vn : str
variable name
ds : xarray dataset
dataset of output with all ensemble simulations
data : np.array
2D array with ensemble simulations (shape: [n_samples, n_ensembles])
stats_cns : list, optional
List of statistics to compute (e.g., ['mean', 'std', 'median'])

Returns
-------
stats : np.array
Statistics related to a given variable
stats : np.array, or None
Statistics related to a given variable.
"""
stats = None
if 'mean' in stats_cns:
if stats is None:
stats = np.nanmean(data,axis=1)[:,np.newaxis]
if 'std' in stats_cns:
stats = np.append(stats, np.nanstd(data,axis=1)[:,np.newaxis], axis=1)
if '2.5%' in stats_cns:
stats = np.append(stats, np.nanpercentile(data, 2.5, axis=1)[:,np.newaxis], axis=1)
if '25%' in stats_cns:
stats = np.append(stats, np.nanpercentile(data, 25, axis=1)[:,np.newaxis], axis=1)
if 'median' in stats_cns:
if stats is None:
stats = np.nanmedian(data, axis=1)[:,np.newaxis]
else:
stats = np.append(stats, np.nanmedian(data, axis=1)[:,np.newaxis], axis=1)
if '75%' in stats_cns:
stats = np.append(stats, np.nanpercentile(data, 75, axis=1)[:,np.newaxis], axis=1)
if '97.5%' in stats_cns:
stats = np.append(stats, np.nanpercentile(data, 97.5, axis=1)[:,np.newaxis], axis=1)
if 'mad' in stats_cns:
stats = np.append(stats, median_abs_deviation(data, axis=1, nan_policy='omit')[:,np.newaxis], axis=1)
return stats

# dictionary of functions to call for each stat in `stats_cns`
stat_funcs = {
'mean': lambda x: np.nanmean(x, axis=1),
'std': lambda x: np.nanstd(x, axis=1),
'2.5%': lambda x: np.nanpercentile(x, 2.5, axis=1),
'25%': lambda x: np.nanpercentile(x, 25, axis=1),
'median': lambda x: np.nanmedian(x, axis=1),
'75%': lambda x: np.nanpercentile(x, 75, axis=1),
'97.5%': lambda x: np.nanpercentile(x, 97.5, axis=1),
'mad': lambda x: median_abs_deviation(x, axis=1, nan_policy='omit')
}

# calculate statustics for each stat in `stats_cns`
with warnings.catch_warnings():
warnings.simplefilter("ignore", RuntimeWarning) # Suppress All-NaN Slice Warnings
stats_list = [stat_funcs[stat](data) for stat in stats_cns if stat in stat_funcs]

# stack stats_list to numpy array
return np.column_stack(stats_list) if stats_list else None