diff --git a/damnit/gui/main_window.py b/damnit/gui/main_window.py index 2a97c039..a466e25c 100644 --- a/damnit/gui/main_window.py +++ b/damnit/gui/main_window.py @@ -592,17 +592,29 @@ def inspect_data(self, index): pp.show() return + if variable.type_hint() is DataType.Dataset: + QMessageBox.warning(self, "Can't inspect variable", + f"'{quantity}' is a Xarray Dataset (not supported).") + try: - data = xr.DataArray(variable.read()) + data = variable.read() except KeyError: log.warning(f'"{quantity}" not found in {variable.file}...') return + if variable.type_hint() is DataType.DataArray: + canvas = Canvas(self, dataarray=data, title=f'{variable.title} (run {run})') + self._canvas_inspect.append(canvas) + canvas.show() + return + + data = xr.DataArray(data) + if data.ndim == 2 or (data.ndim == 3 and data.shape[-1] in (3, 4)): canvas = Canvas( self, image=data.data, - title=f"{quantity_title} (run {run})", + title=f"{variable.title} (run {run})", ) else: if data.ndim == 0: @@ -615,22 +627,15 @@ def inspect_data(self, index): f"'{quantity}' with {data.ndim} dimensions (not supported).") return - # Use the train ID if it's been saved, otherwise generate an X axis - if "trainId" in data.coords: - x = data.trainId - else: - x = np.arange(len(data)) - canvas = Canvas( self, - x=[fix_data_for_plotting(x)], + x=[np.arange(len(data))], y=[fix_data_for_plotting(data)], xlabel=f"Event (run {run})", - ylabel=quantity_title, + ylabel=variable.title, fmt="o", ) - self._canvas_inspect.append(canvas) canvas.show() diff --git a/damnit/gui/plot.py b/damnit/gui/plot.py index 9e48f33c..07f6c9b8 100644 --- a/damnit/gui/plot.py +++ b/damnit/gui/plot.py @@ -3,7 +3,6 @@ import pandas as pd import tempfile import xarray as xr -from pandas.api.types import is_numeric_dtype from PyQt5.QtCore import Qt from PyQt5 import QtCore, QtWidgets, QtGui @@ -30,6 +29,7 @@ def __init__( x=[], y=[], image=None, + dataarray=None, xlabel="", ylabel="", title=None, @@ -64,6 +64,7 @@ def __init__( self._axis.set_xlabel(xlabel) self._axis.set_ylabel(ylabel if not is_histogram else "Probability density") if title is not None: + self.setWindowTitle(title) self._axis.set_title(title) elif is_histogram: self._axis.set_title(f"Probability density of {xlabel}") @@ -154,7 +155,7 @@ def __init__( self._zoom_factory = None self._panmanager = PanManager(self.figure, MouseButton.LEFT) - self.update_canvas(x, y, image, legend=legend) + self.update_canvas(x, y, image, dataarray, legend=legend) # Take a guess at a good aspect ratio if it's an image if image is not None: @@ -232,11 +233,16 @@ def set_dynamic_aspect(self, is_dynamic): self._axis.set_aspect(aspect) self.figure.canvas.draw() - def update_canvas(self, xs=None, ys=None, image=None, legend=None, series_names=["default"]): + def update_canvas(self, xs=None, ys=None, image=None, dataarray=None, legend=None, series_names=["default"]): cmap = matplotlib.colormaps["tab20"] self._nan_warning_label.hide() - if (xs is None and ys is None) and self.plot_type == "histogram1D": + if dataarray is not None: + if dataarray.ndim == 3 and dataarray.shape[-1] in (3, 4): + dataarray.plot.imshow(ax=self._axis) + else: + dataarray.plot(ax=self._axis) + elif (xs is None and ys is None) and self.plot_type == "histogram1D": xs, ys = [], [] for series in self._lines.keys():