diff --git a/requirements-github.txt b/requirements-github.txt index 50629176..340f3742 100644 --- a/requirements-github.txt +++ b/requirements-github.txt @@ -12,4 +12,4 @@ bokeh<3.6.0,>=3.5.0 geopandas>=0.13.2 geoviews>=1.10.0 nbsite>=0.8.1 -git+https://github.com/NOAA-EMC/emcpy.git@1764011794f84c76488b0876e8b577d6af74df20#egg=emcpy +git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy diff --git a/requirements.txt b/requirements.txt index c5e05297..3709f91d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,7 +10,7 @@ pandas numpy # Additional packages -git+https://github.com/NOAA-EMC/emcpy.git@92aa62f34a1f413d8cb1646bca0e81f267b61365#egg=emcpy +git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy scikit-learn seaborn hvplot diff --git a/requirements_sles15.txt b/requirements_sles15.txt index 39ce5a93..9038881b 100644 --- a/requirements_sles15.txt +++ b/requirements_sles15.txt @@ -18,7 +18,7 @@ contourpy msgpack>=1.0.0 # Additional packages -git+https://github.com/NOAA-EMC/emcpy.git@92aa62f34a1f413d8cb1646bca0e81f267b61365#egg=emcpy +git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy scikit-learn seaborn hvplot diff --git a/src/eva/plotting/batch/emcpy/diagnostics/emcpy_map_gridded.py b/src/eva/plotting/batch/emcpy/diagnostics/emcpy_map_gridded.py index 271b2f08..22668fd0 100644 --- a/src/eva/plotting/batch/emcpy/diagnostics/emcpy_map_gridded.py +++ b/src/eva/plotting/batch/emcpy/diagnostics/emcpy_map_gridded.py @@ -3,6 +3,7 @@ from eva.utilities.utils import get_schema, update_object import emcpy.plots.map_plots import os +import numpy as np from eva.plotting.batch.base.diagnostics.map_gridded import MapGridded @@ -11,32 +12,178 @@ class EmcpyMapGridded(MapGridded): """ - EmcpyMapGridded class is a subclass of the MapGridded class, tailored for - configuring and plotting gridded map visualizations using the emcpy library. + EMCPy backend for gridded maps. + Option A: if latitude/longitude are 1-D centers, convert them to 2-D center grids + with np.meshgrid and reduce data to a single 2-D level before plotting. + """ - Attributes: - Inherits attributes from the MapGridded class. + def _to_2d_centers(self, latvar, lonvar, datavar): + """ + Normalize inputs for EMCPy MapGridded by ensuring 2-D lat/lon and 2-D data. + Handles: + - 3-D curvilinear lat/lon with a 'tile' dim: selects one tile (tile_index, default 0) + - 2-D curvilinear lat/lon: pass-through, squeeze data to 2-D + - 1-D centers: meshgrid to 2-D and orient data to (Nlat, Nlon) + """ + lat = np.asarray(latvar) + lon = np.asarray(lonvar) + A = np.asarray(datavar) - Methods: - configure_plot(): Configures the plotting settings for the gridded map. - """ + # --- CASE 1: 3-D curvilinear (lon, lat, tile) → choose one tile to get 2-D --- + if lat.ndim == 3 and lon.ndim == 3 and lat.shape == lon.shape: + tile_count = lat.shape[2] + tile_idx = int(self.config.get("tile_index", 0)) + if not (0 <= tile_idx < tile_count): + raise ValueError(f"tile_index {tile_idx} out of range [0,{tile_count-1}]") + + # slice lat/lon to 2-D for this tile + lat2d = lat[:, :, tile_idx] + lon2d = lon[:, :, tile_idx] + + # slice data on its tile axis (if present) + A2 = A + if A2.ndim >= 3: + # find axis whose size == tile_count + # find all axes whose size == tile_count + tile_axes = [i for i, s in enumerate(A2.shape) if s == tile_count] + tile_axis_config = self.config.get("tile_axis_index", None) + if tile_axis_config is not None: + tile_axis = int(tile_axis_config) + if tile_axis < 0 or tile_axis >= A2.ndim or A2.shape[tile_axis] != tile_count: + raise ValueError( + f"Configured tile_axis_index {tile_axis} is invalid for data " + f"shape {A2.shape} and tile_count {tile_count}" + ) + else: + if len(tile_axes) == 1: + tile_axis = tile_axes[0] + elif len(tile_axes) == 0: + tile_axis = None + else: + raise ValueError( + f"Ambiguous tile axis: multiple axes {tile_axes} in data shape " + f"{A2.shape} have size equal to tile_count ({tile_count}). " + "Please specify 'tile_axis_index' in the config." + ) + + # if any extra leading level/ensemble dim remains, pick first or configured + if A2.ndim == 3: + lev_idx = int(self.config.get("level_index", 0)) + # choose an axis that is not lat or lon sized + lat_size, lon_size = lat2d.shape + # bring to (lat, lon, extra) or (lon, lat, extra) + if A2.shape[:2] == (lat_size, lon_size): + A2 = A2[:, :, lev_idx] + elif A2.shape[:2] == (lon_size, lat_size): + A2 = A2[:, :, lev_idx].T + else: + # try matching by transpose + if A2.transpose(1, 0, 2).shape[:2] == (lat_size, lon_size): + A2 = A2.transpose(1, 0, 2)[:, :, lev_idx] + else: + # last resort: squeeze to 2-D + A2 = np.squeeze(A2) + + A2 = np.squeeze(A2) + # Ensure shapes match lat2d/lon2d + if A2.shape != lat2d.shape: + if A2.T.shape == lat2d.shape: + A2 = A2.T + else: + raise ValueError(f"Data shape {A2.shape} incompatible " + f"with tile lat/lon {lat2d.shape}") + + return lat2d, lon2d, A2 + + # --- CASE 2: 2-D curvilinear lat/lon: pass-through, squeeze data to 2-D --- + if lat.ndim == 2 and lon.ndim == 2: + A2 = A + if A2.ndim == 3: + lev_idx = int(self.config.get("level_index", 0)) + # assume first axis is level/extra + A2 = A2[lev_idx, ...] + return lat, lon, np.squeeze(A2) + + # --- CASE 3: 1-D centers → build 2-D centers via meshgrid --- + if lat.ndim == 1 and lon.ndim == 1: + lat1d = lat.squeeze() + lon1d = lon.squeeze() + A2 = A + if A2.ndim == 3: + # identify lat/lon axes robustly, avoiding ambiguity + shape = A2.shape + # Try to get axis indices from config first + lat_axis = self.config.get("lat_axis", None) + lon_axis = self.config.get("lon_axis", None) + if lat_axis is not None and lon_axis is not None: + lat_axis = int(lat_axis) + lon_axis = int(lon_axis) + else: + # Find all axes matching lat/lon sizes + lat_axes = [i for i, s in enumerate(shape) if s == lat1d.size] + lon_axes = [i for i, s in enumerate(shape) if s == lon1d.size] + if len(lat_axes) != 1 or len(lon_axes) != 1: + raise ValueError( + f"Ambiguous axis identification: " + f"Found lat_axes={lat_axes} for size {lat1d.size}, " + f"lon_axes={lon_axes} for size {lon1d.size}. " + "Please specify 'lat_axis' and 'lon_axis' in config." + ) + lat_axis = lat_axes[0] + lon_axis = lon_axes[0] + if lat_axis != lon_axis: + axes = (lat_axis, lon_axis) + extra = [ + i for i in range(A2.ndim) + if i not in axes + ] + order = [*axes, *extra] + A2 = np.transpose(A2, order) + lev_idx = int(self.config.get("level_index", 0)) + if A2.ndim == 3: + A2 = A2[:, :, lev_idx] + else: + lev_idx = int(self.config.get("level_index", 0)) + A2 = A2[lev_idx, ...] + A2 = np.squeeze(A2) + + # orient to (Nlat, Nlon) + if A2.shape != (lat1d.size, lon1d.size): + if A2.T.shape == (lat1d.size, lon1d.size): + A2 = A2.T + else: + raise ValueError( + f"Data shape {A2.shape} incompatible with " + f"lat {lat1d.size} / lon {lon1d.size}" + ) + + LAT2D, LON2D = np.meshgrid(lat1d, lon1d, indexing="ij") + return LAT2D, LON2D, A2 + + # Otherwise unsupported + raise ValueError(f"Expected 1-D or 2-D lat/lon; got lat {lat.shape}, lon {lon.shape}") def configure_plot(self): """ Configures the plotting settings for the gridded map. - Returns: - plotobj: The configured plot object for emcpy gridded maps. + plotobj: The configured plot object for EMCPy gridded maps. """ + # Convert to 2-D centers + 2-D data if needed + lat2d, lon2d, data2d = self._to_2d_centers(self.latvar, self.lonvar, self.datavar) + + # Create EMCPy MapGridded object + self.plotobj = emcpy.plots.map_plots.MapGridded(lat2d, lon2d, data2d) - # create declarative plotting MapGridded object - self.plotobj = emcpy.plots.map_plots.MapGridded(self.latvar, self.lonvar, self.datavar) - # get defaults from schema - layer_schema = self.config.get('schema', os.path.join(return_eva_path(), 'plotting', - 'batch', 'emcpy', 'defaults', 'map_gridded.yaml')) + # Apply schema defaults/overrides + layer_schema = self.config.get( + "schema", + os.path.join( + return_eva_path(), "plotting", "batch", "emcpy", "defaults", "map_gridded.yaml" + ), + ) new_config = get_schema(layer_schema, self.config, self.logger) - delvars = ['longitude', 'latitude', 'data', 'type', 'schema'] - for d in delvars: + for d in ["longitude", "latitude", "data", "type", "schema", "level_index"]: new_config.pop(d, None) self.plotobj = update_object(self.plotobj, new_config, self.logger) return self.plotobj