diff --git a/test.ipynb b/test.ipynb
new file mode 100644
index 0000000000..86276ecaae
--- /dev/null
+++ b/test.ipynb
@@ -0,0 +1,1195 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Populating index\n"
+ ]
+ },
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "100%|██████████| 16/16 [00:00<00:00, 76.25it/s]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from torchgeo.datasets import Sentinel2\n",
+ "\n",
+ "data_dir = r\"tests\\data\\sentinel2\"\n",
+ "\n",
+ "ds = Sentinel2(data_dir, bands=[\"B02\", \"B03\", \"B04\", \"B08\"], cache=False, res=10)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "text/plain": [
+ "BoundingBox(minx=399960.0, maxx=401240.0, miny=4498720.0, maxy=4500000.0, mint=1555079321.0, maxt=1649927271.999999)"
+ ]
+ },
+ "execution_count": 2,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ds.bounds"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([1, 4, 13, 13])\n",
+ "[[datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51), datetime.datetime(2022, 4, 14, 11, 7, 51)]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "from torchgeo.datasets.utils import BoundingBox\n",
+ "full_t_query = BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1649927271.999999)\n",
+ "sample = ds[[full_t_query]]\n",
+ "print(sample[\"image\"].shape)\n",
+ "print(sample[\"dates\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "torch.Size([2, 4, 13, 13])\n",
+ "[[datetime.datetime(2019, 4, 12, 16, 28, 41), datetime.datetime(2019, 4, 14, 11, 7, 51)], [datetime.datetime(2022, 4, 12, 16, 28, 41), datetime.datetime(2022, 4, 14, 11, 7, 51)]]\n"
+ ]
+ }
+ ],
+ "source": [
+ "multi_t_query = [BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1555079321.0, maxt=1605264929),\n",
+ " BoundingBox(minx=399960.0, maxx=400088.0, miny=4498720.0, maxy=4498848.0, mint=1605264929.0, maxt=1649927272),\n",
+ " ]\n",
+ "sample = ds[multi_t_query]\n",
+ "print(sample[\"image\"].shape)\n",
+ "print(sample[\"dates\"])"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import plotly.express as px\n",
+ "\n",
+ "def plot(\n",
+ " sample: dict,\n",
+ " indices_to_plot,\n",
+ " show = False,\n",
+ " **kwargs,\n",
+ "):\n",
+ " \"\"\"Plots the image data from the given sample.\n",
+ "\n",
+ " Args:\n",
+ " sample (dict): A dictionary containing the image data returned by self.__get_item__. Should contain the key \"image\".\n",
+ " indices_to_plot (list, optional): A list of indices to plot. If not provided, the method will use the RGB bands defined in `self.rgb_bands`.\n",
+ " show (bool, optional): Whether to display the plot. Defaults to False.\n",
+ " **kwargs (dict): Additional keyword arguments to be passed to `px.imshow`.\n",
+ "\n",
+ " Returns:\n",
+ " fig: The plotly figure object.\n",
+ " \"\"\"\n",
+ " image = sample[\"image\"]\n",
+ "\n",
+ " # Reorder and rescale the image\n",
+ " if (sample[\"image\"].ndim == 4) and (sample[\"image\"].shape[0] > 1):\n",
+ " # Shape of image = [d, c, h, w]\n",
+ " image = image[:, indices_to_plot, :, :].permute(0, 2, 3, 1)\n",
+ " if image.shape[-1] == 1:\n",
+ " image = image.squeeze(-1)\n",
+ " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n",
+ "\n",
+ " fig = px.imshow(\n",
+ " image, animation_frame=0, labels={\"animation_frame\": \"Date\"}, **kwargs\n",
+ " )\n",
+ " # Todo, currently taking the first date, need to handle multiple dates\n",
+ " date_labels = [\n",
+ " dates[0].strftime(\"%m/%d/%Y, %H:%M:%S\") for dates in sample[\"dates\"]\n",
+ " ]\n",
+ " for i, label in enumerate(date_labels):\n",
+ " fig.layout.sliders[0].steps[i].label = label\n",
+ "\n",
+ " else:\n",
+ " image = image[indices_to_plot].permute(1, 2, 0)\n",
+ " image = torch.clamp(image / 10000, min=0, max=1).numpy()\n",
+ "\n",
+ " # Plot the image\n",
+ " fig = px.imshow(image, **kwargs)\n",
+ "\n",
+ " fig.update_xaxes(showticklabels=False).update_yaxes(showticklabels=False)\n",
+ " if show:\n",
+ " fig.show()\n",
+ " return fig"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "metadata": {},
+ "outputs": [
+ {
+ "data": {
+ "application/vnd.plotly.v1+json": {
+ "config": {
+ "plotlyServerURL": "https://plot.ly"
+ },
+ "data": [
+ {
+ "hovertemplate": "x: %{x}
y: %{y}",
+ "name": "0",
+ "source": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAANCAIAAAD9iXMrAAAAfElEQVR4XpWRsQ0CMQxFP0hXs8wtwRD0twk912aJXE/NOqGF4lEQRbHPkeBV9rflb8uip5B5AXAjmcoB0A8cvTDg/75HpwaYbRsXL7Q73tLkRvQ037hpk6RV+vpmZxJQ9yuQoNgaVzjXML5j9sLgH0/pZNL7bl7/1rS08AMeovHfi9t0qgAAAABJRU5ErkJggg==",
+ "type": "image",
+ "xaxis": "x",
+ "yaxis": "y"
+ }
+ ],
+ "frames": [
+ {
+ "data": [
+ {
+ "name": "0",
+ "source": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAANCAIAAAD9iXMrAAAAfElEQVR4XpWRsQ0CMQxFP0hXs8wtwRD0twk912aJXE/NOqGF4lEQRbHPkeBV9rflb8uip5B5AXAjmcoB0A8cvTDg/75HpwaYbRsXL7Q73tLkRvQ037hpk6RV+vpmZxJQ9yuQoNgaVzjXML5j9sLgH0/pZNL7bl7/1rS08AMeovHfi9t0qgAAAABJRU5ErkJggg==",
+ "type": "image"
+ }
+ ],
+ "layout": {
+ "margin": {
+ "t": 60
+ }
+ },
+ "name": "0"
+ },
+ {
+ "data": [
+ {
+ "name": "1",
+ "source": "data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAA0AAAANCAIAAAD9iXMrAAAAjUlEQVR4XnWQuxXDMAhFXzJLqtRawptoEw+QNku4zybyEm7T3BRGgHyU2wBPfARiQgXaKN014aldjywcUi+obLn+iucBn+QbPj3nTamnsbyY2QDeEVpHAW1x2Yi8zg3Im/0j7rK7d5zm5YIUd+lc7rsCbFQNf1vwBY2+4NhvBb4exREKSg0qxV/MNaHwA3n7fU3SxPVDAAAAAElFTkSuQmCC",
+ "type": "image"
+ }
+ ],
+ "layout": {
+ "margin": {
+ "t": 60
+ }
+ },
+ "name": "1"
+ }
+ ],
+ "layout": {
+ "margin": {
+ "t": 60
+ },
+ "sliders": [
+ {
+ "active": 0,
+ "currentvalue": {
+ "prefix": "Date="
+ },
+ "len": 0.9,
+ "pad": {
+ "b": 10,
+ "t": 60
+ },
+ "steps": [
+ {
+ "args": [
+ [
+ "0"
+ ],
+ {
+ "frame": {
+ "duration": 0,
+ "redraw": true
+ },
+ "fromcurrent": true,
+ "mode": "immediate",
+ "transition": {
+ "duration": 0,
+ "easing": "linear"
+ }
+ }
+ ],
+ "label": "04/12/2019, 16:28:41",
+ "method": "animate"
+ },
+ {
+ "args": [
+ [
+ "1"
+ ],
+ {
+ "frame": {
+ "duration": 0,
+ "redraw": true
+ },
+ "fromcurrent": true,
+ "mode": "immediate",
+ "transition": {
+ "duration": 0,
+ "easing": "linear"
+ }
+ }
+ ],
+ "label": "04/12/2022, 16:28:41",
+ "method": "animate"
+ }
+ ],
+ "x": 0.1,
+ "xanchor": "left",
+ "y": 0,
+ "yanchor": "top"
+ }
+ ],
+ "template": {
+ "data": {
+ "bar": [
+ {
+ "error_x": {
+ "color": "#2a3f5f"
+ },
+ "error_y": {
+ "color": "#2a3f5f"
+ },
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "bar"
+ }
+ ],
+ "barpolar": [
+ {
+ "marker": {
+ "line": {
+ "color": "#E5ECF6",
+ "width": 0.5
+ },
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "barpolar"
+ }
+ ],
+ "carpet": [
+ {
+ "aaxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "baxis": {
+ "endlinecolor": "#2a3f5f",
+ "gridcolor": "white",
+ "linecolor": "white",
+ "minorgridcolor": "white",
+ "startlinecolor": "#2a3f5f"
+ },
+ "type": "carpet"
+ }
+ ],
+ "choropleth": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "choropleth"
+ }
+ ],
+ "contour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "contour"
+ }
+ ],
+ "contourcarpet": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "contourcarpet"
+ }
+ ],
+ "heatmap": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmap"
+ }
+ ],
+ "heatmapgl": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "heatmapgl"
+ }
+ ],
+ "histogram": [
+ {
+ "marker": {
+ "pattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ }
+ },
+ "type": "histogram"
+ }
+ ],
+ "histogram2d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2d"
+ }
+ ],
+ "histogram2dcontour": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "histogram2dcontour"
+ }
+ ],
+ "mesh3d": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "type": "mesh3d"
+ }
+ ],
+ "parcoords": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "parcoords"
+ }
+ ],
+ "pie": [
+ {
+ "automargin": true,
+ "type": "pie"
+ }
+ ],
+ "scatter": [
+ {
+ "fillpattern": {
+ "fillmode": "overlay",
+ "size": 10,
+ "solidity": 0.2
+ },
+ "type": "scatter"
+ }
+ ],
+ "scatter3d": [
+ {
+ "line": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatter3d"
+ }
+ ],
+ "scattercarpet": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattercarpet"
+ }
+ ],
+ "scattergeo": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergeo"
+ }
+ ],
+ "scattergl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattergl"
+ }
+ ],
+ "scattermapbox": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scattermapbox"
+ }
+ ],
+ "scatterpolar": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolar"
+ }
+ ],
+ "scatterpolargl": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterpolargl"
+ }
+ ],
+ "scatterternary": [
+ {
+ "marker": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "type": "scatterternary"
+ }
+ ],
+ "surface": [
+ {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ },
+ "colorscale": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "type": "surface"
+ }
+ ],
+ "table": [
+ {
+ "cells": {
+ "fill": {
+ "color": "#EBF0F8"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "header": {
+ "fill": {
+ "color": "#C8D4E3"
+ },
+ "line": {
+ "color": "white"
+ }
+ },
+ "type": "table"
+ }
+ ]
+ },
+ "layout": {
+ "annotationdefaults": {
+ "arrowcolor": "#2a3f5f",
+ "arrowhead": 0,
+ "arrowwidth": 1
+ },
+ "autotypenumbers": "strict",
+ "coloraxis": {
+ "colorbar": {
+ "outlinewidth": 0,
+ "ticks": ""
+ }
+ },
+ "colorscale": {
+ "diverging": [
+ [
+ 0,
+ "#8e0152"
+ ],
+ [
+ 0.1,
+ "#c51b7d"
+ ],
+ [
+ 0.2,
+ "#de77ae"
+ ],
+ [
+ 0.3,
+ "#f1b6da"
+ ],
+ [
+ 0.4,
+ "#fde0ef"
+ ],
+ [
+ 0.5,
+ "#f7f7f7"
+ ],
+ [
+ 0.6,
+ "#e6f5d0"
+ ],
+ [
+ 0.7,
+ "#b8e186"
+ ],
+ [
+ 0.8,
+ "#7fbc41"
+ ],
+ [
+ 0.9,
+ "#4d9221"
+ ],
+ [
+ 1,
+ "#276419"
+ ]
+ ],
+ "sequential": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ],
+ "sequentialminus": [
+ [
+ 0,
+ "#0d0887"
+ ],
+ [
+ 0.1111111111111111,
+ "#46039f"
+ ],
+ [
+ 0.2222222222222222,
+ "#7201a8"
+ ],
+ [
+ 0.3333333333333333,
+ "#9c179e"
+ ],
+ [
+ 0.4444444444444444,
+ "#bd3786"
+ ],
+ [
+ 0.5555555555555556,
+ "#d8576b"
+ ],
+ [
+ 0.6666666666666666,
+ "#ed7953"
+ ],
+ [
+ 0.7777777777777778,
+ "#fb9f3a"
+ ],
+ [
+ 0.8888888888888888,
+ "#fdca26"
+ ],
+ [
+ 1,
+ "#f0f921"
+ ]
+ ]
+ },
+ "colorway": [
+ "#636efa",
+ "#EF553B",
+ "#00cc96",
+ "#ab63fa",
+ "#FFA15A",
+ "#19d3f3",
+ "#FF6692",
+ "#B6E880",
+ "#FF97FF",
+ "#FECB52"
+ ],
+ "font": {
+ "color": "#2a3f5f"
+ },
+ "geo": {
+ "bgcolor": "white",
+ "lakecolor": "white",
+ "landcolor": "#E5ECF6",
+ "showlakes": true,
+ "showland": true,
+ "subunitcolor": "white"
+ },
+ "hoverlabel": {
+ "align": "left"
+ },
+ "hovermode": "closest",
+ "mapbox": {
+ "style": "light"
+ },
+ "paper_bgcolor": "white",
+ "plot_bgcolor": "#E5ECF6",
+ "polar": {
+ "angularaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "radialaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "scene": {
+ "xaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "yaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ },
+ "zaxis": {
+ "backgroundcolor": "#E5ECF6",
+ "gridcolor": "white",
+ "gridwidth": 2,
+ "linecolor": "white",
+ "showbackground": true,
+ "ticks": "",
+ "zerolinecolor": "white"
+ }
+ },
+ "shapedefaults": {
+ "line": {
+ "color": "#2a3f5f"
+ }
+ },
+ "ternary": {
+ "aaxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "baxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ },
+ "bgcolor": "#E5ECF6",
+ "caxis": {
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": ""
+ }
+ },
+ "title": {
+ "x": 0.05
+ },
+ "xaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ },
+ "yaxis": {
+ "automargin": true,
+ "gridcolor": "white",
+ "linecolor": "white",
+ "ticks": "",
+ "title": {
+ "standoff": 15
+ },
+ "zerolinecolor": "white",
+ "zerolinewidth": 2
+ }
+ }
+ },
+ "updatemenus": [
+ {
+ "buttons": [
+ {
+ "args": [
+ null,
+ {
+ "frame": {
+ "duration": 500,
+ "redraw": true
+ },
+ "fromcurrent": true,
+ "mode": "immediate",
+ "transition": {
+ "duration": 500,
+ "easing": "linear"
+ }
+ }
+ ],
+ "label": "▶",
+ "method": "animate"
+ },
+ {
+ "args": [
+ [
+ null
+ ],
+ {
+ "frame": {
+ "duration": 0,
+ "redraw": true
+ },
+ "fromcurrent": true,
+ "mode": "immediate",
+ "transition": {
+ "duration": 0,
+ "easing": "linear"
+ }
+ }
+ ],
+ "label": "◼",
+ "method": "animate"
+ }
+ ],
+ "direction": "left",
+ "pad": {
+ "r": 10,
+ "t": 70
+ },
+ "showactive": false,
+ "type": "buttons",
+ "x": 0.1,
+ "xanchor": "right",
+ "y": 0,
+ "yanchor": "top"
+ }
+ ],
+ "xaxis": {
+ "anchor": "y",
+ "domain": [
+ 0,
+ 1
+ ],
+ "showticklabels": false
+ },
+ "yaxis": {
+ "anchor": "x",
+ "domain": [
+ 0,
+ 1
+ ],
+ "showticklabels": false
+ }
+ }
+ }
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "plot(sample, show=False, indices_to_plot=[2, 1, 0])"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "cca",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.14"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/torchgeo/datasets/geo.py b/torchgeo/datasets/geo.py
index 26a035d427..9ea6de0774 100644
--- a/torchgeo/datasets/geo.py
+++ b/torchgeo/datasets/geo.py
@@ -12,11 +12,15 @@
import sys
import warnings
from collections.abc import Callable, Iterable, Sequence
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from datetime import datetime
+from re import Pattern
from typing import Any, ClassVar, cast
import fiona
import fiona.transform
import numpy as np
+import pandas as pd
import pyproj
import rasterio
import rasterio.merge
@@ -31,6 +35,7 @@
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader as pil_loader
+from tqdm import tqdm
from .errors import DatasetNotFoundError
from .utils import (
@@ -125,7 +130,7 @@ def __init__(
self.index = Index(interleaved=False, properties=Property(dimension=3))
@abc.abstractmethod
- def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
+ def __getitem__(self, query: Iterable[BoundingBox]) -> dict[str, Any]:
"""Retrieve image/mask and metadata indexed by query.
Args:
@@ -377,6 +382,9 @@ class RasterDataset(GeoDataset):
#: Color map for the dataset, used for plotting
cmap: ClassVar[dict[int, tuple[int, int, int, int]]] = {}
+ #: Nodata value for the dataset
+ nodata_value: int | None = None
+
@property
def dtype(self) -> torch.dtype:
"""The dtype of the dataset (overrides the dtype of the data file via a cast).
@@ -420,6 +428,7 @@ def __init__(
bands: Sequence[str] | None = None,
transforms: Callable[[dict[str, Any]], dict[str, Any]] | None = None,
cache: bool = True,
+ drop_nodata: bool = True,
) -> None:
"""Initialize a new RasterDataset instance.
@@ -433,6 +442,7 @@ def __init__(
transforms: a function/transform that takes an input sample
and returns a transformed version
cache: if True, cache file handle to speed up repeated sampling
+ drop_nodata: Drop the sample if any pixel contains nodata value.
Raises:
DatasetNotFoundError: If dataset is not found.
@@ -445,50 +455,10 @@ def __init__(
self.paths = paths
self.bands = bands or self.all_bands
self.cache = cache
+ self.drop_nodata = drop_nodata
- # Populate the dataset index
- i = 0
- filename_regex = re.compile(self.filename_regex, re.VERBOSE)
- for filepath in self.files:
- match = re.match(filename_regex, os.path.basename(filepath))
- if match is not None:
- try:
- with rasterio.open(filepath) as src:
- # See if file has a color map
- if len(self.cmap) == 0:
- try:
- self.cmap = src.colormap(1) # type: ignore[misc]
- except ValueError:
- pass
-
- if crs is None:
- crs = src.crs
-
- with WarpedVRT(src, crs=crs) as vrt:
- minx, miny, maxx, maxy = vrt.bounds
- if res is None:
- res = vrt.res[0]
- except rasterio.errors.RasterioIOError:
- # Skip files that rasterio is unable to read
- continue
- else:
- mint = self.mint
- maxt = self.maxt
- if 'date' in match.groupdict():
- date = match.group('date')
- mint, maxt = disambiguate_timestamp(date, self.date_format)
- elif 'start' in match.groupdict() and 'stop' in match.groupdict():
- start = match.group('start')
- stop = match.group('stop')
- mint, _ = disambiguate_timestamp(start, self.date_format)
- _, maxt = disambiguate_timestamp(stop, self.date_format)
-
- coords = (minx, maxx, miny, maxy, mint, maxt)
- self.index.insert(i, coords, filepath)
- i += 1
-
- if i == 0:
- raise DatasetNotFoundError(self)
+ crs, res = self.try_set_metadata(crs, res)
+ self._populate_index(crs)
if not self.separate_files:
self.band_indexes = None
@@ -505,50 +475,255 @@ def __init__(
raise AssertionError(msg)
self._crs = cast(CRS, crs)
- self._res = cast(float, res)
+ self._res = res
- def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
- """Retrieve image/mask and metadata indexed by query.
+ def try_set_metadata(self, crs: CRS, res: float | None) -> tuple[CRS, float]:
+ """Try to set the CRS, resolution and cmap from the first file in the dataset.
Args:
- query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
+ crs: The coordinate reference system (CRS) to use.
+ res: The resolution of the dataset in units of CRS.
Returns:
- sample of image/mask and metadata at that index
+ tuple: The coordinate reference system (CRS) and resolution of the dataset.
+ """
+ with rasterio.open(self.files[0]) as src:
+ # See if file has a color map
+ if len(self.cmap) == 0:
+ try:
+ self.cmap = src.colormap(1) # type: ignore[misc]
+ except ValueError:
+ pass
+
+ if crs is None:
+ crs = src.crs
+ if self.nodata_value is None:
+ src_nodata_value = src.nodata
+ if src_nodata_value is not None:
+ self.nodata_value = src_nodata_value
+ elif self.drop_nodata:
+ raise ValueError(
+ 'drop_nodata is True but nodata is not set in the dataset and could not be read from the file.'
+ )
- Raises:
- IndexError: if query is not found in the index
+ with WarpedVRT(src, crs=crs) as vrt:
+ if res is None:
+ res = vrt.res[0]
+ return crs, res
+
+ def _get_bounds(self, filepath: str, crs: CRS) -> tuple[tuple[float], str]:
+ """Retrieves the bounds for a given file path and coordinate reference system (CRS).
+
+ Args:
+ filepath (str): The path to the file.
+ crs (str): The coordinate reference system (CRS) to use.
+
+ Returns:
+ tuple[tuple[float], str]: A tuple containing the bbox coordinates and the filepath.
+ The bbox coordinates are represented as a tuple of floats in the following order:
+ (minx, maxx, miny, maxy, mint, maxt).
+ """
+ filename_regex = re.compile(self.filename_regex, re.VERBOSE)
+
+ try:
+ with rasterio.open(filepath) as src:
+ with WarpedVRT(src, crs=crs) as vrt:
+ minx, miny, maxx, maxy = vrt.bounds
+ match = re.match(filename_regex, os.path.basename(filepath))
+ if not match:
+ raise ValueError(f'No match found for {os.path.basename(filepath)}')
+ except rasterio.errors.RasterioIOError as e:
+ raise FileNotFoundError(f'Error reading {filepath}') from e
+ else:
+ mint = self.mint
+ maxt = self.maxt
+ if 'date' in match.groupdict():
+ date = match.group('date')
+ mint, maxt = disambiguate_timestamp(date, self.date_format)
+ elif 'start' in match.groupdict() and 'stop' in match.groupdict():
+ start = match.group('start')
+ stop = match.group('stop')
+ mint, _ = disambiguate_timestamp(start, self.date_format)
+ _, maxt = disambiguate_timestamp(stop, self.date_format)
+ else:
+ # TODO: Optionally, revert to no_date option if date is not found
+ pass
+
+ bbox = (
+ float(minx),
+ float(maxx),
+ float(miny),
+ float(maxy),
+ float(mint),
+ float(maxt),
+ )
+ return bbox, filepath
+
+ def _compile_and_check_filename_regex(self) -> Pattern:
+ """Compiles and checks the filename whether a valid regex pattern is supplied.
+
+ Returns:
+ Pattern: The compiled regex pattern.
+
+ """
+ if 'band' not in self.filename_regex and self.separate_files:
+ raise ValueError(
+ 'The term `band` is not in the filename_regex, but separate_files=True. At least provide a regex pattern to distinguish bands.'
+ )
+ return re.compile(self.filename_regex, re.VERBOSE)
+
+ def _populate_index(self, crs: CRS) -> None:
+ """Populates the dataset index by retrieving index parameters for each filepath in the dataset paths.
+
+ This method uses a ThreadPoolExecutor to concurrently retrieve index parameters for each file that matches the
+ filename regex. The retrieved parameters are then inserted into the dataset index.
+
+ Args:
+ crs (str): The coordinate reference system used for warping while opening the file.
+
+ Returns:
+ None
+ """
+ print('Populating index')
+ filename_regex = self._compile_and_check_filename_regex()
+
+ # Populate the dataset index
+ def has_match(filepath: str) -> bool:
+ """Check if the given filepath matches the specified filename regex and its band is included in `self.bands`.
+
+ Args:
+ filepath (str): The path to the file to be checked.
+
+ Returns:
+ bool: True if the filepath matches the filename regex and conditions, False otherwise.
+ """
+ match = re.match(filename_regex, os.path.basename(filepath))
+ if match is not None:
+ if self.separate_files:
+ return match.group('band') in self.bands
+ else:
+ return True
+ else:
+ return False
+
+ with ThreadPoolExecutor(max_workers=8) as executor:
+ futures = [
+ executor.submit(self._get_bounds, filepath, crs)
+ for filepath in self.files
+ if has_match(filepath)
+ ]
+ i = 0
+ for f in tqdm(as_completed(futures), total=len(futures)):
+ bbox, filepath = f.result()
+ self.index.insert(i, bbox, filepath)
+ i += 1
+
+ # TODO: Sequential version: choose which to use.
+ # i = 0
+ # for filepath in self.files:
+ # if has_match(filepath):
+ # bbox, filepath = self._get_bounds(filepath, crs)
+ # self.index.insert(i, bbox, filepath)
+ # i += 1
+
+ if i == 0:
+ raise DatasetNotFoundError(self)
+
+ def _get_regex_groups_as_df(self, filepaths: list[str]) -> pd.DataFrame:
+ """Extracts the regex metadata from a list of filepaths.
+
+ Args:
+ filepaths (list): A list of filepaths.
+
+ Returns:
+ pandas.DataFrame: A DataFrame containing the extracted file metadata.
+ """
+ filename_regex = re.compile(self.filename_regex, re.VERBOSE)
+ file_metadata = []
+ for filepath in filepaths:
+ match = re.match(filename_regex, os.path.basename(filepath))
+ if match:
+ meta = match.groupdict()
+ meta.update({'filepath': filepath})
+ file_metadata.append(meta)
+
+ return pd.DataFrame(file_metadata)
+
+ def __merge_single_bbox(
+ self, query: BoundingBox
+ ) -> tuple[torch.Tensor | None, list[str]]:
+ """Merge all files that intersect with a single bounding box.
+
+ Args:
+ query: (BoundingBox) Bounds of the query
+
+ Returns:
+ tuple[torch.Tensor, list[str]]: A tuple containing the merged tensor and the list of dates that produced that tensor.
"""
hits = self.index.intersection(tuple(query), objects=True)
filepaths = cast(list[str], [hit.object for hit in hits])
-
if not filepaths:
raise IndexError(
f'query: {query} not found in index with bounds: {self.bounds}'
)
+ file_df = self._get_regex_groups_as_df(filepaths)
+
if self.separate_files:
- data_list: list[Tensor] = []
- filename_regex = re.compile(self.filename_regex, re.VERBOSE)
- for band in self.bands:
- band_filepaths = []
- for filepath in filepaths:
- filename = os.path.basename(filepath)
- directory = os.path.dirname(filepath)
- match = re.match(filename_regex, filename)
- if match:
- if 'band' in match.groupdict():
- start = match.start('band')
- end = match.end('band')
- filename = filename[:start] + band + filename[end:]
- filepath = os.path.join(directory, filename)
- band_filepaths.append(filepath)
- data_list.append(self._merge_files(band_filepaths, query))
- data = torch.cat(data_list)
+ grouped = file_df.groupby(['band']).agg(list)
+ res_for_bbox = []
+ for band, filepaths in grouped.sort_values('band')['filepath'].items():
+ single_bbox_single_band = self._merge_files(filepaths, query)
+ res_for_bbox.append(single_bbox_single_band)
+
+ res_single_bbox = (
+ torch.cat(res_for_bbox).unsqueeze(0)
+ if len(res_for_bbox) == len(self.bands)
+ else None
+ )
else:
- data = self._merge_files(filepaths, query, self.band_indexes)
+ res_single_bbox = self._merge_files(filepaths, query, self.band_indexes)
+ # TODO ideally, we want feedback from rasterio.merge.merge to know which dates were merged
+ dates = file_df['date'].unique().tolist() # TODO what if no dates?
+ dates = [datetime.strptime(date, self.date_format) for date in dates]
+
+ if res_single_bbox is not None:
+ # Check if res_single_date contains nodata values and only append if it doesn't
+ if not self.drop_nodata or not torch.any(
+ res_single_bbox == self.nodata_value
+ ):
+ return res_single_bbox, dates
+ return None, []
+
+ def __merge_query(
+ self, query: Iterable[BoundingBox]
+ ) -> tuple[torch.Tensor, list[str]]:
+ res = []
+ valid_dates = []
+ for bbox in query:
+ res_single_bbox, dates = self.__merge_single_bbox(bbox)
+ if res_single_bbox is not None:
+ res.append(res_single_bbox)
+ valid_dates.append(dates)
+ return torch.cat(res), valid_dates
+
+ def __getitem__(self, query: BoundingBox | Iterable[BoundingBox]) -> dict[str, Any]:
+ """Retrieve image/mask and metadata indexed by query.
+
+ Args:
+ query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
+
+ Returns:
+ sample of image/mask and metadata at that index
+
+ Raises:
+ IndexError: if query is not found in the index
+ """
+ if isinstance(query, BoundingBox):
+ query = [query]
+ data, valid_dates = self.__merge_query(query)
- sample = {'crs': self.crs, 'bounds': query}
+ sample = {'crs': self.crs, 'bounds': query, 'dates': valid_dates}
data = data.to(self.dtype)
if self.is_image:
diff --git a/torchgeo/datasets/sentinel.py b/torchgeo/datasets/sentinel.py
index 79637931ad..b7d09b891a 100644
--- a/torchgeo/datasets/sentinel.py
+++ b/torchgeo/datasets/sentinel.py
@@ -266,7 +266,7 @@ class Sentinel2(Sentinel):
# https://sentinels.copernicus.eu/web/sentinel/user-guides/sentinel-2-msi/naming-convention
# https://sentinel.esa.int/documents/247904/685211/Sentinel-2-MSI-L2A-Product-Format-Specifications.pdf
- filename_glob = 'T*_*_{}*.*'
+ filename_glob = 'T*_*_*.*'
filename_regex = r"""
^T(?P\d{{2}}[A-Z]{{3}})
_(?P\d{{8}}T\d{{6}})
@@ -295,6 +295,7 @@ class Sentinel2(Sentinel):
rgb_bands = ('B04', 'B03', 'B02')
separate_files = True
+ nodata_value = 0
def __init__(
self,
@@ -325,7 +326,7 @@ def __init__(
*root* was renamed to *paths*
"""
bands = bands or self.all_bands
- self.filename_glob = self.filename_glob.format(bands[0])
+ # self.filename_glob = self.filename_glob.format(bands[0])
self.filename_regex = self.filename_regex.format(res)
super().__init__(paths, crs, res, bands, transforms, cache)