diff --git a/pygem/output.py b/pygem/output.py index 579947dd..1d428720 100644 --- a/pygem/output.py +++ b/pygem/output.py @@ -17,7 +17,7 @@ import numpy as np import pandas as pd import xarray as xr -import os, types, json, cftime, collections +import os, types, json, cftime, collections, warnings import pygem from pygem.setup.config import ConfigManager # instantiate ConfigManager @@ -735,41 +735,40 @@ def _update_dicts(self): 'temporal_resolution': 'annual', 'comment': 'climatic mass balance is computed before dynamics so can theoretically exceed ice thickness'} + def calc_stats_array(data, stats_cns=pygem_prms['sim']['out']['sim_stats']): """ - Calculate stats for a given variable + Calculate stats for a given variable. Parameters ---------- - vn : str - variable name - ds : xarray dataset - dataset of output with all ensemble simulations + data : np.array + 2D array with ensemble simulations (shape: [n_samples, n_ensembles]) + stats_cns : list, optional + List of statistics to compute (e.g., ['mean', 'std', 'median']) Returns ------- - stats : np.array - Statistics related to a given variable + stats : np.array, or None + Statistics related to a given variable. """ - stats = None - if 'mean' in stats_cns: - if stats is None: - stats = np.nanmean(data,axis=1)[:,np.newaxis] - if 'std' in stats_cns: - stats = np.append(stats, np.nanstd(data,axis=1)[:,np.newaxis], axis=1) - if '2.5%' in stats_cns: - stats = np.append(stats, np.nanpercentile(data, 2.5, axis=1)[:,np.newaxis], axis=1) - if '25%' in stats_cns: - stats = np.append(stats, np.nanpercentile(data, 25, axis=1)[:,np.newaxis], axis=1) - if 'median' in stats_cns: - if stats is None: - stats = np.nanmedian(data, axis=1)[:,np.newaxis] - else: - stats = np.append(stats, np.nanmedian(data, axis=1)[:,np.newaxis], axis=1) - if '75%' in stats_cns: - stats = np.append(stats, np.nanpercentile(data, 75, axis=1)[:,np.newaxis], axis=1) - if '97.5%' in stats_cns: - stats = np.append(stats, np.nanpercentile(data, 97.5, axis=1)[:,np.newaxis], axis=1) - if 'mad' in stats_cns: - stats = np.append(stats, median_abs_deviation(data, axis=1, nan_policy='omit')[:,np.newaxis], axis=1) - return stats \ No newline at end of file + + # dictionary of functions to call for each stat in `stats_cns` + stat_funcs = { + 'mean': lambda x: np.nanmean(x, axis=1), + 'std': lambda x: np.nanstd(x, axis=1), + '2.5%': lambda x: np.nanpercentile(x, 2.5, axis=1), + '25%': lambda x: np.nanpercentile(x, 25, axis=1), + 'median': lambda x: np.nanmedian(x, axis=1), + '75%': lambda x: np.nanpercentile(x, 75, axis=1), + '97.5%': lambda x: np.nanpercentile(x, 97.5, axis=1), + 'mad': lambda x: median_abs_deviation(x, axis=1, nan_policy='omit') + } + + # calculate statustics for each stat in `stats_cns` + with warnings.catch_warnings(): + warnings.simplefilter("ignore", RuntimeWarning) # Suppress All-NaN Slice Warnings + stats_list = [stat_funcs[stat](data) for stat in stats_cns if stat in stat_funcs] + + # stack stats_list to numpy array + return np.column_stack(stats_list) if stats_list else None \ No newline at end of file