Skip to content

Commit

Permalink
Moved another function from h3 to common
Browse files Browse the repository at this point in the history
  • Loading branch information
ndemaio committed Oct 24, 2023
1 parent 5c1f201 commit e253247
Show file tree
Hide file tree
Showing 2 changed files with 137 additions and 128 deletions.
133 changes: 133 additions & 0 deletions raster2dggs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,29 @@
import errno
import tempfile
import logging
import threading
import rioxarray
import dask
import click_log

import rasterio as rio
import pandas as pd
import pyarrow.parquet as pq

from typing import Union, Callable
from pathlib import Path
from rasterio import crs
from rasterio.vrt import WarpedVRT
from rasterio.enums import Resampling
from tqdm import tqdm
from tqdm.dask import TqdmCallback
import dask.dataframe as dd
import xarray as xr

from concurrent.futures import ThreadPoolExecutor, as_completed

from urllib.parse import urlparse
from rasterio.warp import calculate_default_transform

import raster2dggs.constants as const

Expand Down Expand Up @@ -188,3 +199,125 @@ def address_boundary_issues(
)

return output


def initial_index(
dggs: str,
dggsfunc: Callable,
parent_groupby: Callable,
raster_input: Union[Path, str],
output: Path,
resolution: int,
parent_res: Union[None, int],
warp_args: dict,
**kwargs,
) -> Path:
"""
Responsible for opening the raster_input, and performing DGGS indexing per window of a WarpedVRT.
A WarpedVRT is used to enforce reprojection to https://epsg.io/4326, which is required for H3 indexing.
It also allows on-the-fly resampling of the input, which is useful if the target DGGS resolution exceeds the resolution
of the input.
This function passes a path to a temporary directory (which contains the output of this "stage 1" processing) to
a secondary function that addresses issues at the boundaries of raster windows.
"""
parent_res = get_parent_res(dggs, parent_res, resolution)
LOGGER.info(
"Indexing %s at %s resolution %d, parent resolution %d",
raster_input,
str.upper(dggs),
resolution,
parent_res,
)

with tempfile.TemporaryDirectory() as tmpdir:
LOGGER.debug(f"Create temporary directory {tmpdir}")

# https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.calculate_default_transform
with rio.Env(CHECK_WITH_INVERT_PROJ=True):
with rio.open(raster_input) as src:
LOGGER.debug("Source CRS: %s", src.crs)
# VRT used to avoid additional disk use given the potential for reprojection to 4326 prior to H3 indexing
band_names = src.descriptions

upscale_factor = kwargs["upscale"]
if upscale_factor > 1:
dst_crs = warp_args["crs"]
transform, width, height = calculate_default_transform(
src.crs,
dst_crs,
src.width,
src.height,
*src.bounds,
dst_width=src.width * upscale_factor,
dst_height=src.height * upscale_factor,
)
upsample_args = dict(
{"transform": transform, "width": width, "height": height}
)
LOGGER.debug(upsample_args)
else:
upsample_args = dict({})

with WarpedVRT(
src, src_crs=src.crs, **warp_args, **upsample_args
) as vrt:
LOGGER.debug("VRT CRS: %s", vrt.crs)
da: xr.Dataset = rioxarray.open_rasterio(
vrt,
lock=dask.utils.SerializableLock(),
masked=True,
default_name=const.DEFAULT_NAME,
).chunk(**{"y": "auto", "x": "auto"})

windows = [window for _, window in vrt.block_windows()]
LOGGER.debug(
"%d windows (the same number of partitions will be created)",
len(windows),
)

write_lock = threading.Lock()

def process(window):
sdf = da.rio.isel_window(window)

result = dggsfunc(
sdf,
resolution,
parent_res,
vrt.nodata,
band_labels=band_names,
)

with write_lock:
pq.write_to_dataset(
result,
root_path=tmpdir,
compression=kwargs["compression"],
)

return None

with tqdm(total=len(windows), desc="Raster windows") as pbar:
with ThreadPoolExecutor(
max_workers=kwargs["threads"]
) as executor:
futures = [
executor.submit(process, window) for window in windows
]
for future in as_completed(futures):
result = future.result()
pbar.update(1)

LOGGER.debug("Stage 1 (primary indexing) complete")
return address_boundary_issues(
dggs,
parent_groupby,
tmpdir,
output,
resolution,
parent_res,
**kwargs,
)
132 changes: 4 additions & 128 deletions raster2dggs/h3.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,15 @@
from concurrent.futures import ThreadPoolExecutor, as_completed
from numbers import Number
import numpy as np
from pathlib import Path
import tempfile
import threading
from typing import Callable, Tuple, Union

import click
import click_log
import dask
import h3pandas # Necessary import despite lack of explicit use
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import rasterio as rio
from rasterio.enums import Resampling
from rasterio.vrt import WarpedVRT
from rasterio.warp import calculate_default_transform
import rioxarray
from tqdm import tqdm
import xarray as xr

import raster2dggs.constants as const
Expand Down Expand Up @@ -63,124 +54,6 @@ def _h3func(
return pa.Table.from_pandas(h3index)


def _initial_index(
raster_input: Union[Path, str],
output: Path,
resolution: int,
parent_res: Union[None, int],
warp_args: dict,
**kwargs,
) -> Path:
"""
Responsible for opening the raster_input, and performing H3 indexing per window of a WarpedVRT.
A WarpedVRT is used to enforce reprojection to https://epsg.io/4326, which is required for H3 indexing.
It also allows on-the-fly resampling of the input, which is useful if the target H3 resolution exceeds the resolution
of the input.
This function passes a path to a temporary directory (which contains the output of this "stage 1" processing) to
a secondary function that addresses issues at the boundaries of raster windows.
"""
parent_res = common.get_parent_res("h3", parent_res, resolution)
common.LOGGER.info(
"Indexing %s at H3 resolution %d, parent resolution %d",
raster_input,
resolution,
parent_res,
)

with tempfile.TemporaryDirectory() as tmpdir:
common.LOGGER.debug(f"Create temporary directory {tmpdir}")

# https://rasterio.readthedocs.io/en/latest/api/rasterio.warp.html#rasterio.warp.calculate_default_transform
with rio.Env(CHECK_WITH_INVERT_PROJ=True):
with rio.open(raster_input) as src:
common.LOGGER.debug("Source CRS: %s", src.crs)
# VRT used to avoid additional disk use given the potential for reprojection to 4326 prior to H3 indexing
band_names = src.descriptions

upscale_factor = kwargs["upscale"]
if upscale_factor > 1:
dst_crs = warp_args["crs"]
transform, width, height = calculate_default_transform(
src.crs,
dst_crs,
src.width,
src.height,
*src.bounds,
dst_width=src.width * upscale_factor,
dst_height=src.height * upscale_factor,
)
upsample_args = dict(
{"transform": transform, "width": width, "height": height}
)
common.LOGGER.debug(upsample_args)
else:
upsample_args = dict({})

with WarpedVRT(
src, src_crs=src.crs, **warp_args, **upsample_args
) as vrt:
common.LOGGER.debug("VRT CRS: %s", vrt.crs)
da: xr.Dataset = rioxarray.open_rasterio(
vrt,
lock=dask.utils.SerializableLock(),
masked=True,
default_name=const.DEFAULT_NAME,
).chunk(**{"y": "auto", "x": "auto"})

windows = [window for _, window in vrt.block_windows()]
common.LOGGER.debug(
"%d windows (the same number of partitions will be created)",
len(windows),
)

write_lock = threading.Lock()

def process(window):
sdf = da.rio.isel_window(window)

result = _h3func(
sdf,
resolution,
parent_res,
vrt.nodata,
band_labels=band_names,
)

with write_lock:
pq.write_to_dataset(
result,
root_path=tmpdir,
compression=kwargs["compression"],
)

return None

with tqdm(total=len(windows), desc="Raster windows") as pbar:
with ThreadPoolExecutor(
max_workers=kwargs["threads"]
) as executor:
futures = [
executor.submit(process, window) for window in windows
]
for future in as_completed(futures):
result = future.result()
pbar.update(1)

common.LOGGER.debug("Stage 1 (primary indexing) complete")
return common.address_boundary_issues(
"h3",
_h3_parent_groupby,
tmpdir,
output,
resolution,
parent_res,
**kwargs,
)


def _h3_parent_groupby(
df, resolution: int, aggfunc: Union[str, Callable], decimals: int
):
Expand Down Expand Up @@ -313,7 +186,10 @@ def h3(
overwrite,
)

_initial_index(
common.initial_index(
"h3",
_h3func,
_h3_parent_groupby,
raster_input,
output_directory,
int(resolution),
Expand Down

0 comments on commit e253247

Please sign in to comment.