Skip to content

Commit

Permalink
Plot xarray (#289)
Browse files Browse the repository at this point in the history
plot xarray dataarray
  • Loading branch information
tmichela authored Jul 19, 2024
1 parent 3c3d8d5 commit 1a364da
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
27 changes: 16 additions & 11 deletions damnit/gui/main_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,17 +592,29 @@ def inspect_data(self, index):
pp.show()
return

if variable.type_hint() is DataType.Dataset:
QMessageBox.warning(self, "Can't inspect variable",
f"'{quantity}' is a Xarray Dataset (not supported).")

try:
data = xr.DataArray(variable.read())
data = variable.read()
except KeyError:
log.warning(f'"{quantity}" not found in {variable.file}...')
return

if variable.type_hint() is DataType.DataArray:
canvas = Canvas(self, dataarray=data, title=f'{variable.title} (run {run})')
self._canvas_inspect.append(canvas)
canvas.show()
return

data = xr.DataArray(data)

if data.ndim == 2 or (data.ndim == 3 and data.shape[-1] in (3, 4)):
canvas = Canvas(
self,
image=data.data,
title=f"{quantity_title} (run {run})",
title=f"{variable.title} (run {run})",
)
else:
if data.ndim == 0:
Expand All @@ -615,22 +627,15 @@ def inspect_data(self, index):
f"'{quantity}' with {data.ndim} dimensions (not supported).")
return

# Use the train ID if it's been saved, otherwise generate an X axis
if "trainId" in data.coords:
x = data.trainId
else:
x = np.arange(len(data))

canvas = Canvas(
self,
x=[fix_data_for_plotting(x)],
x=[np.arange(len(data))],
y=[fix_data_for_plotting(data)],
xlabel=f"Event (run {run})",
ylabel=quantity_title,
ylabel=variable.title,
fmt="o",
)


self._canvas_inspect.append(canvas)
canvas.show()

Expand Down
14 changes: 10 additions & 4 deletions damnit/gui/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pandas as pd
import tempfile
import xarray as xr
from pandas.api.types import is_numeric_dtype

from PyQt5.QtCore import Qt
from PyQt5 import QtCore, QtWidgets, QtGui
Expand All @@ -30,6 +29,7 @@ def __init__(
x=[],
y=[],
image=None,
dataarray=None,
xlabel="",
ylabel="",
title=None,
Expand Down Expand Up @@ -64,6 +64,7 @@ def __init__(
self._axis.set_xlabel(xlabel)
self._axis.set_ylabel(ylabel if not is_histogram else "Probability density")
if title is not None:
self.setWindowTitle(title)
self._axis.set_title(title)
elif is_histogram:
self._axis.set_title(f"Probability density of {xlabel}")
Expand Down Expand Up @@ -154,7 +155,7 @@ def __init__(
self._zoom_factory = None
self._panmanager = PanManager(self.figure, MouseButton.LEFT)

self.update_canvas(x, y, image, legend=legend)
self.update_canvas(x, y, image, dataarray, legend=legend)

# Take a guess at a good aspect ratio if it's an image
if image is not None:
Expand Down Expand Up @@ -232,11 +233,16 @@ def set_dynamic_aspect(self, is_dynamic):
self._axis.set_aspect(aspect)
self.figure.canvas.draw()

def update_canvas(self, xs=None, ys=None, image=None, legend=None, series_names=["default"]):
def update_canvas(self, xs=None, ys=None, image=None, dataarray=None, legend=None, series_names=["default"]):
cmap = matplotlib.colormaps["tab20"]
self._nan_warning_label.hide()

if (xs is None and ys is None) and self.plot_type == "histogram1D":
if dataarray is not None:
if dataarray.ndim == 3 and dataarray.shape[-1] in (3, 4):
dataarray.plot.imshow(ax=self._axis)
else:
dataarray.plot(ax=self._axis)
elif (xs is None and ys is None) and self.plot_type == "histogram1D":
xs, ys = [], []

for series in self._lines.keys():
Expand Down

0 comments on commit 1a364da

Please sign in to comment.