FlowingClusters.jl performs unsupervised clustering of arbitrary real data that have first been deformed into a base space under a FFJORD normalizing flow (Grathwohl et al. 2018).
The generative model in the base space consists of a non-parametric Chinese Restaurant Process (CRP) prior (Pitman 1995, Aldous 1985, Frigyik et al. 2010 Tutorial). The base distribution of the CRP is given by a normal-inverse-Wishart distribution and the data likelihood by multivariate normal. For the hyperprior for the parameters of the normal-inverse-Wishart distribution and the CRP prior we use an independence Jeffreys prior. The generative model
If a neural network
where
where
We use Adaptive-Metropolis-within-Gibbs (AMWG) for the hyperparameters of the clustering part of the model, and Adaptive-Metropolis (AM) for the parameters of the neural network. The Chinese Restaurant Process is sampled using a mix of both traditional Gibbs moves (Neal 2000 Algorithm 1) and sequentially allocated split-merge proposals (Dahl & Newcomb 2022).
using FlowingClusters
using SplitMaskStandardize
dataset = SMSDataset("data/ebird_data/ebird_bioclim_landcover.csv", splits=[1, 1, 1], subsample=3000)
chain = MNCRPChain(eb.training.presence(:sp1).standardize(:BIO1, :BIO12), nb_samples=200)
advance_chain!(chain, Inf, nb_splitmerge=150, nb_hyperparams=2)
Progress: 64 Time: 0:00:17 ( 0.27 s/it)
step (hyperparams per, gibbs per, splitmerge per): 64/Inf (2, 1, 150.0)
chain length: 5146
conv largestcluster chain (burn 50%): ess=386.3, rhat=1.0
#chain samples (oldest, latest, eta) convergence: 108/200 (3678, 5140, 8) ess=85.3 rhat=0.999 (trim if ess<72.0)
logprob (max, q95, max minus nn): -2603.6 (-2315.1, -2463.2, -2315.1)
nb clusters, nb>1, smallest(>1), median, mean, largest: 11, 9, 2, 42.0, 79.0, 262
split #succ/#tot, merge #succ/#tot: 2/30, 1/117
split/step, merge/step: 4.83, 4.75
MAP #attempts/#successes: 5/0
nb clusters, nb>1, smallest(>1), median, mean, largest: 7, 7, 10, 23.0, 124.0, 506
last MAP logprob (minus nn): -2202.3 (-2202.3)
last MAP at: 1572
last checkpoint at: -1
Here we save a sample of the state of the chain in a buffer of size 200. This is done each time the number of iterations since the last sample crosses over twice the autocorrelation time of the last 50% of the chain of the size of the largest cluster, a typical proxy of mixing used in non-parametric clustering. Once the buffer is full and the ess of the samples is approximately equal to the size of the buffer we stop the chain by invoking touch stop
in the same working directory. The oldest chain samples are automatically dropped when the ess goes below 50% of the number of samples in the buffer.
- It's ok the stop the chain once the sample buffer is full and its ess is roughly equal to the buffer size.
- Set the number of split-merge
nb_splitmerge
moves per iteration such that the number of accepted splits and merges per iteration is above 3 to prevent overfitting. This is especially important when using a neural network. - The number of Gibbs moves
nb_gibbs
can be set to 1, and the numbernb_hyperparams
of AMWG and AM moves to 1 or 2.
We can plot the chain which outputs a very approximate MAP state of the clusters (greedy Gibbs plus a partial hyperparameter optimization), the current state of the clusters in the chain, together with the trace of various quantities such as the log-probability, the number of clusters, the size of the largest clusters, and all the hyperparameters.
plot(chain)
We can also get the predictions of tail probabilities (suitability in SDMs, or probability of presence) at a set of test points for a single state of the chain, in this case the approximate MAP state, and statistical summaries of those tail probabilities using the 200 chain samples we've collected.
tp = tail_probability(chain.map_clusters, chain.map_hyperparams);
tp(dataset.test.presence(:sp1).standardize(:BIO1, :BIO12))
841-element Vector{Float64}:
0.2255
0.0671
0.0004
⋮
0.9112
0.7221
tpsumm = tail_probability_summary(presence_chain_.clusters_samples, presence_chain_.hyperparams_samples);
summstats = tpsumm(dataset.test.presence(:sp1).standardize(:BIO1, :BIO12))
(median = [0.1604, 0.1793, 0.0013, ⋯],
mean = [ ⋯ ],
std = [ ⋯ ],
iqr = [ ⋯ ],
CI95 = [ ⋯ ],
CI90 = [ ⋯ ],
quantile = q -> ⋯,
modefd = [ ⋯ ],
modedoane = [ ⋯ ])
# The quantile field is a function that returns the
# q quantile of the tail probabilities at each points
# in the dataset
summstats.quantile(0.3)
841-element Vector{Float64}:
0.13695
0.14554
0.0009
⋮
0.9115
0.58434
The tail probability is a generalization of the way the suitability is determine in BIOCLIM. For a given point
In higher dimensions the same principle applies, only the isocontours become lines, surfaces, volumes, etc.