Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-github.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@ bokeh<3.6.0,>=3.5.0
geopandas>=0.13.2
geoviews>=1.10.0
nbsite>=0.8.1
git+https://github.com/NOAA-EMC/emcpy.git@1764011794f84c76488b0876e8b577d6af74df20#egg=emcpy
git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pandas
numpy

# Additional packages
git+https://github.com/NOAA-EMC/emcpy.git@92aa62f34a1f413d8cb1646bca0e81f267b61365#egg=emcpy
git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy
scikit-learn
seaborn
hvplot
Expand Down
2 changes: 1 addition & 1 deletion requirements_sles15.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ contourpy
msgpack>=1.0.0

# Additional packages
git+https://github.com/NOAA-EMC/emcpy.git@92aa62f34a1f413d8cb1646bca0e81f267b61365#egg=emcpy
git+https://github.com/NOAA-EMC/emcpy.git@7794574611e760475d61eb5d9458af2d3d2191d8#egg=emcpy
scikit-learn
seaborn
hvplot
Expand Down
179 changes: 163 additions & 16 deletions src/eva/plotting/batch/emcpy/diagnostics/emcpy_map_gridded.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from eva.utilities.utils import get_schema, update_object
import emcpy.plots.map_plots
import os
import numpy as np

from eva.plotting.batch.base.diagnostics.map_gridded import MapGridded

Expand All @@ -11,32 +12,178 @@

class EmcpyMapGridded(MapGridded):
"""
EmcpyMapGridded class is a subclass of the MapGridded class, tailored for
configuring and plotting gridded map visualizations using the emcpy library.
EMCPy backend for gridded maps.
Option A: if latitude/longitude are 1-D centers, convert them to 2-D center grids
with np.meshgrid and reduce data to a single 2-D level before plotting.
"""

Attributes:
Inherits attributes from the MapGridded class.
def _to_2d_centers(self, latvar, lonvar, datavar):
"""
Normalize inputs for EMCPy MapGridded by ensuring 2-D lat/lon and 2-D data.
Handles:
- 3-D curvilinear lat/lon with a 'tile' dim: selects one tile (tile_index, default 0)
- 2-D curvilinear lat/lon: pass-through, squeeze data to 2-D
- 1-D centers: meshgrid to 2-D and orient data to (Nlat, Nlon)
"""
lat = np.asarray(latvar)
lon = np.asarray(lonvar)
A = np.asarray(datavar)

Methods:
configure_plot(): Configures the plotting settings for the gridded map.
"""
# --- CASE 1: 3-D curvilinear (lon, lat, tile) → choose one tile to get 2-D ---
if lat.ndim == 3 and lon.ndim == 3 and lat.shape == lon.shape:
tile_count = lat.shape[2]
tile_idx = int(self.config.get("tile_index", 0))
if not (0 <= tile_idx < tile_count):
raise ValueError(f"tile_index {tile_idx} out of range [0,{tile_count-1}]")

# slice lat/lon to 2-D for this tile
lat2d = lat[:, :, tile_idx]
lon2d = lon[:, :, tile_idx]

# slice data on its tile axis (if present)
A2 = A
if A2.ndim >= 3:
# find axis whose size == tile_count
# find all axes whose size == tile_count
tile_axes = [i for i, s in enumerate(A2.shape) if s == tile_count]
tile_axis_config = self.config.get("tile_axis_index", None)
if tile_axis_config is not None:
tile_axis = int(tile_axis_config)
if tile_axis < 0 or tile_axis >= A2.ndim or A2.shape[tile_axis] != tile_count:
raise ValueError(
f"Configured tile_axis_index {tile_axis} is invalid for data "
f"shape {A2.shape} and tile_count {tile_count}"
)
else:
if len(tile_axes) == 1:
tile_axis = tile_axes[0]
elif len(tile_axes) == 0:
tile_axis = None
else:
raise ValueError(
f"Ambiguous tile axis: multiple axes {tile_axes} in data shape "
f"{A2.shape} have size equal to tile_count ({tile_count}). "
"Please specify 'tile_axis_index' in the config."
)

# if any extra leading level/ensemble dim remains, pick first or configured
if A2.ndim == 3:
lev_idx = int(self.config.get("level_index", 0))
# choose an axis that is not lat or lon sized
lat_size, lon_size = lat2d.shape
# bring to (lat, lon, extra) or (lon, lat, extra)
if A2.shape[:2] == (lat_size, lon_size):
A2 = A2[:, :, lev_idx]
elif A2.shape[:2] == (lon_size, lat_size):
A2 = A2[:, :, lev_idx].T
else:
# try matching by transpose
if A2.transpose(1, 0, 2).shape[:2] == (lat_size, lon_size):
A2 = A2.transpose(1, 0, 2)[:, :, lev_idx]
else:
# last resort: squeeze to 2-D
A2 = np.squeeze(A2)

A2 = np.squeeze(A2)
# Ensure shapes match lat2d/lon2d
if A2.shape != lat2d.shape:
if A2.T.shape == lat2d.shape:
A2 = A2.T
else:
raise ValueError(f"Data shape {A2.shape} incompatible "
f"with tile lat/lon {lat2d.shape}")

return lat2d, lon2d, A2

# --- CASE 2: 2-D curvilinear lat/lon: pass-through, squeeze data to 2-D ---
if lat.ndim == 2 and lon.ndim == 2:
A2 = A
if A2.ndim == 3:
lev_idx = int(self.config.get("level_index", 0))
# assume first axis is level/extra
A2 = A2[lev_idx, ...]
return lat, lon, np.squeeze(A2)

# --- CASE 3: 1-D centers → build 2-D centers via meshgrid ---
if lat.ndim == 1 and lon.ndim == 1:
lat1d = lat.squeeze()
lon1d = lon.squeeze()
A2 = A
if A2.ndim == 3:
# identify lat/lon axes robustly, avoiding ambiguity
shape = A2.shape
# Try to get axis indices from config first
lat_axis = self.config.get("lat_axis", None)
lon_axis = self.config.get("lon_axis", None)
if lat_axis is not None and lon_axis is not None:
lat_axis = int(lat_axis)
lon_axis = int(lon_axis)
else:
# Find all axes matching lat/lon sizes
lat_axes = [i for i, s in enumerate(shape) if s == lat1d.size]
lon_axes = [i for i, s in enumerate(shape) if s == lon1d.size]
if len(lat_axes) != 1 or len(lon_axes) != 1:
raise ValueError(
f"Ambiguous axis identification: "
f"Found lat_axes={lat_axes} for size {lat1d.size}, "
f"lon_axes={lon_axes} for size {lon1d.size}. "
"Please specify 'lat_axis' and 'lon_axis' in config."
)
lat_axis = lat_axes[0]
lon_axis = lon_axes[0]
if lat_axis != lon_axis:
axes = (lat_axis, lon_axis)
extra = [
i for i in range(A2.ndim)
if i not in axes
]
order = [*axes, *extra]
A2 = np.transpose(A2, order)
lev_idx = int(self.config.get("level_index", 0))
if A2.ndim == 3:
A2 = A2[:, :, lev_idx]
else:
lev_idx = int(self.config.get("level_index", 0))
A2 = A2[lev_idx, ...]
A2 = np.squeeze(A2)

# orient to (Nlat, Nlon)
if A2.shape != (lat1d.size, lon1d.size):
if A2.T.shape == (lat1d.size, lon1d.size):
A2 = A2.T
else:
raise ValueError(
f"Data shape {A2.shape} incompatible with "
f"lat {lat1d.size} / lon {lon1d.size}"
)

LAT2D, LON2D = np.meshgrid(lat1d, lon1d, indexing="ij")
return LAT2D, LON2D, A2

# Otherwise unsupported
raise ValueError(f"Expected 1-D or 2-D lat/lon; got lat {lat.shape}, lon {lon.shape}")

def configure_plot(self):
"""
Configures the plotting settings for the gridded map.

Returns:
plotobj: The configured plot object for emcpy gridded maps.
plotobj: The configured plot object for EMCPy gridded maps.
"""
# Convert to 2-D centers + 2-D data if needed
lat2d, lon2d, data2d = self._to_2d_centers(self.latvar, self.lonvar, self.datavar)

# Create EMCPy MapGridded object
self.plotobj = emcpy.plots.map_plots.MapGridded(lat2d, lon2d, data2d)

# create declarative plotting MapGridded object
self.plotobj = emcpy.plots.map_plots.MapGridded(self.latvar, self.lonvar, self.datavar)
# get defaults from schema
layer_schema = self.config.get('schema', os.path.join(return_eva_path(), 'plotting',
'batch', 'emcpy', 'defaults', 'map_gridded.yaml'))
# Apply schema defaults/overrides
layer_schema = self.config.get(
"schema",
os.path.join(
return_eva_path(), "plotting", "batch", "emcpy", "defaults", "map_gridded.yaml"
),
)
new_config = get_schema(layer_schema, self.config, self.logger)
delvars = ['longitude', 'latitude', 'data', 'type', 'schema']
for d in delvars:
for d in ["longitude", "latitude", "data", "type", "schema", "level_index"]:
new_config.pop(d, None)
self.plotobj = update_object(self.plotobj, new_config, self.logger)
return self.plotobj
Expand Down