Skip to content

Commit

Permalink
add input.cluster.get_representative_points
Browse files Browse the repository at this point in the history
  • Loading branch information
jornbr committed Feb 6, 2025
1 parent 3fbbfb1 commit c87d9a7
Showing 1 changed file with 33 additions and 7 deletions.
40 changes: 33 additions & 7 deletions src/fabmos/input/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,16 +216,42 @@ def interpolate(values: xr.DataArray, lon: np.ndarray, lat: np.ndarray):
def split(clusters: npt.ArrayLike) -> np.ma.MaskedArray:
from skimage.segmentation import flood_fill

masked = np.ma.getmaskarray(clusters)
unmasked = ~np.ma.getmaskarray(clusters)
clusters = np.asarray(clusters)
unique_clusters = np.unique(clusters[~masked])
n = -1
for c in unique_clusters:
unique_clusters = np.unique(clusters[unmasked])
fill_value = len(unique_clusters)
cluster_indices = np.full(clusters.shape, fill_value, dtype=int)
for icluster, cluster_id in enumerate(unique_clusters):
cluster_indices[(clusters == cluster_id) & unmasked] = icluster
n = 0
for icluster in range(len(unique_clusters)):
while True:
indices = (clusters == c).nonzero()
indices = (cluster_indices == icluster).nonzero()
if indices[0].size == 0:
break
n += 1
seed_point = (indices[0][0], indices[1][0])
flood_fill(clusters, seed_point, -n, connectivity=1, in_place=True)
return np.ma.array(-clusters, mask=masked)
flood_fill(cluster_indices, seed_point, -n, connectivity=1, in_place=True)
return np.ma.masked_equal(-cluster_indices, -fill_value, copy=False)


def get_representative_points(clusters: npt.ArrayLike, mincount: int = 1):
clusters = np.asanyarray(clusters)
split_clusters = split(clusters).filled(0)
jj, ii = np.indices(split_clusters.shape)
for split_id in range(1, split_clusters.max() + 1):
sel = split_clusters == split_id
if sel.sum() < mincount:
continue
ii_sel = ii[sel]
jj_sel = jj[sel]
cluster_id = clusters[jj_sel[0], ii_sel[0]]
dist = (ii_sel - ii_sel[:, np.newaxis]) ** 2
dist += (jj_sel - jj_sel[:, np.newaxis]) ** 2
dist.sort(axis=1)
best_dist = dist.min(axis=0)
has_best_dist = dist == best_dist
still_best_dist = np.logical_and.accumulate(has_best_dist, axis=1)
longest_best = still_best_dist.sum(axis=1)
ind = np.argmax(longest_best)
yield cluster_id, ii_sel[ind], jj_sel[ind]

0 comments on commit c87d9a7

Please sign in to comment.