1
+ """Create windows for crop type mapping from GPKG files (fixed splits)."""
2
+
3
+ import argparse
4
+ import multiprocessing
5
+ from datetime import datetime , timezone
6
+ from pathlib import Path
7
+ from typing import Iterable , Tuple
8
+
9
+ import geopandas as gpd
10
+ import shapely
11
+ import tqdm
12
+ from rslearn .const import WGS84_PROJECTION
13
+ from rslearn .dataset import Window
14
+ from rslearn .utils import Projection , STGeometry , get_utm_ups_crs
15
+ from rslearn .utils .feature import Feature
16
+ from rslearn .utils .mp import star_imap_unordered
17
+ from rslearn .utils .vector_format import GeojsonVectorFormat
18
+ from upath import UPath
19
+
20
+ from rslp .utils .windows import calculate_bounds
21
+
22
+ WINDOW_RESOLUTION = 10
23
+ LABEL_LAYER = "label"
24
+
25
+ CLASS_MAP = {
26
+ 0 : "Water" ,
27
+ 1 : "Bare Ground" ,
28
+ 2 : "Rangeland" ,
29
+ 3 : "Flooded Vegetation" ,
30
+ 4 : "Trees" ,
31
+ 5 : "Cropland" ,
32
+ 6 : "Buildings" ,
33
+ }
34
+
35
+ # Per-province temporal coverage (UTC)
36
+ PROVINCE_TIME = {
37
+ "gaza" : (
38
+ datetime (2024 , 10 , 23 , tzinfo = timezone .utc ),
39
+ datetime (2025 , 5 , 7 , tzinfo = timezone .utc ),
40
+ ),
41
+ "manica" : (
42
+ datetime (2024 , 11 , 23 , tzinfo = timezone .utc ),
43
+ datetime (2025 , 6 , 7 , tzinfo = timezone .utc ),
44
+ ),
45
+ "zambezia" : (
46
+ datetime (2024 , 11 , 23 , tzinfo = timezone .utc ),
47
+ datetime (2025 , 6 , 7 , tzinfo = timezone .utc ),
48
+ ),
49
+ }
50
+
51
+
52
+ def process_gpkg (gpkg_path : UPath ) -> gpd .GeoDataFrame :
53
+ """Load a GPKG and ensure lon/lat in WGS84; expect 'fid' and 'class' columns."""
54
+ gdf = gpd .read_file (str (gpkg_path ))
55
+
56
+ # Normalize CRS to WGS84
57
+ if gdf .crs is None :
58
+ gdf = gdf .set_crs ("EPSG:4326" , allow_override = True )
59
+ else :
60
+ gdf = gdf .to_crs ("EPSG:4326" )
61
+
62
+ required_cols = {"class" , "geometry" }
63
+ missing = [c for c in required_cols if c not in gdf .columns ]
64
+ if missing :
65
+ raise ValueError (f"{ gpkg_path } : missing required column(s): { missing } " )
66
+
67
+ return gdf
68
+
69
+
70
+ def iter_points (gdf : gpd .GeoDataFrame ) -> Iterable [Tuple [int , float , float , int ]]:
71
+ """Yield (fid, latitude, longitude, category) per feature using centroid for polygons."""
72
+ for fid , row in gdf .iterrows ():
73
+ geom = row .geometry
74
+ if geom is None or geom .is_empty :
75
+ continue
76
+ if isinstance (geom , shapely .Point ):
77
+ pt = geom
78
+ else :
79
+ pt = geom .centroid
80
+ lon , lat = float (pt .x ), float (pt .y )
81
+ category = int (row ["class" ])
82
+ yield fid , lat , lon , category
83
+
84
+
85
+ def create_window (
86
+ rec : Tuple [int , float , float , int ],
87
+ ds_path : UPath ,
88
+ group_name : str ,
89
+ split : str ,
90
+ window_size : int ,
91
+ start_time : datetime ,
92
+ end_time : datetime ,
93
+ ) -> None :
94
+ """Create a single window and write label layer."""
95
+ fid , latitude , longitude , category_id = rec
96
+ category_label = CLASS_MAP .get (category_id , f"Unknown_{ category_id } " )
97
+
98
+ # Geometry/projection
99
+ src_point = shapely .Point (longitude , latitude )
100
+ src_geometry = STGeometry (WGS84_PROJECTION , src_point , None )
101
+ dst_crs = get_utm_ups_crs (longitude , latitude )
102
+ dst_projection = Projection (dst_crs , WINDOW_RESOLUTION , - WINDOW_RESOLUTION )
103
+ dst_geometry = src_geometry .to_projection (dst_projection )
104
+ bounds = calculate_bounds (dst_geometry , window_size )
105
+
106
+ # Group = province name; split is taken from file name (train/test)
107
+ group = group_name
108
+ window_name = f"{ fid } _{ latitude :.6f} _{ longitude :.6f} "
109
+
110
+ window = Window (
111
+ path = Window .get_window_root (ds_path , group , window_name ),
112
+ group = group ,
113
+ name = window_name ,
114
+ projection = dst_projection ,
115
+ bounds = bounds ,
116
+ time_range = (start_time , end_time ),
117
+ options = {
118
+ "split" : split , # 'train' or 'test' as provided
119
+ "category_id" : category_id ,
120
+ "category" : category_label ,
121
+ "fid" : fid ,
122
+ "source" : "gpkg" ,
123
+ },
124
+ )
125
+ window .save ()
126
+
127
+ # Label layer (same as before, using window geometry)
128
+ feature = Feature (
129
+ window .get_geometry (),
130
+ {
131
+ "category_id" : category_id ,
132
+ "category" : category_label ,
133
+ "fid" : fid ,
134
+ "split" : split ,
135
+ },
136
+ )
137
+ layer_dir = window .get_layer_dir (LABEL_LAYER )
138
+ GeojsonVectorFormat ().encode_vector (layer_dir , [feature ])
139
+ window .mark_layer_completed (LABEL_LAYER )
140
+
141
+
142
+ def create_windows_from_gpkg (
143
+ gpkg_path : UPath ,
144
+ ds_path : UPath ,
145
+ group_name : str ,
146
+ split : str ,
147
+ window_size : int ,
148
+ max_workers : int ,
149
+ start_time : datetime ,
150
+ end_time : datetime ,
151
+ ) -> None :
152
+ """Create windows from a single GPKG file."""
153
+ gdf = process_gpkg (gpkg_path )
154
+ records = list (iter_points (gdf ))
155
+
156
+ jobs = [
157
+ dict (
158
+ rec = rec ,
159
+ ds_path = ds_path ,
160
+ group_name = group_name ,
161
+ split = split ,
162
+ window_size = window_size ,
163
+ start_time = start_time ,
164
+ end_time = end_time ,
165
+ )
166
+ for rec in records
167
+ ]
168
+
169
+ print (
170
+ f"[{ group_name } :{ split } ] file={ gpkg_path .name } features={ len (jobs )} "
171
+ f"time={ start_time .date ()} →{ end_time .date ()} "
172
+ )
173
+
174
+ if max_workers <= 1 :
175
+ for kw in tqdm .tqdm (jobs ):
176
+ create_window (** kw )
177
+ else :
178
+ p = multiprocessing .Pool (max_workers )
179
+ outputs = star_imap_unordered (p , create_window , jobs )
180
+ for _ in tqdm .tqdm (outputs , total = len (jobs )):
181
+ pass
182
+ p .close ()
183
+
184
+ if __name__ == "__main__" :
185
+ multiprocessing .set_start_method ("forkserver" , force = True )
186
+
187
+ parser = argparse .ArgumentParser (description = "Create windows from GPKG files" )
188
+ parser .add_argument (
189
+ "--gpkg_dir" ,
190
+ type = str ,
191
+ required = True ,
192
+ help = "Directory containing gaza_[train|test].gpkg, manica_[train|test].gpkg, zambezia_[train|test].gpkg" ,
193
+ )
194
+ parser .add_argument (
195
+ "--ds_path" ,
196
+ type = str ,
197
+ required = True ,
198
+ help = "Path to the dataset root" ,
199
+ )
200
+ parser .add_argument (
201
+ "--window_size" ,
202
+ type = int ,
203
+ default = 1 ,
204
+ help = "Window size (pixels per side in projected grid)" ,
205
+ )
206
+ parser .add_argument (
207
+ "--max_workers" ,
208
+ type = int ,
209
+ default = 32 ,
210
+ help = "Worker processes (set 1 for single-process)" ,
211
+ )
212
+ args = parser .parse_args ()
213
+
214
+ gpkg_dir = Path (args .gpkg_dir )
215
+ ds_path = UPath (args .ds_path )
216
+
217
+ expected = [
218
+ ("gaza" , "train" , gpkg_dir / "gaza_train.gpkg" ),
219
+ ("gaza" , "test" , gpkg_dir / "gaza_test.gpkg" ),
220
+ ("manica" , "train" , gpkg_dir / "manica_train.gpkg" ),
221
+ ("manica" , "test" , gpkg_dir / "manica_test.gpkg" ),
222
+ ("zambezia" , "train" , gpkg_dir / "zambezia_train.gpkg" ),
223
+ ("zambezia" , "test" , gpkg_dir / "zambezia_test.gpkg" ),
224
+ ]
225
+
226
+ # Basic checks
227
+ for province , _ , path in expected :
228
+ if province not in PROVINCE_TIME :
229
+ raise ValueError (f"Unknown province '{ province } '" )
230
+ if not path .exists ():
231
+ raise FileNotFoundError (f"Missing expected file: { path } " )
232
+
233
+ # Run per file
234
+ for province , split , path in expected :
235
+ start_time , end_time = PROVINCE_TIME [province ]
236
+ create_windows_from_gpkg (
237
+ gpkg_path = UPath (path ),
238
+ ds_path = ds_path ,
239
+ group_name = province , # group == province
240
+ split = split , # honor provided split
241
+ window_size = args .window_size ,
242
+ max_workers = args .max_workers ,
243
+ start_time = start_time ,
244
+ end_time = end_time ,
245
+ )
246
+
247
+ print ("Done." )
0 commit comments