diff --git a/stage1_eval_roi.py b/stage1_eval_roi.py new file mode 100644 index 0000000..7676319 --- /dev/null +++ b/stage1_eval_roi.py @@ -0,0 +1,61 @@ +# stage1_eval_roi.py +import os +import json +import numpy as np + +def main(): + test_y_path = "data/np_data/test_y.npy" + out_dir = "stage1_outputs" + + y_test = np.load(test_y_path) # (N,H,W) + + results = [] + for i in range(y_test.shape[0]): + gt = y_test[i].astype(bool) + roi = np.load(os.path.join(out_dir, f"tile_{i:03d}_roi.npy")).astype(bool) + + gt_pixels = int(gt.sum()) + roi_pixels = int(roi.sum()) + total_pixels = gt.size + + if gt_pixels == 0: + coverage = 1.0 # nothing to cover + else: + coverage = float((roi & gt).sum()) / float(gt_pixels) + + roi_frac = float(roi_pixels) / float(total_pixels) + + results.append({ + "tile": i, + "gt_pixels": gt_pixels, + "roi_pixels": roi_pixels, + "coverage": coverage, + "roi_frac": roi_frac, + }) + + # summary + coverages = [r["coverage"] for r in results] + roi_fracs = [r["roi_frac"] for r in results] + + summary = { + "mean_coverage": float(np.mean(coverages)), + "median_coverage": float(np.median(coverages)), + "mean_roi_frac": float(np.mean(roi_fracs)), + "median_roi_frac": float(np.median(roi_fracs)), + "tiles": results, + } + + print(json.dumps({ + "mean_coverage": summary["mean_coverage"], + "median_coverage": summary["median_coverage"], + "mean_roi_frac": summary["mean_roi_frac"], + "median_roi_frac": summary["median_roi_frac"], + }, indent=2)) + + with open(os.path.join(out_dir, "eval_summary.json"), "w") as f: + json.dump(summary, f, indent=2) + + print("Saved eval_summary.json in", out_dir) + +if __name__ == "__main__": + main() diff --git a/stage1_infer_roi.py b/stage1_infer_roi.py new file mode 100644 index 0000000..24d1b2b --- /dev/null +++ b/stage1_infer_roi.py @@ -0,0 +1,253 @@ +# # stage1_infer_roi.py +# import os +# import json +# import numpy as np +# from scipy import ndimage as ndi +# import joblib + + +# def prob_map_for_tile(model, tile_x): +# """ +# tile_x: (H, W, C) +# returns prob_map: (H, W) +# """ +# H, W, C = tile_x.shape +# X = tile_x.reshape(-1, C) +# p = model.predict_proba(X)[:, 1] +# return p.reshape(H, W) + + +# def roi_from_prob(prob_map, threshold=0.15, min_blob_area=100, dilate_pixels=1, box_pad=2): +# """ +# Returns: +# roi_mask: (H, W) uint8 0/1 +# boxes: list of dicts: {x1,y1,x2,y2,area} +# Notes: +# - threshold is chosen for high recall; tune later +# - blob filtering removes speckle noise +# """ +# H, W = prob_map.shape + +# mask = (prob_map >= threshold) + +# # dilate (optional) +# if dilate_pixels > 0: +# struct = ndi.generate_binary_structure(2, 1) +# mask = ndi.binary_dilation(mask, structure=struct, iterations=dilate_pixels) + +# # connected components +# labeled, n = ndi.label(mask) +# if n == 0: +# return np.zeros((H, W), dtype=np.uint8), [] + +# boxes = [] +# roi_mask = np.zeros((H, W), dtype=np.uint8) + +# slices = ndi.find_objects(labeled) +# for comp_id, slc in enumerate(slices, start=1): +# if slc is None: +# continue +# comp = (labeled[slc] == comp_id) +# area = int(comp.sum()) +# if area < min_blob_area: +# continue + +# # add to final mask +# roi_mask[slc][comp] = 1 + +# y1, y2 = slc[0].start, slc[0].stop +# x1, x2 = slc[1].start, slc[1].stop + +# # pad + clamp +# x1p = max(0, x1 - box_pad) +# y1p = max(0, y1 - box_pad) +# x2p = min(W, x2 + box_pad) +# y2p = min(H, y2 + box_pad) + +# boxes.append({"x1": x1p, "y1": y1p, "x2": x2p, "y2": y2p, "area": area}) + +# return roi_mask, boxes + + +# def main(): +# test_x_path = "data/np_data/test_x.npy" +# out_dir = "stage1_outputs" +# os.makedirs(out_dir, exist_ok=True) + +# model = joblib.load("stage1_model/logreg.joblib") +# X_test = np.load(test_x_path) # (N,512,512,16) + +# # Tunable hyperparams +# threshold = 0.35 +# min_blob_area = 350 +# dilate_pixels = 1 +# box_pad = 3 + +# cfg = dict( +# threshold=threshold, +# min_blob_area=min_blob_area, +# dilate_pixels=dilate_pixels, +# box_pad=box_pad +# ) +# with open(os.path.join(out_dir, "config.json"), "w") as f: +# json.dump(cfg, f, indent=2) + +# for i in range(X_test.shape[0]): +# tile = X_test[i] +# prob = prob_map_for_tile(model, tile) +# roi_mask, boxes = roi_from_prob( +# prob, +# threshold=threshold, +# min_blob_area=min_blob_area, +# dilate_pixels=dilate_pixels, +# box_pad=box_pad +# ) + +# # save outputs +# np.save(os.path.join(out_dir, f"tile_{i:03d}_prob.npy"), prob.astype(np.float32)) +# np.save(os.path.join(out_dir, f"tile_{i:03d}_roi.npy"), roi_mask.astype(np.uint8)) +# with open(os.path.join(out_dir, f"tile_{i:03d}_boxes.json"), "w") as f: +# json.dump(boxes, f, indent=2) + +# print(f"tile {i:03d}: boxes={len(boxes)} roi_pixels={int(roi_mask.sum())}") + +# print("Done:", out_dir) + + +# if __name__ == "__main__": +# main() + + +# stage1_infer_roi.py +import os +import json +import numpy as np +from scipy import ndimage as ndi +import joblib + + +def prob_map_for_tile(scaler, clf, tile_x): + """ + tile_x: (H, W, C) + returns prob_map: (H, W) + """ + H, W, C = tile_x.shape + X = tile_x.reshape(-1, C).astype(np.float32, copy=False) + Xs = scaler.transform(X) + p = clf.predict_proba(Xs)[:, 1] + return p.reshape(H, W) + + +def roi_from_prob( + prob_map, + threshold=0.25, # minimum threshold floor + keep_frac=0.05, # keep top 5% pixels per tile + min_blob_area=1200, + dilate_pixels=1, + box_pad=4 +): + """ + Strategy: + - adaptive threshold per tile: keep top keep_frac pixels + - also enforce a floor threshold to avoid keeping pure noise + """ + H, W = prob_map.shape + + # Adaptive threshold: keep top keep_frac pixels + t_adapt = float(np.quantile(prob_map, 1.0 - keep_frac)) + t = max(float(threshold), t_adapt) + + mask = (prob_map >= t) + + # dilate (optional) + if dilate_pixels > 0: + struct = ndi.generate_binary_structure(2, 1) + mask = ndi.binary_dilation(mask, structure=struct, iterations=dilate_pixels) + + labeled, n = ndi.label(mask) + if n == 0: + return np.zeros((H, W), dtype=np.uint8), [] + + boxes = [] + roi_mask = np.zeros((H, W), dtype=np.uint8) + + slices = ndi.find_objects(labeled) + for comp_id, slc in enumerate(slices, start=1): + if slc is None: + continue + comp = (labeled[slc] == comp_id) + area = int(comp.sum()) + if area < min_blob_area: + continue + + roi_mask[slc][comp] = 1 + + y1, y2 = slc[0].start, slc[0].stop + x1, x2 = slc[1].start, slc[1].stop + + x1p = max(0, x1 - box_pad) + y1p = max(0, y1 - box_pad) + x2p = min(W, x2 + box_pad) + y2p = min(H, y2 + box_pad) + + boxes.append({"x1": x1p, "y1": y1p, "x2": x2p, "y2": y2p, "area": area}) + + return roi_mask, boxes + + +def main(): + test_x_path = "data/np_data/test_x.npy" + out_dir = "stage1_outputs" + os.makedirs(out_dir, exist_ok=True) + + # Load streaming-trained model bundle + bundle = joblib.load("stage1_model_streaming/sgd_logreg.joblib") + scaler = bundle["scaler"] + clf = bundle["clf"] + + # Use mmap to reduce RAM while loading test tiles + X_test = np.load(test_x_path, mmap_mode="r") # (N,512,512,16) + + # Tunable hyperparams + threshold = 0.12 # floor (lets weak plumes in) + keep_frac = 0.1 # cap ROI near 5% before blob filtering + min_blob_area = 1200 # slightly lower because mask is sparser now + dilate_pixels = 1 # reduce dilation to avoid giant merge + box_pad = 4 + + cfg = dict( + model_path="stage1_model_streaming/sgd_logreg.joblib", + threshold=threshold, + min_blob_area=min_blob_area, + dilate_pixels=dilate_pixels, + box_pad=box_pad, + keep_frac=keep_frac + ) + with open(os.path.join(out_dir, "config.json"), "w") as f: + json.dump(cfg, f, indent=2) + + for i in range(X_test.shape[0]): + tile = X_test[i] + prob = prob_map_for_tile(scaler, clf, tile) + roi_mask, boxes = roi_from_prob( + prob, + threshold=threshold, + keep_frac=keep_frac, + min_blob_area=min_blob_area, + dilate_pixels=dilate_pixels, + box_pad=box_pad + ) + + # save outputs + np.save(os.path.join(out_dir, f"tile_{i:03d}_prob.npy"), prob.astype(np.float32)) + np.save(os.path.join(out_dir, f"tile_{i:03d}_roi.npy"), roi_mask.astype(np.uint8)) + with open(os.path.join(out_dir, f"tile_{i:03d}_boxes.json"), "w") as f: + json.dump(boxes, f, indent=2) + + print(f"tile {i:03d}: boxes={len(boxes)} roi_pixels={int(roi_mask.sum())}") + + print("Done:", out_dir) + + +if __name__ == "__main__": + main() diff --git a/stage1_sweep_configs.py b/stage1_sweep_configs.py new file mode 100644 index 0000000..3877b90 --- /dev/null +++ b/stage1_sweep_configs.py @@ -0,0 +1,238 @@ +# stage1_sweep_configs.py +import os, json +import numpy as np +from scipy import ndimage as ndi +import joblib +from itertools import product + + +def prob_map_for_tile(scaler, clf, tile_x): + H, W, C = tile_x.shape + X = tile_x.reshape(-1, C).astype(np.float32, copy=False) + p = clf.predict_proba(scaler.transform(X))[:, 1] + return p.reshape(H, W) + + +def roi_from_prob(prob_map, mode, threshold, keep_frac, min_blob_area, dilate_pixels, box_pad): + """ + mode: + - "fixed": prob >= threshold + - "topk": prob >= max(threshold, quantile(1-keep_frac)) + - "hybrid_or": (prob >= threshold) OR (prob >= quantile(1-keep_frac)) + Returns roi_mask uint8 {0,1} + """ + H, W = prob_map.shape + + if mode == "fixed": + mask = (prob_map >= threshold) + + else: + # adaptive threshold based on keeping top keep_frac + t_adapt = float(np.quantile(prob_map, 1.0 - keep_frac)) + if mode == "topk": + t = max(float(threshold), t_adapt) + mask = (prob_map >= t) + elif mode == "hybrid_or": + mask = (prob_map >= float(threshold)) | (prob_map >= t_adapt) + else: + raise ValueError(f"Unknown mode: {mode}") + + # dilation + if dilate_pixels > 0: + struct = ndi.generate_binary_structure(2, 1) + mask = ndi.binary_dilation(mask, structure=struct, iterations=dilate_pixels) + + labeled, n = ndi.label(mask) + if n == 0: + return np.zeros((H, W), dtype=np.uint8) + + roi = np.zeros((H, W), dtype=np.uint8) + slices = ndi.find_objects(labeled) + for comp_id, slc in enumerate(slices, start=1): + if slc is None: + continue + comp = (labeled[slc] == comp_id) + area = int(comp.sum()) + if area < min_blob_area: + continue + roi[slc][comp] = 1 + + # Note: box_pad affects boxes not roi_frac/coverage; kept for completeness + return roi + + +def eval_config(prob_maps, y_test, cfg): + """ + Returns aggregate metrics + per-tile coverage list for robust stats. + """ + N = y_test.shape[0] + coverages = [] + roi_fracs = [] + + for i in range(N): + gt = y_test[i].astype(bool) + prob = prob_maps[i] + roi = roi_from_prob(prob, **cfg).astype(bool) + + gt_pixels = int(gt.sum()) + if gt_pixels == 0: + coverage = 1.0 + else: + coverage = float((roi & gt).sum()) / float(gt_pixels) + + roi_frac = float(roi.sum()) / float(roi.size) + + coverages.append(coverage) + roi_fracs.append(roi_frac) + + coverages = np.array(coverages, dtype=np.float32) + roi_fracs = np.array(roi_fracs, dtype=np.float32) + + out = { + "mean_coverage": float(coverages.mean()), + "median_coverage": float(np.median(coverages)), + "p10_coverage": float(np.quantile(coverages, 0.10)), + "mean_roi_frac": float(roi_fracs.mean()), + "median_roi_frac": float(np.median(roi_fracs)), + "p90_roi_frac": float(np.quantile(roi_fracs, 0.90)), + "cfg": cfg, + } + return out + + +def pareto_frontier(results): + """ + Keep configs that are not dominated: + better coverage AND lower ROI at the same time. + """ + # Sort by coverage desc, ROI asc + r = sorted(results, key=lambda x: (-x["mean_coverage"], x["mean_roi_frac"])) + frontier = [] + best_roi = float("inf") + for item in r: + roi = item["mean_roi_frac"] + if roi < best_roi: + frontier.append(item) + best_roi = roi + return frontier + + +def main(): + test_x_path = "data/np_data/test_x.npy" + test_y_path = "data/np_data/test_y.npy" + model_path = "stage1_model_streaming/sgd_logreg.joblib" + + cache_dir = "stage1_sweep_cache" + os.makedirs(cache_dir, exist_ok=True) + prob_cache_path = os.path.join(cache_dir, "test_prob_maps.npy") + + print("Loading model:", model_path) + bundle = joblib.load(model_path) + scaler, clf = bundle["scaler"], bundle["clf"] + + print("Loading test tiles (mmap)...") + X_test = np.load(test_x_path, mmap_mode="r") # (N,512,512,16) + y_test = np.load(test_y_path, mmap_mode="r") # (N,512,512) + + # ---- Compute prob maps once and cache ---- + if os.path.exists(prob_cache_path): + print("Loading cached prob maps:", prob_cache_path) + prob_maps = np.load(prob_cache_path, mmap_mode="r") + else: + print("Computing prob maps (one-time)...") + prob_maps = np.zeros((X_test.shape[0], X_test.shape[1], X_test.shape[2]), dtype=np.float32) + for i in range(X_test.shape[0]): + prob_maps[i] = prob_map_for_tile(scaler, clf, X_test[i]).astype(np.float32) + if (i + 1) % 5 == 0: + print(f" computed {i+1}/{X_test.shape[0]}") + np.save(prob_cache_path, prob_maps) + print("Saved prob cache:", prob_cache_path) + + # ---- Sweep space ---- + # Start moderately sized; expand if needed. + modes = ["fixed", "topk", "hybrid_or"] + thresholds = [0.10, 0.15, 0.20, 0.25, 0.30, 0.35, 0.45] + keep_fracs = [0.03, 0.05, 0.08, 0.12, 0.18] # only used for modes != fixed + min_blob_areas = [50, 150, 300, 600, 1200] + dilates = [0, 1, 2] + box_pads = [2, 4] # doesn't affect metrics, but keep + + results = [] + total = 0 + + for mode in modes: + for threshold, min_blob_area, dilate_pixels, box_pad in product(thresholds, min_blob_areas, dilates, box_pads): + if mode == "fixed": + cfg = dict( + mode=mode, + threshold=threshold, + keep_frac=0.05, # unused + min_blob_area=min_blob_area, + dilate_pixels=dilate_pixels, + box_pad=box_pad, + ) + results.append(eval_config(prob_maps, y_test, cfg)) + total += 1 + else: + for keep_frac in keep_fracs: + cfg = dict( + mode=mode, + threshold=threshold, + keep_frac=keep_frac, + min_blob_area=min_blob_area, + dilate_pixels=dilate_pixels, + box_pad=box_pad, + ) + results.append(eval_config(prob_maps, y_test, cfg)) + total += 1 + + print(f"Finished mode={mode}") + + print("Total configs evaluated:", total) + + # ---- Save all results ---- + all_path = os.path.join(cache_dir, "sweep_results.json") + with open(all_path, "w") as f: + json.dump(results, f, indent=2) + print("Saved:", all_path) + + # ---- Report best by constraints ---- + # Pick your constraint for reporting: + target_cov = 0.97 + candidates = [r for r in results if r["mean_coverage"] >= target_cov] + candidates = sorted(candidates, key=lambda r: r["mean_roi_frac"]) + print(f"\nBest configs with mean_coverage >= {target_cov}:") + for r in candidates[:10]: + print({ + "mean_coverage": round(r["mean_coverage"], 4), + "p10_coverage": round(r["p10_coverage"], 4), + "mean_roi_frac": round(r["mean_roi_frac"], 4), + "cfg": r["cfg"], + }) + + # Alternatively: ROI budget constraint + roi_budget = 0.10 + candidates2 = [r for r in results if r["mean_roi_frac"] <= roi_budget] + candidates2 = sorted(candidates2, key=lambda r: -r["mean_coverage"]) + print(f"\nBest configs with mean_roi_frac <= {roi_budget}:") + for r in candidates2[:10]: + print({ + "mean_coverage": round(r["mean_coverage"], 4), + "p10_coverage": round(r["p10_coverage"], 4), + "mean_roi_frac": round(r["mean_roi_frac"], 4), + "cfg": r["cfg"], + }) + + # ---- Pareto frontier ---- + frontier = pareto_frontier(results) + front_path = os.path.join(cache_dir, "pareto_frontier.json") + with open(front_path, "w") as f: + json.dump(frontier, f, indent=2) + print("\nSaved Pareto frontier:", front_path) + print("Top 10 Pareto points (coverage vs ROI):") + for r in frontier[:10]: + print(round(r["mean_coverage"], 4), round(r["mean_roi_frac"], 4), r["cfg"]) + + +if __name__ == "__main__": + main() diff --git a/stage1_train.py b/stage1_train.py new file mode 100644 index 0000000..d44fe6a --- /dev/null +++ b/stage1_train.py @@ -0,0 +1,270 @@ +# # stage1_train.py +# import os +# import json +# import numpy as np +# from sklearn.preprocessing import StandardScaler +# from sklearn.linear_model import LogisticRegression +# from sklearn.pipeline import Pipeline +# from sklearn.metrics import roc_auc_score, average_precision_score + + +# def build_pixel_dataset(X_tiles, y_tiles, neg_per_pos=20, seed=0): +# """ +# X_tiles: (N, H, W, C) +# y_tiles: (N, H, W) 0/1 +# Returns: +# X: (N_samples, C) +# y: (N_samples,) +# Strategy: +# - take ALL positives +# - sample negatives at ratio neg_per_pos : 1 +# """ +# rng = np.random.default_rng(seed) + +# N, H, W, C = X_tiles.shape +# X_flat = X_tiles.reshape(-1, C) +# y_flat = y_tiles.reshape(-1) + +# pos_idx = np.where(y_flat == 1)[0] +# neg_idx = np.where(y_flat == 0)[0] + +# n_pos = pos_idx.size +# n_neg_take = min(neg_idx.size, n_pos * neg_per_pos) + +# neg_take = rng.choice(neg_idx, size=n_neg_take, replace=False) + +# idx = np.concatenate([pos_idx, neg_take]) +# rng.shuffle(idx) + +# X = X_flat[idx] +# y = y_flat[idx].astype(np.uint8) + +# return X, y, {"n_pos": int(n_pos), "n_neg_take": int(n_neg_take), "neg_per_pos": neg_per_pos} + + +# def main(): +# # Update paths if yours differ +# train_x_path = "data/np_data/train_x.npy" +# train_y_path = "data/np_data/train_y.npy" +# test_x_path = "data/np_data/test_x.npy" +# test_y_path = "data/np_data/test_y.npy" + +# out_dir = "stage1_model" +# os.makedirs(out_dir, exist_ok=True) + +# print("Loading npy...") +# X_train_tiles = np.load(train_x_path) +# y_train_tiles = np.load(train_y_path) +# X_test_tiles = np.load(test_x_path) +# y_test_tiles = np.load(test_y_path) + +# # ---- Build sampled pixel dataset ---- +# # Start with 20 negatives per positive. If too many false positives later, increase to 50. +# Xtr, ytr, info = build_pixel_dataset(X_train_tiles, y_train_tiles, neg_per_pos=75, seed=0) +# Xte, yte, _ = build_pixel_dataset(X_test_tiles, y_test_tiles, neg_per_pos=20, seed=1) + +# print("Pixel dataset info:", info) +# print("Train samples:", Xtr.shape, "pos rate:", ytr.mean()) +# print("Test samples:", Xte.shape, "pos rate:", yte.mean()) + +# # ---- Model: StandardScaler + Logistic Regression ---- +# # class_weight balanced helps a lot with this imbalance +# model = Pipeline([ +# ("scaler", StandardScaler(with_mean=True, with_std=True)), +# ("clf", LogisticRegression( +# max_iter=2000, +# solver="saga", +# n_jobs=-1, +# class_weight="balanced", +# penalty="l2", +# C=1.0 +# )) +# ]) + +# print("Training...") +# model.fit(Xtr, ytr) + +# # ---- Quick evaluation on sampled pixels (not final ROI metric, but sanity) ---- +# print("Evaluating on sampled pixels...") +# p = model.predict_proba(Xte)[:, 1] +# roc = roc_auc_score(yte, p) +# ap = average_precision_score(yte, p) +# print("Sampled pixel ROC-AUC:", roc) +# print("Sampled pixel PR-AUC :", ap) + +# # Save model as joblib +# import joblib +# joblib.dump(model, os.path.join(out_dir, "logreg.joblib")) + +# # Save metadata +# meta = { +# "model": "StandardScaler + LogisticRegression", +# "neg_per_pos": info["neg_per_pos"], +# "n_pos": info["n_pos"], +# "n_neg_take": info["n_neg_take"], +# "note": "Trained on sampled pixels: all positives + neg_per_pos*positives negatives. Split was by tile via provided train/test npy." +# } +# with open(os.path.join(out_dir, "meta.json"), "w") as f: +# json.dump(meta, f, indent=2) + +# print("Saved:", out_dir) + + +# if __name__ == "__main__": +# main() + + +''' +This is the streaming version to avoid fitting large chunk of data in memory at once. +''' +import os, json +import numpy as np +import joblib +from sklearn.linear_model import SGDClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import roc_auc_score, average_precision_score + + +def sample_pixels_from_tile(tile_x, tile_y, neg_per_pos=50, rng=None): + if rng is None: + rng = np.random.default_rng(0) + + H, W, C = tile_x.shape + X = tile_x.reshape(-1, C) + y = tile_y.reshape(-1) + + pos_idx = np.where(y == 1)[0] + neg_idx = np.where(y == 0)[0] + n_pos = pos_idx.size + + if n_pos == 0: + n_neg_take = min(5000, neg_idx.size) + idx = rng.choice(neg_idx, size=n_neg_take, replace=False) + else: + n_neg_take = min(neg_idx.size, n_pos * neg_per_pos) + neg_take = rng.choice(neg_idx, size=n_neg_take, replace=False) + idx = np.concatenate([pos_idx, neg_take]) + rng.shuffle(idx) + + return X[idx].astype(np.float32, copy=False), y[idx].astype(np.uint8, copy=False) + + +def iter_train_batches(X_tiles, y_tiles, neg_per_pos, batch_cap, seed): + rng = np.random.default_rng(seed) + N = X_tiles.shape[0] + for i in range(N): + X_s, y_s = sample_pixels_from_tile(X_tiles[i], y_tiles[i], neg_per_pos=neg_per_pos, rng=rng) + n = X_s.shape[0] + for start in range(0, n, batch_cap): + end = min(n, start + batch_cap) + yield X_s[start:end], y_s[start:end] + + +def main(): + train_x_path = "data/np_data/train_x.npy" + train_y_path = "data/np_data/train_y.npy" + test_x_path = "data/np_data/test_x.npy" + test_y_path = "data/np_data/test_y.npy" + + out_dir = "stage1_model_streaming" + os.makedirs(out_dir, exist_ok=True) + + print("Loading tiles (mmap)...") + X_train = np.load(train_x_path, mmap_mode="r") + y_train = np.load(train_y_path, mmap_mode="r") + X_test = np.load(test_x_path, mmap_mode="r") + y_test = np.load(test_y_path, mmap_mode="r") + + # ---- knobs ---- + neg_per_pos = 50 + batch_cap = 100_000 # lower if RAM tight + epochs = 2 + seed = 0 + + # ---- 1) FIRST PASS: fit scaler only (streaming) ---- + scaler = StandardScaler(with_mean=True, with_std=True) + print("\nPass 1: fitting scaler...") + b = 0 + for Xb, yb in iter_train_batches(X_train, y_train, neg_per_pos, batch_cap, seed): + scaler.partial_fit(Xb) + b += 1 + if b % 20 == 0: + print(f" scaler batches: {b}") + + # ---- 2) SECOND PASS: train classifier with frozen scaler ---- + # IMPORTANT: weights should match the sampled distribution, not the full dataset. + # With neg_per_pos=50, a simple stable choice is: + class_weight = {0: 1.0, 1: float(neg_per_pos)} + print("\nUsing class_weight:", class_weight) + + clf = SGDClassifier( + loss="log_loss", + penalty="l2", + alpha=1e-4, + learning_rate="optimal", + max_iter=1, # we control epochs + tol=None, + class_weight=class_weight, + average=True, # stabilizes online SGD a lot + random_state=seed + ) + classes = np.array([0, 1], dtype=np.uint8) + + print("\nPass 2: training classifier...") + first = True + for ep in range(epochs): + print(f"\n=== epoch {ep+1}/{epochs} ===") + batch_num = 0 + for Xb, yb in iter_train_batches(X_train, y_train, neg_per_pos, batch_cap, seed + ep): + Xb_s = scaler.transform(Xb) + if first: + clf.partial_fit(Xb_s, yb, classes=classes) + first = False + else: + clf.partial_fit(Xb_s, yb) + + batch_num += 1 + if batch_num % 20 == 0: + print(f" train batches: {batch_num}") + + # ---- quick eval on sampled pixels ---- + print("\nSampling test pixels for quick eval...") + K = min(10, X_test.shape[0]) + rng = np.random.default_rng(123) + + X_eval_list, y_eval_list = [], [] + for i in range(K): + Xs, ys = sample_pixels_from_tile(X_test[i], y_test[i], neg_per_pos=neg_per_pos, rng=rng) + if Xs.shape[0] > 150_000: + idx = rng.choice(Xs.shape[0], size=150_000, replace=False) + Xs, ys = Xs[idx], ys[idx] + X_eval_list.append(Xs) + y_eval_list.append(ys) + + X_eval = np.concatenate(X_eval_list, axis=0) + y_eval = np.concatenate(y_eval_list, axis=0) + + X_eval_s = scaler.transform(X_eval) + p = clf.predict_proba(X_eval_s)[:, 1] + roc = roc_auc_score(y_eval, p) + ap = average_precision_score(y_eval, p) + print("Sampled pixel ROC-AUC:", roc) + print("Sampled pixel PR-AUC :", ap) + + joblib.dump({"scaler": scaler, "clf": clf}, os.path.join(out_dir, "sgd_logreg.joblib")) + meta = dict( + model="Two-pass streaming: StandardScaler.partial_fit then SGDClassifier(log_loss)", + neg_per_pos=neg_per_pos, + batch_cap=batch_cap, + epochs=epochs, + class_weight=class_weight, + note="Scaler frozen during classifier training; weights match sampled ratio.", + ) + with open(os.path.join(out_dir, "meta.json"), "w") as f: + json.dump(meta, f, indent=2) + + print("Saved:", out_dir) + + +if __name__ == "__main__": + main()