-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathtest_workflow_raster_stream.py
156 lines (117 loc) · 4.66 KB
/
test_workflow_raster_stream.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
'''Tests for raster streaming workflows'''
import asyncio
from typing import List
import unittest
import unittest.mock
from uuid import UUID
from datetime import datetime
import json
import rioxarray
import pyarrow as pa
import xarray as xr
from geoengine.types import RasterBandDescriptor
import geoengine as ge
from . import UrllibMocker
class MockWebsocket:
'''Mock for websockets.client.connect'''
def __init__(self):
'''Create a mock websocket with some data'''
self.__tiles = []
for time in [datetime(2014, 1, 1, 0, 0, 0), datetime(2014, 1, 2, 0, 0, 0)]:
for tiles in read_data():
self.__tiles.append(arrow_bytes(tiles, ge.TimeInterval(start=time), 0))
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
@property
def open(self) -> bool:
'''Mock open impl'''
return len(self.__tiles) > 0
async def recv(self):
return self.__tiles.pop()
async def send(self, *args):
pass
async def close(self):
pass
def read_data() -> List[xr.DataArray]:
'''Slice a raster into 4 parts'''
whole = rioxarray.open_rasterio("tests/responses/ndvi.tiff")
if isinstance(whole, list):
raise TypeError("Expected Dataset not List")
whole = whole.isel(band=0)
parts = [
whole[:4, :4],
whole[4:, :4],
whole[:4, 4:],
whole[4:, 4:],
]
return parts
def arrow_bytes(data: xr.DataArray, time: ge.TimeInterval, band: int) -> bytes:
'''Convert a xarray.DataArray into an Arrow record batch within an IPC file'''
array = pa.array(data.to_numpy().reshape(-1))
batch = pa.RecordBatch.from_arrays([array], ["data"])
schema = batch.schema.with_metadata({
"geoTransform": json.dumps({
"originCoordinate": {
"x": data.rio.bounds()[0],
"y": data.rio.bounds()[3],
},
"xPixelSize": 45.0,
"yPixelSize": -22.5,
}),
"xSize": "4",
"ySize": "4",
"spatialReference": "EPSG:4326",
"time": json.dumps({
"start": int(time.start.astype('datetime64[ms]').astype(int)),
"end": int(time.start.astype('datetime64[ms]').astype(int))
}),
"band": str(band),
})
sink = pa.BufferOutputStream()
with pa.ipc.new_file(sink, schema) as writer:
writer.write_batch(batch)
return sink.getvalue()
class WorkflowRasterStreamTests(unittest.TestCase):
'''Test methods for retrieving raster workflows as data streams'''
def setUp(self) -> None:
ge.reset(False)
def test_streaming_workflow(self):
with UrllibMocker() as m:
m.get("http://localhost:3030/session", json={
"id": "00000000-0000-0000-0000-000000000000",
})
ge.initialize("http://localhost:3030", token="no_token")
with unittest.mock.patch(
"geoengine.Workflow._Workflow__query_result_descriptor",
return_value=ge.RasterResultDescriptor(
"U8",
[RasterBandDescriptor("band", ge.UnitlessMeasurement())],
"EPSG:4326",
spatial_bounds=ge.SpatialPartition2D(-180.0, -90.0, 180.0, 90.0),
spatial_resolution=ge.SpatialResolution(45.0, 22.5)
),
):
workflow = ge.Workflow(UUID("00000000-0000-0000-0000-000000000000"))
query_rect = ge.QueryRectangle(
spatial_bounds=ge.BoundingBox2D(-180.0, -90.0, 180.0, 90.0),
time_interval=ge.TimeInterval(datetime(2014, 1, 1, 0, 0, 0), datetime(2014, 1, 3, 0, 0, 0)),
resolution=ge.SpatialResolution(45.0, 22.5),
)
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
async def inner1():
tiles = []
async for tile in workflow.raster_stream(query_rect):
tiles.append(tile)
assert len(tiles) == 8
asyncio.run(inner1())
with unittest.mock.patch("websockets.client.connect", return_value=MockWebsocket()):
async def inner2():
array = await workflow.raster_stream_into_xarray(query_rect)
assert array.shape == (2, 1, 8, 8) # time, band, y, x
original_array = rioxarray.open_rasterio("tests/responses/ndvi.tiff").isel(band=0, drop=True)
# Let's check that the output is the same as if we would
# have read the whole raster with rioxarray
assert array.isel({'band': 0, 'time': 0}, drop=True).equals(original_array)
asyncio.run(inner2())