diff --git a/src/eva/data/gsi_obs_space.py b/src/eva/data/gsi_obs_space.py index 2c29d203..6bebdac0 100644 --- a/src/eva/data/gsi_obs_space.py +++ b/src/eva/data/gsi_obs_space.py @@ -20,25 +20,6 @@ # -------------------------------------------------------------------------------------------------- -def all_equal(iterable): - - """ - Check if all elements in an iterable are equal. - - Args: - iterable: An iterable object to check. - - Returns: - bool: True if all elements are equal, False otherwise. - """ - - g = groupby(iterable) - return next(g, True) and not next(g, False) - - -# -------------------------------------------------------------------------------------------------- - - def uv(group_vars): """ @@ -112,30 +93,63 @@ def subset_channels(ds, channels, logger, add_channels_variable=False): # -------------------------------------------------------------------------------------------------- -def satellite_dataset(ds): +def satellite_dataset(ds, group_vars, force_reshape_all): """ Build a new dataset to reshape satellite data. Args: ds (Dataset): The input xarray Dataset. - + group_vars: all or use selected group_vars + force_reshape_all: bool that allows user to force all variables to be reshaped + vs. checking whether to thin or reshape Returns: Dataset: Reshaped xarray Dataset. """ - nchans = ds.dims['nchans'] - iters = int(ds.dims['nobs']/nchans) + nchans = ds.sizes['nchans'] + iters = int(ds.sizes['nobs']/nchans) coords = { 'nchans': (('nchans'), ds['sensor_chan'].data), 'nobs': (('nobs'), np.arange(0, iters)), } - data_vars = {} + + required_variables = group_vars[:] + if ('sensor_chan' not in required_variables): + required_variables.append('sensor_chan') + + lat_lon = ['Latitude', 'Longitude'] + # these inherently have a spectral dimension. + always_2d = ['QC_Flag', + 'Observation', + 'Obs_Minus_Forecast_adjusted', + 'Obs_Minus_Forecast_unadjusted', + 'Forecast_adjusted_clear', + 'Forecast_unadjusted_clear', + 'Forecast_adjusted', + 'Forecast_unadjusted', + 'Emissivity', + 'BC_Scan_Angle', + 'BC_Cloud_Liquid_Water', + 'BC_Cosine_Latitude_times_Node', + 'BC_Sine_Latitude', + 'Bias_Correction', + 'Bias_Correction_Constant', + 'Bias_Correction_ScanAngle', + 'Inverse_Observation_Error', + 'Input_Observation_Error', + 'BC_Constant', + 'BC_Lapse_Rate_Squared', + 'BC_Lapse_Rate', + 'BC_Emissivity', + 'BC_Fixed_Scan_Position'] # Loop through each variable for var in ds.variables: - + # Ignore everything that the user didn't ask for + if var not in required_variables: + continue # Ignore geovals data if var in ['air_temperature', 'air_pressure', 'air_pressure_levels', 'atmosphere_absorber_01', 'atmosphere_absorber_02', 'atmosphere_absorber_03']: @@ -155,21 +169,27 @@ def satellite_dataset(ds): out_var = var+pred data_vars[out_var] = (('nobs', 'nchans'), data[:, :, ipred]) - # Deals with how to handle nobs data + # Force all variables 2d except lat/lon (optional), or if it inherently has + # spectral dependence. Don't touch lat/lon because if they're 2d it makes it + # harder to make map. You'd need to do a per channel `accept_where` which + # would be cumbersome for hyperspectral, and annoying for any microwave that + # isn't AMSU. + # + # Recommend `gsi_obs_space_reshape_all` for any ancillary data for microwave sensors, + # because footprint size is used and calculated on a per channel basis (e.g., Land_Fraction, + # Water_Fraction, etc become inherently spectral). + elif (force_reshape_all and (var not in lat_lon)) or (var in always_2d): + data_vars[var] = (('nobs', 'nchans'), ds[var].data.reshape(iters, nchans)) + + # either thin or reshape based on whether or not all repeat values for the first + # slice of nchan out of the array. Always use for lat/lon because it will thin + # for most sensors except AMSU, which makes it easier to plot a map. else: - # Check if values repeat over nchans - condition = all_equal(ds[var].data[0:nchans]) - - # If values are repeating over nchan iterations, keep as nobs - if condition: - data = ds[var].data[0::nchans] - data_vars[var] = (('nobs'), data) - - # Else, reshape to be a 2d array + is_1d = (ds[var].isel(nobs=slice(0, nchans)) == ds[var].isel(nobs=0)).all() + if (is_1d): + data_vars[var] = (('nobs'), ds[var].thin(nobs=nchans).data) else: - data = np.reshape(ds[var].data, (iters, nchans)) - data_vars[var] = (('nobs', 'nchans'), data) - + data_vars[var] = (('nobs', 'nchans'), ds[var].data.reshape(iters, nchans)) # create dataset_config new_ds = Dataset(data_vars=data_vars, coords=coords, @@ -227,7 +247,10 @@ def execute(self, dataset_config, data_collections, timeing): # Get the groups to be read # ------------------------- groups = get(dataset_config, self.logger, 'groups') - + if ('gsi_obs_space_reshape_all' in dataset_config.keys()): + force_reshape_all = bool(get(dataset_config, self.logger, 'gsi_obs_space_reshape_all')) + else: + force_reshape_all = False # Loop over filenames # ------------------- for filename in filenames: @@ -251,9 +274,8 @@ def execute(self, dataset_config, data_collections, timeing): # Reshape variables if satellite diag if 'nchans' in ds.dims: - ds = satellite_dataset(ds) + ds = satellite_dataset(ds, group_vars, force_reshape_all) ds = subset_channels(ds, channels, self.logger) - # Adjust variable names if uv if 'variable' in locals(): if variable == 'uv':