Skip to content

Commit

Permalink
Feat(plot_centers): add plot_centers support to PlotMapView and PlotC…
Browse files Browse the repository at this point in the history
…rossSection (#2318)

* Feat(plot_centers): add plot_centers support to PlotMapView and PlotCrossSection

fixes and updates included with feature:

* filter very short/small intersection segments from cross-sectional plotting routine
* sort and assure vertex order is correct for intersected cross-sectional segments
* improve geometry.project_point_onto_xc_line() calculation routine.
* update reproject_modpath_to_crosssection()

* linting

* fix filter_line_segments(), catch and filter single vertex intersections

* add testing for plot_centers()

* add projctr attribute to PlotCrossSection for testing purposes

* lint autotests
  • Loading branch information
jlarsen-usgs authored Oct 7, 2024
1 parent 5555d15 commit c6a41ab
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 37 deletions.
49 changes: 49 additions & 0 deletions autotest/test_plot_cross_section.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,3 +228,52 @@ def test_plot_limits():
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)


def test_plot_centers():
from matplotlib.collections import PathCollection

nlay = 1
nrow = 10
ncol = 10

delc = np.ones((nrow,))
delr = np.ones((ncol,))
top = np.ones((nrow, ncol))
botm = np.zeros((nlay, nrow, ncol))
idomain = np.ones(botm.shape, dtype=int)

idomain[0, :, 0:3] = 0

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

line = {"line": [(0, 0), (10, 10)]}
active_xc_cells = 7

pxc = flopy.plot.PlotCrossSection(modelgrid=grid, line=line)
pc = pxc.plot_centers()

if not isinstance(pc, PathCollection):
raise AssertionError(
"plot_centers() not returning PathCollection object"
)

verts = pc._offsets
if not verts.shape[0] == active_xc_cells:
raise AssertionError(
"plot_centers() not properly masking inactive cells"
)

center_dict = pxc.projctr
edge_dict = pxc.projpts

for node, center in center_dict.items():
verts = np.array(edge_dict[node]).T
xmin = np.min(verts[0])
xmax = np.max(verts[0])
if xmax < center < xmin:
raise AssertionError(
"Cell center not properly drawn on cross-section"
)
41 changes: 41 additions & 0 deletions autotest/test_plot_map_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,44 @@ def test_plot_limits():
raise AssertionError("PlotMapView auto extent setting not working")

plt.close(fig)


def test_plot_centers():
nlay = 1
nrow = 10
ncol = 10

delc = np.ones((nrow,))
delr = np.ones((ncol,))
top = np.ones((nrow, ncol))
botm = np.zeros((nlay, nrow, ncol))
idomain = np.ones(botm.shape, dtype=int)

idomain[0, :, 0:3] = 0
active_cells = np.count_nonzero(idomain)

grid = flopy.discretization.StructuredGrid(
delc=delc, delr=delr, top=top, botm=botm, idomain=idomain
)

xcenters = grid.xcellcenters.ravel()
ycenters = grid.ycellcenters.ravel()
xycenters = list(zip(xcenters, ycenters))

pmv = flopy.plot.PlotMapView(modelgrid=grid)
pc = pmv.plot_centers()
if not isinstance(pc, PathCollection):
raise AssertionError(
"plot_centers() not returning PathCollection object"
)

verts = pc._offsets
if not verts.shape[0] == active_cells:
raise AssertionError(
"plot_centers() not properly masking inactive cells"
)

for vert in verts:
vert = tuple(vert)
if vert not in xycenters:
raise AssertionError("center location not properly plotted")
138 changes: 133 additions & 5 deletions flopy/plot/crosssection.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ class PlotCrossSection:
(xmin, xmax, ymin, ymax) will be used to specify axes limits. If None
then these will be calculated based on grid, coordinates, and rotation.
geographic_coords : bool
boolean flag to allow the user to plot cross section lines in
geographic coordinates. If False (default), cross section is plotted
as the distance along the cross section line.
boolean flag to allow the user to plot cross-section lines in
geographic coordinates. If False (default), cross-section is plotted
as the distance along the cross-section line.
min_segment_length : float
minimum width of a grid cell polygon to be plotted. Cells with a
cross-sectional width less than min_segment_length will be ignored
and not included in the plot. Default is 1e-02.
"""

def __init__(
Expand All @@ -53,6 +56,7 @@ def __init__(
line=None,
extent=None,
geographic_coords=False,
min_segment_length=1e-02,
):
self.ax = ax
self.geographic_coords = geographic_coords
Expand Down Expand Up @@ -180,6 +184,22 @@ def __init__(
self.pts, self.xvertices, self.yvertices
)

self.xypts = plotutil.UnstructuredPlotUtilities.filter_line_segments(
self.xypts, threshold=min_segment_length
)
# need to ensure that the ordering of verticies in xypts is correct
# based on the projection. In certain cases vertices need to be sorted
# for the specific "projection"
for node, points in self.xypts.items():
if self.direction == "y":
if points[0][-1] < points[1][-1]:
points = points[::-1]
else:
if points[0][0] > points[1][0]:
points = points[::-1]

self.xypts[node] = points

if len(self.xypts) < 2:
if len(list(self.xypts.values())[0]) < 2:
s = (
Expand Down Expand Up @@ -238,6 +258,7 @@ def __init__(
self.idomain = np.ones(botm.shape, dtype=int)

self.projpts = self.set_zpts(None)
self.projctr = None

# Create cross-section extent
if extent is None:
Expand Down Expand Up @@ -926,6 +947,111 @@ def plot_bc(

return patches

def plot_centers(
self, a=None, s=None, masked_values=None, inactive=False, **kwargs
):
"""
Method to plot cell centers on cross-section using matplotlib
scatter. This method accepts an optional data array(s) for
coloring and scaling the cell centers. Cell centers in inactive
nodes are not plotted by default
Parameters
----------
a : None, np.ndarray
optional numpy nd.array of size modelgrid.nnodes
s : None, float, numpy array
optional point size parameter
masked_values : None, iteratable
optional list, tuple, or np array of array (a) values to mask
inactive : bool
boolean flag to include inactive cell centers in the plot.
Default is False
**kwargs :
matplotlib ax.scatter() keyword arguments
Returns
-------
matplotlib ax.scatter() object
"""
ax = kwargs.pop("ax", self.ax)

projpts = self.projpts
nodes = list(projpts.keys())
xcs = self.mg.xcellcenters.ravel()
ycs = self.mg.ycellcenters.ravel()
projctr = {}

if not self.geographic_coords:
xcs, ycs = geometry.transform(
xcs,
ycs,
self.mg.xoffset,
self.mg.yoffset,
self.mg.angrot_radians,
inverse=True,
)

for node, points in self.xypts.items():
projpt = projpts[node]
d0 = np.min(np.array(projpt).T[0])

xc_dist = geometry.project_point_onto_xc_line(
points[:2], [xcs[node], ycs[node]], d0=d0, calc_dist=True
)
projctr[node] = xc_dist

else:
projctr = {}
for node in nodes:
if self.direction == "x":
projctr[node] = xcs[node]
else:
projctr[node] = ycs[node]

# pop off any centers that are outside the "visual field"
# for a given cross-section.
removed = {}
for node, points in projpts.items():
center = projctr[node]
points = np.array(points[:2]).T
if np.min(points[0]) > center or np.max(points[0]) < center:
removed[node] = (np.min(points[0]), center, np.max(points[0]))
projctr.pop(node)

# filter out inactive cells
if not inactive:
idomain = self.mg.idomain.ravel()
for node, points in projpts.items():
if idomain[node] == 0:
if node in projctr:
projctr.pop(node)

self.projctr = projctr
nodes = list(projctr.keys())
xcenters = list(projctr.values())
zcenters = [np.mean(np.array(projpts[node]).T[1]) for node in nodes]

if a is not None:
if not isinstance(a, np.ndarray):
a = np.array(a)
a = a.ravel().astype(float)

if masked_values is not None:
self._masked_values.extend(list(masked_values))

for mval in self._masked_values:
a[a == mval] = np.nan

a = a[nodes]

if s is not None:
if not isinstance(s, (int, float)):
s = s[nodes]
print(len(xcenters))
scat = ax.scatter(xcenters, zcenters, c=a, s=s, **kwargs)
return scat

def plot_vector(
self,
vx,
Expand Down Expand Up @@ -1350,6 +1476,7 @@ def plot_endpoint(
self.xvertices,
self.yvertices,
self.direction,
self._ncpl,
method=method,
starting=istart,
)
Expand All @@ -1362,15 +1489,16 @@ def plot_endpoint(
self.xypts,
self.direction,
self.mg,
self._ncpl,
self.geographic_coords,
starting=istart,
)

arr = []
c = []
for node, epl in sorted(epdict.items()):
c.append(cd[node])
for xy in epl:
c.append(cd[node])
arr.append(xy)

arr = np.array(arr)
Expand Down
61 changes: 61 additions & 0 deletions flopy/plot/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,67 @@ def plot_shapes(self, obj, **kwargs):
ax = self._set_axes_limits(ax)
return patch_collection

def plot_centers(
self, a=None, s=None, masked_values=None, inactive=False, **kwargs
):
"""
Method to plot cell centers on cross-section using matplotlib
scatter. This method accepts an optional data array(s) for
coloring and scaling the cell centers. Cell centers in inactive
nodes are not plotted by default
Parameters
----------
a : None, np.ndarray
optional numpy nd.array of size modelgrid.nnodes
s : None, float, numpy array
optional point size parameter
masked_values : None, iteratable
optional list, tuple, or np array of array (a) values to mask
inactive : bool
boolean flag to include inactive cell centers in the plot.
Default is False
**kwargs :
matplotlib ax.scatter() keyword arguments
Returns
-------
matplotlib ax.scatter() object
"""
ax = kwargs.pop("ax", self.ax)

xcenters = self.mg.get_xcellcenters_for_layer(self.layer).ravel()
ycenters = self.mg.get_ycellcenters_for_layer(self.layer).ravel()
idomain = self.mg.get_plottable_layer_array(
self.mg.idomain, self.layer
).ravel()

active_ixs = list(range(len(xcenters)))
if not inactive:
active_ixs = np.where(idomain != 0)[0]

xcenters = xcenters[active_ixs]
ycenters = ycenters[active_ixs]

if a is not None:
a = self.mg.get_plottable_layer_array(a).ravel()

if masked_values is not None:
self._masked_values.extend(list(masked_values))

for mval in self._masked_values:
a[a == mval] = np.nan

a = a[active_ixs]

if s is not None:
if not isinstance(s, (int, float)):
s = self.mg.get_plottable_layer_array(s).ravel()
s = s[active_ixs]

scat = ax.scatter(xcenters, ycenters, c=a, s=s, **kwargs)
return scat

def plot_vector(
self,
vx,
Expand Down
Loading

0 comments on commit c6a41ab

Please sign in to comment.