Skip to content

Commit

Permalink
Issue #72 and #77 redraw line data with time change (#76)
Browse files Browse the repository at this point in the history
* First version of fix with code duplication

* Refactor to use Mixins

* fix docstring

* Refactor duplicate code away in SupportsTemporalMixin and fix bug in its requires_loading() method, which would return False if different cross-section was drawn but same timestep selected

* Update changelog

* Add comment
  • Loading branch information
JoerivanEngelen authored Mar 18, 2024
1 parent 3a29dd4 commit e5f128c
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 58 deletions.
135 changes: 78 additions & 57 deletions imodqgis/cross_section/cross_section_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
#
import abc
import pathlib
from typing import List, Tuple
from typing import Any, List, Tuple

import numpy as np
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtGui import QColor
from PyQt5.QtWidgets import QWidget
from qgis import processing
from qgis.core import (
QgsDateTimeRange,
QgsFeature,
QgsGeometry,
QgsMeshDatasetIndex,
Expand Down Expand Up @@ -73,9 +74,6 @@ def plot(self, plot_widget):
def clear(self):
pass

def requires_loading(self, **kwargs):
return self.x is None

def add_to_legend(self, legend):
for color, name in zip(self.colors().values(), self.labels().values()):
item = pg.BarGraphItem(x=0, y=0, brush=color)
Expand All @@ -97,10 +95,10 @@ def colors(self):
def color_ramp(self):
return self.color_widget.color_ramp_button.colorRamp()

def requires_static_index(self, datetime_range):
def requires_static_index(self, datetime_range: QgsDateTimeRange):
"""
Check if data requires static indexing, meaning temporal manager inactive
(datetime_range is None) or layer has not time data.
(datetime_range is None) or layer has no time data.
"""

# This works for Raster, Mesh and Vector data, as they all have this
Expand Down Expand Up @@ -133,6 +131,38 @@ def edit_colors(self):
raise ValueError("Invalid render style")
self.colors_changed.emit()

def _is_undefined(x: Any) -> bool:
return x is None

class StaticOnlyMixin():
def requires_loading(self, **kwargs) -> bool:
return _is_undefined(self.x)

class SupportsTemporalMixin():
def requires_loading(self, datetime_range: QgsDateTimeRange) -> bool:
time_and_group_index = self.get_time_and_group_index(datetime_range)
if time_and_group_index == self.time_and_group_index:
return _is_undefined(self.x)
else:
return True

def get_time_and_group_index(self, datetime_range: QgsDateTimeRange) -> Tuple[int, int]:
group_index = self.variables_indexes[self.variable][self.layer_numbers[0]]
if self.requires_static_index(datetime_range):
# Just take the first one in such a case
time_index = QgsMeshDatasetIndex(dataset=0, group=group_index)
else:
time_index = self.layer.datasetIndexAtTime(datetime_range, group_index)
return time_index.dataset(), time_index.group()

def get_plot_datetime_range(self, datetime_range: QgsDateTimeRange) -> QgsDateTimeRange:
if self.requires_static_index(datetime_range):
# Fix datetime_range of cross_section_y_data to None
return None
else:
return datetime_range



class AbstractLineData(AbstractCrossSectionData):
def plot(self, plot_widget):
Expand All @@ -151,14 +181,17 @@ def plot(self, plot_widget):
def clear(self):
self.x = None
self.y = None
self.cache = {}
self.plot_item = None

def add_to_legend(self, legend):
for item, name in zip(self.plot_item, self.labels()):
legend.addItem(item, name)
# self.plot_item can be None after clearing
if self.plot_item:
for item, name in zip(self.plot_item, self.labels()):
legend.addItem(item, name)


class MeshLineData(AbstractLineData):
class MeshLineData(AbstractLineData, SupportsTemporalMixin):
def __init__(self, layer, variables_indexes, variable, layer_numbers):
self.layer = layer
self.variables_indexes = variables_indexes
Expand All @@ -179,32 +212,40 @@ def __init__(self, layer, variables_indexes, variable, layer_numbers):
self.color_widget = self.unique_color_widget
self.legend_items = []
self.styling_data = np.array(self.variables)
# Cache cross-section lines drawn by storing their x,y values based on
# dataset and group index. Cache is cleared upon calling ``clear()``
# method.
self.cache = {}
self.time_and_group_index = (None, None)
self.dummy_widget = DummyWidget()

def load(self, geometry, resolution, datetime_range, **_):
if self.requires_static_index(
datetime_range
): # Just take the first one in such a case
plot_datetime_range = (
None # Fix datetime_range of cross_section_y_data to None
)

def load(self, geometry, resolution, datetime_range: QgsDateTimeRange, **_):
index = self.get_time_and_group_index(datetime_range)
plot_datetime_range = self.get_plot_datetime_range(datetime_range)

result = self.cache.get(index, None)
if result is not None:
x, y = result
else:
plot_datetime_range = datetime_range
n_lines = len(self.layer_numbers)
x = cross_section_x_data(self.layer, geometry, resolution)
y = np.empty((n_lines, x.size))
for i, k in enumerate(self.layer_numbers):
dataset_index = self.variables_indexes[self.variable][k]
y[i, :] = cross_section_y_data(
self.layer, geometry, dataset_index, x, plot_datetime_range
)
# Store in cache
self.cache[index] = (x, y)
self.time_and_group_index = index

n_lines = len(self.layer_numbers)
x = cross_section_x_data(self.layer, geometry, resolution)
y = np.empty((n_lines, x.size))
for i, k in enumerate(self.layer_numbers):
dataset_index = self.variables_indexes[self.variable][k]
y[i, :] = cross_section_y_data(
self.layer, geometry, dataset_index, x, plot_datetime_range
)
self.x = x
self.y = y
self.set_color_data()


class RasterLineData(AbstractLineData):
class RasterLineData(AbstractLineData, StaticOnlyMixin):
def __init__(self, layer, variables, variables_indexes):
self.layer = layer
self.variables = variables
Expand All @@ -217,6 +258,7 @@ def __init__(self, layer, variables, variables_indexes):
self.color_widget = self.unique_color_widget
self.legend_items = []
self.styling_data = np.array(variables)
self.cache = {}
self.dummy_widget = DummyWidget()

def load(self, geometry, resolution, **_):
Expand All @@ -236,7 +278,7 @@ def load(self, geometry, resolution, **_):
self.set_color_data()


class PointCrossSectionData(AbstractCrossSectionData):
class PointCrossSectionData(AbstractCrossSectionData, StaticOnlyMixin):
def select_geometry(self, geometry: QgsGeometry, buffer_distance: float):
buffered = geometry.buffer(buffer_distance, 4)
tmp_layer = QgsVectorLayer("Polygon", "temp", "memory")
Expand Down Expand Up @@ -417,7 +459,7 @@ def clear(self):
self.plot_item = None


class MeshData(AbstractCrossSectionData):
class MeshData(AbstractCrossSectionData, SupportsTemporalMixin):
def __init__(self, layer, variables_indexes, variable, layer_numbers):
self.layer = layer
self.variables_indexes = variables_indexes
Expand All @@ -434,39 +476,18 @@ def __init__(self, layer, variables_indexes, variable, layer_numbers):
self.color_widget = self.pseudocolor_widget
self.legend_items = []
self.styling_data = None
# Cache cross-sections drawn by storing their x,top,bottom,and z values
# based on dataset and group index. Cache is cleared upon calling
# ``clear()`` method.
self.cache = {}
self.sample_index = (None, None)
self.time_and_group_index = (None, None)
self.dummy_widget = DummyWidget()

def requires_loading(self, datetime_range):
def load(self, geometry, resolution, datetime_range: QgsDateTimeRange, **_):
group_index = self.variables_indexes[self.variable][self.layer_numbers[0]]
if self.requires_static_index(
datetime_range
): # Just take the first one in such a case
sample_index = (group_index, 0)
else:
index = self.layer.datasetIndexAtTime(datetime_range, group_index)
sample_index = (index.group(), index.dataset())

if sample_index == self.sample_index:
return False
else:
return True

def load(self, geometry, resolution, datetime_range, **_):
group_index = self.variables_indexes[self.variable][self.layer_numbers[0]]

if self.requires_static_index(
datetime_range
): # Just take the first one in such a case
sample_index = QgsMeshDatasetIndex(group=group_index, dataset=0)
plot_datetime_range = (
None # Fix datetime_range of cross_section_y_data to None
)
else:
sample_index = self.layer.datasetIndexAtTime(datetime_range, group_index)
plot_datetime_range = datetime_range
index = (sample_index.dataset(), sample_index.group())
index = self.get_time_and_group_index(datetime_range)
plot_datetime_range = self.get_plot_datetime_range(datetime_range)

# Get result from cache if available.
result = self.cache.get(index, None)
Expand Down Expand Up @@ -497,7 +518,7 @@ def load(self, geometry, resolution, datetime_range, **_):
)
# Store in cache
self.cache[index] = (x, top, bottom, z)
self.sample_index = index
self.time_and_group_index = index

self.x = x
self.y_top = top
Expand Down
6 changes: 5 additions & 1 deletion imodqgis/metadata.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ repository=https://github.com/Deltares/imod-qgis
hasProcessingProvider=no

# Uncomment the following line and add your changelog:
changelog= <p>0.5.1 - Bug fixes
changelog= <p>Unreleased - Bug fixes
- Cross-section: "As line(s)" button reactivated again after deactivation
- Cross-section: Don't crash when drawing linedata for different location on the same timestep
- Cross-section: Redraw mesh line data when time changes in temporal controller
<p>0.5.1 - Bug fixes
- Timeseries: Don't crash when Arrow file doesn't exist
- Timeseries: Don't crash when Arrow data table is empty
- Timeseries: Only list float variables in Arrow data
Expand Down

0 comments on commit e5f128c

Please sign in to comment.