Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 63 additions & 41 deletions src/eva/data/gsi_obs_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,6 @@
# --------------------------------------------------------------------------------------------------


def all_equal(iterable):

"""
Check if all elements in an iterable are equal.

Args:
iterable: An iterable object to check.

Returns:
bool: True if all elements are equal, False otherwise.
"""

g = groupby(iterable)
return next(g, True) and not next(g, False)


# --------------------------------------------------------------------------------------------------


def uv(group_vars):

"""
Expand Down Expand Up @@ -112,30 +93,63 @@ def subset_channels(ds, channels, logger, add_channels_variable=False):
# --------------------------------------------------------------------------------------------------


def satellite_dataset(ds):
def satellite_dataset(ds, group_vars, force_reshape_all):

"""
Build a new dataset to reshape satellite data.

Args:
ds (Dataset): The input xarray Dataset.

group_vars: all or use selected group_vars
force_reshape_all: bool that allows user to force all variables to be reshaped
vs. checking whether to thin or reshape
Returns:
Dataset: Reshaped xarray Dataset.
"""

nchans = ds.dims['nchans']
iters = int(ds.dims['nobs']/nchans)
nchans = ds.sizes['nchans']
iters = int(ds.sizes['nobs']/nchans)

coords = {
'nchans': (('nchans'), ds['sensor_chan'].data),
'nobs': (('nobs'), np.arange(0, iters)),
}

data_vars = {}

required_variables = group_vars[:]
if ('sensor_chan' not in required_variables):
required_variables.append('sensor_chan')

lat_lon = ['Latitude', 'Longitude']
# these inherently have a spectral dimension.
always_2d = ['QC_Flag',
'Observation',
'Obs_Minus_Forecast_adjusted',
'Obs_Minus_Forecast_unadjusted',
'Forecast_adjusted_clear',
'Forecast_unadjusted_clear',
'Forecast_adjusted',
'Forecast_unadjusted',
'Emissivity',
'BC_Scan_Angle',
'BC_Cloud_Liquid_Water',
'BC_Cosine_Latitude_times_Node',
'BC_Sine_Latitude',
'Bias_Correction',
'Bias_Correction_Constant',
'Bias_Correction_ScanAngle',
'Inverse_Observation_Error',
'Input_Observation_Error',
'BC_Constant',
'BC_Lapse_Rate_Squared',
'BC_Lapse_Rate',
'BC_Emissivity',
'BC_Fixed_Scan_Position']
# Loop through each variable
for var in ds.variables:

# Ignore everything that the user didn't ask for
if var not in required_variables:
continue
# Ignore geovals data
if var in ['air_temperature', 'air_pressure', 'air_pressure_levels',
'atmosphere_absorber_01', 'atmosphere_absorber_02', 'atmosphere_absorber_03']:
Expand All @@ -155,21 +169,27 @@ def satellite_dataset(ds):
out_var = var+pred
data_vars[out_var] = (('nobs', 'nchans'), data[:, :, ipred])

# Deals with how to handle nobs data
# Force all variables 2d except lat/lon (optional), or if it inherently has
# spectral dependence. Don't touch lat/lon because if they're 2d it makes it
# harder to make map. You'd need to do a per channel `accept_where` which
# would be cumbersome for hyperspectral, and annoying for any microwave that
# isn't AMSU.
#
# Recommend `gsi_obs_space_reshape_all` for any ancillary data for microwave sensors,
# because footprint size is used and calculated on a per channel basis (e.g., Land_Fraction,
# Water_Fraction, etc become inherently spectral).
elif (force_reshape_all and (var not in lat_lon)) or (var in always_2d):
data_vars[var] = (('nobs', 'nchans'), ds[var].data.reshape(iters, nchans))

# either thin or reshape based on whether or not all repeat values for the first
# slice of nchan out of the array. Always use for lat/lon because it will thin
# for most sensors except AMSU, which makes it easier to plot a map.
else:
# Check if values repeat over nchans
condition = all_equal(ds[var].data[0:nchans])

# If values are repeating over nchan iterations, keep as nobs
if condition:
data = ds[var].data[0::nchans]
data_vars[var] = (('nobs'), data)

# Else, reshape to be a 2d array
is_1d = (ds[var].isel(nobs=slice(0, nchans)) == ds[var].isel(nobs=0)).all()
if (is_1d):
data_vars[var] = (('nobs'), ds[var].thin(nobs=nchans).data)
else:
data = np.reshape(ds[var].data, (iters, nchans))
data_vars[var] = (('nobs', 'nchans'), data)

data_vars[var] = (('nobs', 'nchans'), ds[var].data.reshape(iters, nchans))
# create dataset_config
new_ds = Dataset(data_vars=data_vars,
coords=coords,
Expand Down Expand Up @@ -227,7 +247,10 @@ def execute(self, dataset_config, data_collections, timeing):
# Get the groups to be read
# -------------------------
groups = get(dataset_config, self.logger, 'groups')

if ('gsi_obs_space_reshape_all' in dataset_config.keys()):
force_reshape_all = bool(get(dataset_config, self.logger, 'gsi_obs_space_reshape_all'))
else:
force_reshape_all = False
# Loop over filenames
# -------------------
for filename in filenames:
Expand All @@ -251,9 +274,8 @@ def execute(self, dataset_config, data_collections, timeing):

# Reshape variables if satellite diag
if 'nchans' in ds.dims:
ds = satellite_dataset(ds)
ds = satellite_dataset(ds, group_vars, force_reshape_all)
ds = subset_channels(ds, channels, self.logger)

# Adjust variable names if uv
if 'variable' in locals():
if variable == 'uv':
Expand Down