Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add KMedoids #298

Merged
merged 8 commits into from
Mar 28, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -18,6 +18,7 @@ PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TableDistances = "e5d66e97-8c70-46bb-8b66-04a2d73ad782"
Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TransformsBase = "28dd2a49-a57a-4bfb-84ca-1a49db9b96b8"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
@@ -37,6 +38,7 @@ PrettyTables = "2"
Random = "1.9"
Statistics = "1.9"
StatsBase = "0.33, 0.34"
TableDistances = "1.0"
Tables = "1.6"
TransformsBase = "1.5"
Unitful = "1.17"
6 changes: 6 additions & 0 deletions docs/src/transforms.md
Original file line number Diff line number Diff line change
@@ -242,6 +242,12 @@ SDS
ProjectionPursuit
```

## KMedoids

```@docs
KMedoids
```

## Closure

```@docs
6 changes: 4 additions & 2 deletions src/TableTransforms.jl
Original file line number Diff line number Diff line change
@@ -6,12 +6,13 @@ module TableTransforms

using Tables
using Unitful
using Statistics
using PrettyTables
using AbstractTrees
using LinearAlgebra
using TableDistances
using DataScienceTraits
using CategoricalArrays
using LinearAlgebra
using Statistics
using Random
using CoDa

@@ -90,6 +91,7 @@ export
DRS,
SDS,
ProjectionPursuit,
KMedoids,
Closure,
Remainder,
Compose,
1 change: 1 addition & 0 deletions src/transforms.jl
Original file line number Diff line number Diff line change
@@ -286,6 +286,7 @@ include("transforms/quantile.jl")
include("transforms/functional.jl")
include("transforms/eigenanalysis.jl")
include("transforms/projectionpursuit.jl")
include("transforms/kmedoids.jl")
include("transforms/closure.jl")
include("transforms/remainder.jl")
include("transforms/compose.jl")
142 changes: 142 additions & 0 deletions src/transforms/kmedoids.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# ------------------------------------------------------------------
# Licensed under the MIT License. See LICENSE in the project root.
# ------------------------------------------------------------------

"""
KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())

Assign labels to rows of table using the `k`-medoids algorithm.

The iterative algorithm is interrupted if the relative change on
the average distance to medoids is smaller than a tolerance `tol`
or if the number of iterations exceeds the maximum number of
iterations `maxiter`.

Optionally, specify a dictionary of `weights` for each column to
affect the underlying table distance from TableDistances.jl, and
a random number generator `rng` to obtain reproducible results.

## Examples

```julia
KMedoids(3)
KMedoids(4, maxiter=20)
KMedoids(5, weights=Dict(:col1 => 1.0, :col2 => 2.0))
```

## References

* Kaufman, L. & Rousseeuw, P. J. 1990. [Partitioning Around Medoids (Program PAM)]
(https://onlinelibrary.wiley.com/doi/10.1002/9780470316801.ch2)

* Kaufman, L. & Rousseeuw, P. J. 1991. [Finding Groups in Data: An Introduction to Cluster Analysis]
(https://www.jstor.org/stable/2532178)
"""
struct KMedoids{W,RNG} <: StatelessFeatureTransform
k::Int
tol::Float64
maxiter::Int
weights::W
rng::RNG
end

function KMedoids(k; tol=1e-4, maxiter=10, weights=nothing, rng=Random.default_rng())
# sanity checks
_assert(k > 0, "number of clusters must be positive")
_assert(tol > 0, "tolerance on relative change must be positive")
_assert(maxiter > 0, "maximum number of iterations must be positive")
KMedoids(k, tol, maxiter, weights, rng)
end

parameters(transform::KMedoids) = (; k=transform.k)

function applyfeat(transform::KMedoids, feat, prep)
# retrieve parameters
k = transform.k
tol = transform.tol
maxiter = transform.maxiter
weights = transform.weights
rng = transform.rng

# number of observations
nobs = _nrows(feat)

# sanity checks
k > nobs && throw(ArgumentError("requested number of clusters > number of observations"))

# normalize variables
stdfeat = feat |> StdFeats()

# define table distance
td = TableDistance(normalize=false, weights=weights)

# initialize medoids
medoids = sample(rng, 1:nobs, k, replace=false)

# retrieve distance type
s = Tables.subset(stdfeat, 1:1)
D = eltype(pairwise(td, s))

# pre-allocate memory for labels and distances
labels = fill(0, nobs)
dists = fill(typemax(D), nobs)

# main loop
iter = 0
δcur = mean(dists)
while iter < maxiter
# update labels and medoids
_updatelabels!(td, stdfeat, medoids, labels, dists)
_updatemedoids!(td, stdfeat, medoids, labels)

# average distance to medoids
δnew = mean(dists)

# break upon convergence
abs(δnew - δcur) / δcur < tol && break

# update and continue
δcur = δnew
iter += 1
end

newfeat = (; cluster=labels) |> Tables.materializer(feat)

newfeat, nothing
end

function _updatelabels!(td, table, medoids, labels, dists)
for (k, mₖ) in enumerate(medoids)
inds = 1:_nrows(table)

X = Tables.subset(table, inds)
μ = Tables.subset(table, [mₖ])

δ = pairwise(td, X, μ)

@inbounds for i in inds
if δ[i] < dists[i]
dists[i] = δ[i]
labels[i] = k
end
end
end
end

function _updatemedoids!(td, table, medoids, labels)
for k in eachindex(medoids)
inds = findall(isequal(k), labels)

X = Tables.subset(table, inds)

j = _medoid(td, X)

@inbounds medoids[k] = inds[j]
end
end

function _medoid(td, table)
Δ = pairwise(td, table)
_, j = findmin(sum, eachcol(Δ))
j
end
1 change: 1 addition & 0 deletions test/transforms.jl
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@ transformfiles = [
"functional.jl",
"eigenanalysis.jl",
"projectionpursuit.jl",
"kmedoids.jl",
"closure.jl",
"remainder.jl",
"compose.jl",
25 changes: 25 additions & 0 deletions test/transforms/kmedoids.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
@testset "KMedoids" begin
@test !isrevertible(KMedoids(3))

@test TT.parameters(KMedoids(3)) == (; k=3)

# basic test with continuous variables
a = [randn(100); 10 .+ randn(100)]
b = [randn(100); 10 .+ randn(100)]
t = Table(; a, b)
n = t |> KMedoids(2; rng)
i1 = findall(isequal(1), n.cluster)
i2 = findall(isequal(2), n.cluster)
@test mean(t.a[i1]) > 5
@test mean(t.b[i1]) > 5
@test mean(t.a[i2]) < 5
@test mean(t.b[i2]) < 5

# test with mixed variables
a = [1, 2, 3]
b = [1.0, 2.0, 3.0]
c = ["a", "b", "c"]
t = Table(; a, b, c)
n = t |> KMedoids(3; rng)
@test sort(n.cluster) == [1, 2, 3]
end