Skip to content

Commit 8cb4415

Browse files
committed
add mozambique runs
1 parent 52b31d1 commit 8cb4415

File tree

3 files changed

+390
-0
lines changed

3 files changed

+390
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
model:
2+
class_path: rslearn.train.lightning_module.RslearnLightningModule
3+
init_args:
4+
model:
5+
class_path: rslearn.models.multitask.MultiTaskModel
6+
init_args:
7+
encoder:
8+
- class_path: rslp.helios.model.Helios
9+
init_args:
10+
checkpoint_path: /weka/dfive-default/helios/checkpoints/henryh/base_v6.1_add_chm_cdl_worldcereal/step300000
11+
selector: ["encoder"]
12+
forward_kwargs:
13+
patch_size: 1
14+
decoders:
15+
crop_type_classification:
16+
- class_path: rslearn.models.pooling_decoder.PoolingDecoder
17+
init_args:
18+
in_channels: 768
19+
out_channels: 7
20+
- class_path: rslearn.train.tasks.classification.ClassificationHead
21+
lr: 0.0001
22+
scheduler:
23+
class_path: rslearn.train.scheduler.PlateauScheduler
24+
init_args:
25+
factor: 0.2
26+
patience: 2
27+
min_lr: 0
28+
cooldown: 10
29+
data:
30+
class_path: rslearn.train.data_module.RslearnDataModule
31+
init_args:
32+
path: /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc
33+
inputs:
34+
sentinel2_l2a:
35+
data_type: "raster"
36+
layers: ["sentinel2"]
37+
bands: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
38+
passthrough: true
39+
dtype: FLOAT32
40+
load_all_item_groups: true
41+
load_all_layers: true
42+
sentinel1:
43+
data_type: "raster"
44+
layers: ["sentinel1_ascending"]
45+
bands: ["vv", "vh"]
46+
passthrough: true
47+
dtype: FLOAT32
48+
load_all_item_groups: true
49+
load_all_layers: true
50+
label:
51+
data_type: "vector"
52+
layers: ["label"]
53+
is_target: true
54+
task:
55+
class_path: rslearn.train.tasks.multi_task.MultiTask
56+
init_args:
57+
tasks:
58+
crop_type_classification:
59+
class_path: rslearn.train.tasks.classification.ClassificationTask
60+
init_args:
61+
property_name: "category"
62+
classes: ["Water", "Bare Ground", "Rangeland", "Flooded Vegetation", "Trees", "Cropland", "Buildings"]
63+
enable_f1_metric: true
64+
metric_kwargs:
65+
average: "micro"
66+
input_mapping:
67+
crop_type_classification:
68+
label: "targets"
69+
batch_size: 32
70+
num_workers: 32
71+
default_config:
72+
transforms:
73+
- class_path: rslp.helios.norm.HeliosNormalize
74+
init_args:
75+
config_fname: "/opt/helios/data/norm_configs/computed.json"
76+
band_names:
77+
sentinel2_l2a: ["B02", "B03", "B04", "B08", "B05", "B06", "B07", "B8A", "B11", "B12", "B01", "B09"]
78+
sentinel1: ["vv", "vh"]
79+
- class_path: rslearn.train.transforms.pad.Pad
80+
init_args:
81+
size: 4
82+
mode: "center"
83+
image_selectors: ["sentinel2_l2a", "sentinel1"]
84+
train_config:
85+
groups: ["gaza"]
86+
tags:
87+
split: "train"
88+
val_config:
89+
groups: ["gaza"]
90+
tags:
91+
split: "test"
92+
test_config:
93+
groups: ["gaza"]
94+
tags:
95+
split: "test"
96+
trainer:
97+
max_epochs: 100
98+
callbacks:
99+
- class_path: lightning.pytorch.callbacks.LearningRateMonitor
100+
init_args:
101+
logging_interval: "epoch"
102+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
103+
init_args:
104+
save_top_k: 1
105+
save_last: true
106+
monitor: val_loss
107+
mode: min
108+
- class_path: rslearn.train.callbacks.freeze_unfreeze.FreezeUnfreeze
109+
init_args:
110+
module_selector: ["model", "encoder", 0]
111+
unfreeze_at_epoch: 20
112+
unfreeze_lr_factor: 10
113+
rslp_project: 2025_09_18_mozambique_lulc
114+
rslp_experiment: mozambique_lulc_helios_base_S1_S2_ts_ws4_ps1_gaza

rslp/crop/mozambique/README.md

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Mozambique LULC and Crop Type Classification
2+
3+
This project has two main tasks:
4+
1. Land Use/Land Cover (LULC) and cropland classification
5+
2. Crop type classification
6+
7+
The annotations come from field surveys across three provinces in Mozambique: Gaza, Zambezia, and Manica.
8+
9+
For LULC classification, the train/test splits are:
10+
- Gaza: 2,262 / 970
11+
- Manica: 1,917 / 822
12+
- Zambezia: 1,225 / 525
13+
14+
## LULC Classification
15+
16+
```
17+
python /weka/dfive-default/yawenz/rslearn_projects/rslp/crop/mozambique/create_windows_for_lulc.py --gpkg_dir /weka/dfive-default/yawenz/datasets/mozambique/train_test_samples --ds_path /weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc --window_size 32
18+
19+
export DATASET_PATH=/weka/dfive-default/rslearn-eai/datasets/crop/mozambique_lulc
20+
rslearn dataset prepare --root $DATASET_PATH --workers 64 --no-use-initial-job --retry-max-attempts 8 --retry-backoff-seconds 60
21+
python -m rslp.main common launch_data_materialization_jobs --image favyen/rslp_image --ds_path $DATASET_PATH --clusters+=ai2/neptune-cirrascale --num_jobs 5
22+
23+
python -m rslp.main helios launch_finetune --image_name favyen/rslphelios10 --config_paths+=data/helios/v2_mozambique_lulc/finetune_s1_s2.yaml --cluster+=ai2/neptune --rslp_project 2025_09_18_mozambique_lulc --experiment_id mozambique_lulc_helios_base_S1_S2_ts_ws4_ps1_gaza
24+
```
25+
26+
27+
28+
29+
Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
"""Create windows for crop type mapping from GPKG files (fixed splits)."""
2+
3+
import argparse
4+
import multiprocessing
5+
from datetime import datetime, timezone
6+
from pathlib import Path
7+
from typing import Iterable, Tuple
8+
9+
import geopandas as gpd
10+
import shapely
11+
import tqdm
12+
from rslearn.const import WGS84_PROJECTION
13+
from rslearn.dataset import Window
14+
from rslearn.utils import Projection, STGeometry, get_utm_ups_crs
15+
from rslearn.utils.feature import Feature
16+
from rslearn.utils.mp import star_imap_unordered
17+
from rslearn.utils.vector_format import GeojsonVectorFormat
18+
from upath import UPath
19+
20+
from rslp.utils.windows import calculate_bounds
21+
22+
WINDOW_RESOLUTION = 10
23+
LABEL_LAYER = "label"
24+
25+
CLASS_MAP = {
26+
0: "Water",
27+
1: "Bare Ground",
28+
2: "Rangeland",
29+
3: "Flooded Vegetation",
30+
4: "Trees",
31+
5: "Cropland",
32+
6: "Buildings",
33+
}
34+
35+
# Per-province temporal coverage (UTC)
36+
PROVINCE_TIME = {
37+
"gaza": (
38+
datetime(2024, 10, 23, tzinfo=timezone.utc),
39+
datetime(2025, 5, 7, tzinfo=timezone.utc),
40+
),
41+
"manica": (
42+
datetime(2024, 11, 23, tzinfo=timezone.utc),
43+
datetime(2025, 6, 7, tzinfo=timezone.utc),
44+
),
45+
"zambezia": (
46+
datetime(2024, 11, 23, tzinfo=timezone.utc),
47+
datetime(2025, 6, 7, tzinfo=timezone.utc),
48+
),
49+
}
50+
51+
52+
def process_gpkg(gpkg_path: UPath) -> gpd.GeoDataFrame:
53+
"""Load a GPKG and ensure lon/lat in WGS84; expect 'fid' and 'class' columns."""
54+
gdf = gpd.read_file(str(gpkg_path))
55+
56+
# Normalize CRS to WGS84
57+
if gdf.crs is None:
58+
gdf = gdf.set_crs("EPSG:4326", allow_override=True)
59+
else:
60+
gdf = gdf.to_crs("EPSG:4326")
61+
62+
required_cols = {"class", "geometry"}
63+
missing = [c for c in required_cols if c not in gdf.columns]
64+
if missing:
65+
raise ValueError(f"{gpkg_path}: missing required column(s): {missing}")
66+
67+
return gdf
68+
69+
70+
def iter_points(gdf: gpd.GeoDataFrame) -> Iterable[Tuple[int, float, float, int]]:
71+
"""Yield (fid, latitude, longitude, category) per feature using centroid for polygons."""
72+
for fid, row in gdf.iterrows():
73+
geom = row.geometry
74+
if geom is None or geom.is_empty:
75+
continue
76+
if isinstance(geom, shapely.Point):
77+
pt = geom
78+
else:
79+
pt = geom.centroid
80+
lon, lat = float(pt.x), float(pt.y)
81+
category = int(row["class"])
82+
yield fid, lat, lon, category
83+
84+
85+
def create_window(
86+
rec: Tuple[int, float, float, int],
87+
ds_path: UPath,
88+
group_name: str,
89+
split: str,
90+
window_size: int,
91+
start_time: datetime,
92+
end_time: datetime,
93+
) -> None:
94+
"""Create a single window and write label layer."""
95+
fid, latitude, longitude, category_id = rec
96+
category_label = CLASS_MAP.get(category_id, f"Unknown_{category_id}")
97+
98+
# Geometry/projection
99+
src_point = shapely.Point(longitude, latitude)
100+
src_geometry = STGeometry(WGS84_PROJECTION, src_point, None)
101+
dst_crs = get_utm_ups_crs(longitude, latitude)
102+
dst_projection = Projection(dst_crs, WINDOW_RESOLUTION, -WINDOW_RESOLUTION)
103+
dst_geometry = src_geometry.to_projection(dst_projection)
104+
bounds = calculate_bounds(dst_geometry, window_size)
105+
106+
# Group = province name; split is taken from file name (train/test)
107+
group = group_name
108+
window_name = f"{fid}_{latitude:.6f}_{longitude:.6f}"
109+
110+
window = Window(
111+
path=Window.get_window_root(ds_path, group, window_name),
112+
group=group,
113+
name=window_name,
114+
projection=dst_projection,
115+
bounds=bounds,
116+
time_range=(start_time, end_time),
117+
options={
118+
"split": split, # 'train' or 'test' as provided
119+
"category_id": category_id,
120+
"category": category_label,
121+
"fid": fid,
122+
"source": "gpkg",
123+
},
124+
)
125+
window.save()
126+
127+
# Label layer (same as before, using window geometry)
128+
feature = Feature(
129+
window.get_geometry(),
130+
{
131+
"category_id": category_id,
132+
"category": category_label,
133+
"fid": fid,
134+
"split": split,
135+
},
136+
)
137+
layer_dir = window.get_layer_dir(LABEL_LAYER)
138+
GeojsonVectorFormat().encode_vector(layer_dir, [feature])
139+
window.mark_layer_completed(LABEL_LAYER)
140+
141+
142+
def create_windows_from_gpkg(
143+
gpkg_path: UPath,
144+
ds_path: UPath,
145+
group_name: str,
146+
split: str,
147+
window_size: int,
148+
max_workers: int,
149+
start_time: datetime,
150+
end_time: datetime,
151+
) -> None:
152+
"""Create windows from a single GPKG file."""
153+
gdf = process_gpkg(gpkg_path)
154+
records = list(iter_points(gdf))
155+
156+
jobs = [
157+
dict(
158+
rec=rec,
159+
ds_path=ds_path,
160+
group_name=group_name,
161+
split=split,
162+
window_size=window_size,
163+
start_time=start_time,
164+
end_time=end_time,
165+
)
166+
for rec in records
167+
]
168+
169+
print(
170+
f"[{group_name}:{split}] file={gpkg_path.name} features={len(jobs)} "
171+
f"time={start_time.date()}{end_time.date()}"
172+
)
173+
174+
if max_workers <= 1:
175+
for kw in tqdm.tqdm(jobs):
176+
create_window(**kw)
177+
else:
178+
p = multiprocessing.Pool(max_workers)
179+
outputs = star_imap_unordered(p, create_window, jobs)
180+
for _ in tqdm.tqdm(outputs, total=len(jobs)):
181+
pass
182+
p.close()
183+
184+
if __name__ == "__main__":
185+
multiprocessing.set_start_method("forkserver", force=True)
186+
187+
parser = argparse.ArgumentParser(description="Create windows from GPKG files")
188+
parser.add_argument(
189+
"--gpkg_dir",
190+
type=str,
191+
required=True,
192+
help="Directory containing gaza_[train|test].gpkg, manica_[train|test].gpkg, zambezia_[train|test].gpkg",
193+
)
194+
parser.add_argument(
195+
"--ds_path",
196+
type=str,
197+
required=True,
198+
help="Path to the dataset root",
199+
)
200+
parser.add_argument(
201+
"--window_size",
202+
type=int,
203+
default=1,
204+
help="Window size (pixels per side in projected grid)",
205+
)
206+
parser.add_argument(
207+
"--max_workers",
208+
type=int,
209+
default=32,
210+
help="Worker processes (set 1 for single-process)",
211+
)
212+
args = parser.parse_args()
213+
214+
gpkg_dir = Path(args.gpkg_dir)
215+
ds_path = UPath(args.ds_path)
216+
217+
expected = [
218+
("gaza", "train", gpkg_dir / "gaza_train.gpkg"),
219+
("gaza", "test", gpkg_dir / "gaza_test.gpkg"),
220+
("manica", "train", gpkg_dir / "manica_train.gpkg"),
221+
("manica", "test", gpkg_dir / "manica_test.gpkg"),
222+
("zambezia", "train", gpkg_dir / "zambezia_train.gpkg"),
223+
("zambezia", "test", gpkg_dir / "zambezia_test.gpkg"),
224+
]
225+
226+
# Basic checks
227+
for province, _, path in expected:
228+
if province not in PROVINCE_TIME:
229+
raise ValueError(f"Unknown province '{province}'")
230+
if not path.exists():
231+
raise FileNotFoundError(f"Missing expected file: {path}")
232+
233+
# Run per file
234+
for province, split, path in expected:
235+
start_time, end_time = PROVINCE_TIME[province]
236+
create_windows_from_gpkg(
237+
gpkg_path=UPath(path),
238+
ds_path=ds_path,
239+
group_name=province, # group == province
240+
split=split, # honor provided split
241+
window_size=args.window_size,
242+
max_workers=args.max_workers,
243+
start_time=start_time,
244+
end_time=end_time,
245+
)
246+
247+
print("Done.")

0 commit comments

Comments
 (0)