Skip to content

Commit 8405a47

Browse files
authored
Merge pull request #75 from PyDataBlog/improved-distance
Support for distance metrics
2 parents 7cc6969 + 117b2d8 commit 8405a47

18 files changed

+320
-119
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,14 @@ Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
88
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
99
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1010
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
11+
UnsafeArrays = "c4a57d5a-5b31-53a6-b365-19f8c011fbd6"
1112

1213
[compat]
1314
Distances = "0.8.2"
1415
MLJModelInterface = "0.2.1"
1516
StatsBase = "0.32, 0.33"
1617
julia = "1.3"
18+
UnsafeArrays = "1"
1719

1820
[extras]
1921
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"

README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
[![Build Status](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl.svg?branch=master)](https://www.travis-ci.org/PyDataBlog/ParallelKMeans.jl)
66
[![Coverage Status](https://coveralls.io/repos/github/PyDataBlog/ParallelKMeans.jl/badge.svg?branch=master)](https://coveralls.io/github/PyDataBlog/ParallelKMeans.jl?branch=master)
77
[![FOSSA Status](https://app.fossa.com/api/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl.svg?type=shield)](https://app.fossa.com/projects/git%2Bgithub.com%2FPyDataBlog%2FParallelKMeans.jl?ref=badge_shield)
8+
[![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/PyDataBlog/ParallelKMeans.jl/master)
89
_________________________________________________________________________________________________________
910
**Authors:** [Bernard Brenyah](https://www.linkedin.com/in/bbrenyah/) & [Andrey Oskin](https://www.linkedin.com/in/andrej-oskin-b2b03959/)
1011
_________________________________________________________________________________________________________
1112

1213
<div align="center">
13-
<b>Classic & Contemporary Variants Of K-Means In Sonic Mode<b>
14+
<b>Classic & Contemporary Variants Of K-Means In Sonic Mode</b>
1415
</div>
1516

1617
<p align="center">
@@ -64,7 +65,9 @@ ________________________________________________________________________________
6465
- Lightening fast implementation of K-Means clustering algorithm even on a single thread in native Julia.
6566
- Support for multi-theading implementation of K-Means clustering algorithm.
6667
- Kmeans++ initialization for faster and better convergence.
67-
- Implementation of all the variants of the K-Means algorithm.
68+
- Implementation of all available variants of the K-Means algorithm.
69+
- Support for all distance metrics available at [Distances.jl](https://github.com/JuliaStats/Distances.jl)
70+
- Supported interface as an [MLJ](https://github.com/alan-turing-institute/MLJ.jl#available-models) model.
6871

6972
_________________________________________________________________________________________________________
7073

benchmark/bench01_distance.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,13 @@ suite = BenchmarkGroup()
99
Random.seed!(2020)
1010
X = rand(3, 100_000)
1111
centroids = rand(3, 2)
12-
d = Vector{Float64}(undef, 100_000)
13-
suite["100kx3"] = @benchmarkable ParallelKMeans.chunk_colwise($d, $X, $centroids, 1, nothing, 1:100_000, 1)
12+
d = fill(-Inf, 100_000)
13+
suite["100kx3"] = @benchmarkable ParallelKMeans.chunk_colwise(d1, $X, $centroids, 1, nothing, Euclidean(), 1:100_000, 1) setup=(d1 = copy(d))
1414

1515
X = rand(10, 100_000)
1616
centroids = rand(10, 2)
17-
d = Vector{Float64}(undef, 100_000)
18-
suite["100kx10"] = @benchmarkable ParallelKMeans.chunk_colwise($d, $X, $centroids, 1, nothing, 1:100_000, 1)
17+
d = fill(-Inf, 100_000)
18+
suite["100kx10"] = @benchmarkable ParallelKMeans.chunk_colwise(d1, $X, $centroids, 1, nothing, Euclidean(), 1:100_000, 1) setup=(d1 = copy(d))
1919

2020
end # module
2121

docs/src/index.md

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,15 @@ pkg> add ParallelKMeans
5151
The few (and selected) brave ones can simply grab the current experimental features by simply adding the experimental branch to your development environment after invoking the package manager with `]`:
5252

5353
```julia
54-
dev git@github.com:PyDataBlog/ParallelKMeans.jl.git
54+
pkg> add ParallelKMeans#experimental
5555
```
5656

57-
Don't forget to checkout the experimental branch and you are good to go with bleeding edge features and breakages!
57+
You are good to go with bleeding edge features and breakages!
5858

59-
```bash
60-
git checkout experimental
59+
To revert to a stable version, you can simply run:
60+
61+
```julia
62+
pkg> free ParallelKMeans
6163
```
6264

6365
## Features
@@ -207,7 +209,7 @@ ________________________________________________________________________________
207209
- 0.1.4 Bug fixes.
208210
- 0.1.5 Added `Yinyang` algorithm.
209211
- 0.1.6 Added support for weighted k-means; Added `Coreset` algorithm; improved support for different types of the design matrix.
210-
- 0.1.7 Added `Yinyang` and `Coreset` support in MLJ interface; added `weights` support in MLJ; added RNG seed support in MLJ interface and through all algorithms.
212+
- 0.1.7 Added `Yinyang` and `Coreset` support in MLJ interface; added `weights` support in MLJ; added RNG seed support in MLJ interface and through all algorithms; added metric support.
211213

212214
## Contributing
213215

src/ParallelKMeans.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,11 @@ module ParallelKMeans
22

33
using StatsBase
44
using Random
5+
using UnsafeArrays
6+
using Distances
57
import MLJModelInterface
68
import Base.Threads: @spawn
7-
import Distances
9+
810

911
include("kmeans.jl")
1012
include("seeding.jl")

src/coreset.jl

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,18 +35,20 @@ Coreset(; m = 100, alg = Lloyd()) = Coreset(m, alg)
3535
Coreset(m::Int) = Coreset(m, Lloyd())
3636
Coreset(alg::AbstractKMeansAlg) = Coreset(100, alg)
3737

38-
function kmeans!(alg::Coreset, containers, X, k, weights;
38+
function kmeans!(alg::Coreset, containers, X, k, weights, metric::Euclidean = Euclidean();
3939
n_threads = Threads.nthreads(),
4040
k_init = "k-means++", max_iters = 300,
4141
tol = eltype(design_matrix)(1e-6), verbose = false,
4242
init = nothing, rng = Random.GLOBAL_RNG)
43+
4344
nrow, ncol = size(X)
45+
4446
centroids = isnothing(init) ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
4547

4648
T = eltype(X)
4749
# Steps 2-4 of the paper's algorithm 3
4850
# We distribute points over the centers and calculate weights of each cluster
49-
@parallelize n_threads ncol chunk_fit(alg, containers, centroids, X, weights)
51+
@parallelize n_threads ncol chunk_fit(alg, containers, centroids, X, weights, metric)
5052

5153
# after this step, containers.centroids_new
5254
collect_containers(alg, containers, n_threads)
@@ -62,15 +64,16 @@ function kmeans!(alg::Coreset, containers, X, k, weights;
6264

6365
# run usual kmeans for new set with new weights.
6466
res = kmeans(alg.alg, coreset, k, weights = coreset_weights, tol = tol, max_iters = max_iters,
65-
verbose = verbose, init = centroids, n_threads = n_threads, rng = rng)
67+
verbose = verbose, init = centroids, n_threads = n_threads, rng = rng, metric = metric)
6668

67-
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights)
69+
@parallelize n_threads ncol chunk_apply(alg, containers, res.centers, X, weights, metric)
6870

6971
totalcost = sum(containers.totalcost)
7072

7173
return KmeansResult(res.centers, containers.labels, T[], Int[], T[], totalcost, res.iterations, res.converged)
7274
end
7375

76+
7477
function create_containers(alg::Coreset, X, k, nrow, ncol, n_threads)
7578
T = eltype(X)
7679

@@ -109,7 +112,8 @@ function create_containers(alg::Coreset, X, k, nrow, ncol, n_threads)
109112
)
110113
end
111114

112-
function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
115+
116+
function chunk_fit(alg::Coreset, containers, centroids, X, weights, metric, r, idx)
113117
centroids_cnt = containers.centroids_cnt[idx]
114118
centroids_dist = containers.centroids_dist[idx]
115119
labels = containers.labels
@@ -118,10 +122,10 @@ function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
118122

119123
J = zero(T)
120124
for i in r
121-
dist = distance(X, centroids, i, 1)
125+
dist = distance(metric, X, centroids, i, 1)
122126
label = 1
123127
for j in 2:size(centroids, 2)
124-
new_dist = distance(X, centroids, i, j)
128+
new_dist = distance(metric, X, centroids, i, j)
125129

126130
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
127131
label = new_dist < dist ? j : label
@@ -144,6 +148,7 @@ function chunk_fit(alg::Coreset, containers, centroids, X, weights, r, idx)
144148
containers.J[idx] = J
145149
end
146150

151+
147152
function collect_containers(::Coreset, containers, n_threads)
148153
# Here we transform formula of the step 6
149154
# By multiplying both sides of equation on $c_\phi / \alpha$ we obtain
@@ -172,6 +177,7 @@ function collect_containers(::Coreset, containers, n_threads)
172177
end
173178
end
174179

180+
175181
function chunk_update_sensitivity(alg::Coreset, containers, r, idx)
176182
labels = containers.labels
177183
centroids_const = containers.centroids_const
@@ -182,18 +188,19 @@ function chunk_update_sensitivity(alg::Coreset, containers, r, idx)
182188
end
183189
end
184190

185-
function chunk_apply(alg::Coreset, containers, centroids, X, weights, r, idx)
191+
192+
function chunk_apply(alg::Coreset, containers, centroids, X, weights, metric, r, idx)
186193
centroids_cnt = containers.centroids_cnt[idx]
187194
centroids_dist = containers.centroids_dist[idx]
188195
labels = containers.labels
189196
T = eltype(X)
190197

191198
J = zero(T)
192199
for i in r
193-
dist = distance(X, centroids, i, 1)
200+
dist = distance(metric, X, centroids, i, 1)
194201
label = 1
195202
for j in 2:size(centroids, 2)
196-
new_dist = distance(X, centroids, i, j)
203+
new_dist = distance(metric, X, centroids, i, j)
197204

198205
# calculation of the closest center (steps 2-3 of the paper's algorithm 3)
199206
label = new_dist < dist ? j : label

src/elkan.jl

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,18 @@ kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
1818
"""
1919
struct Elkan <: AbstractKMeansAlg end
2020

21-
function kmeans!(alg::Elkan, containers, X, k, weights;
21+
22+
function kmeans!(alg::Elkan, containers, X, k, weights=nothing, metric=Euclidean();
2223
n_threads = Threads.nthreads(),
2324
k_init = "k-means++", max_iters = 300,
2425
tol = eltype(X)(1e-6), verbose = false,
2526
init = nothing, rng = Random.GLOBAL_RNG)
27+
2628
nrow, ncol = size(X)
2729
centroids = init == nothing ? smart_init(X, k, n_threads, weights, rng, init=k_init).centroids : deepcopy(init)
2830

29-
update_containers(alg, containers, centroids, n_threads)
30-
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights)
31+
update_containers(alg, containers, centroids, n_threads, metric)
32+
@parallelize n_threads ncol chunk_initialize(alg, containers, centroids, X, weights, metric)
3133

3234
T = eltype(X)
3335
converged = false
@@ -38,7 +40,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
3840
while niters < max_iters
3941
niters += 1
4042
# Core iteration
41-
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights)
43+
@parallelize n_threads ncol chunk_update_centroids(alg, containers, centroids, X, weights, metric)
4244

4345
# Collect distributed containers (such as centroids_new, centroids_cnt)
4446
# in paper it is step 4
@@ -47,10 +49,10 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
4749
J = sum(containers.ub)
4850

4951
# auxiliary calculation, in paper it's d(c, m(c))
50-
calculate_centroids_movement(alg, containers, centroids)
52+
calculate_centroids_movement(alg, containers, centroids, metric)
5153

5254
# lower and ounds update, in paper it's steps 5 and 6
53-
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids)
55+
@parallelize n_threads ncol chunk_update_bounds(alg, containers, centroids, metric)
5456

5557
# Step 7, final assignment of new centroids
5658
centroids .= containers.centroids_new[end]
@@ -67,11 +69,11 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
6769
end
6870

6971
# Step 1 in original paper, calulation of distance d(c, c')
70-
update_containers(alg, containers, centroids, n_threads)
72+
update_containers(alg, containers, centroids, n_threads, metric)
7173
J_previous = J
7274
end
7375

74-
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights)
76+
@parallelize n_threads ncol sum_of_squares(containers, X, containers.labels, centroids, weights, metric)
7577
totalcost = sum(containers.sum_of_squares)
7678

7779
# Terminate algorithm with the assumption that K-means has converged
@@ -85,6 +87,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
8587
return KmeansResult(centroids, containers.labels, T[], Int[], T[], totalcost, niters, converged)
8688
end
8789

90+
8891
function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
8992
T = eltype(X)
9093
lng = n_threads + 1
@@ -128,7 +131,8 @@ function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
128131
)
129132
end
130133

131-
function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
134+
135+
function chunk_initialize(::Elkan, containers, centroids, X, weights, metric, r, idx)
132136
ub = containers.ub
133137
lb = containers.lb
134138
centroids_dist = containers.centroids_dist
@@ -138,15 +142,15 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
138142
T = eltype(X)
139143

140144
@inbounds for i in r
141-
min_dist = distance(X, centroids, i, 1)
145+
min_dist = distance(metric, X, centroids, i, 1)
142146
label = 1
143147
lb[label, i] = min_dist
144148
for j in 2:size(centroids, 2)
145149
# triangular inequality
146150
if centroids_dist[j, label] > min_dist
147151
lb[j, i] = min_dist
148152
else
149-
dist = distance(X, centroids, i, j)
153+
dist = distance(metric, X, centroids, i, j)
150154
label = dist < min_dist ? j : label
151155
min_dist = dist < min_dist ? dist : min_dist
152156
lb[j, i] = dist
@@ -161,7 +165,8 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
161165
end
162166
end
163167

164-
function update_containers(::Elkan, containers, centroids, n_threads)
168+
169+
function update_containers(::Elkan, containers, centroids, n_threads, metric)
165170
# unpack containers for easier manipulations
166171
centroids_dist = containers.centroids_dist
167172
T = eltype(centroids)
@@ -170,7 +175,7 @@ function update_containers(::Elkan, containers, centroids, n_threads)
170175
@inbounds for j in axes(centroids_dist, 2)
171176
min_dist = T(Inf)
172177
for i in j + 1:k
173-
d = distance(centroids, centroids, i, j)
178+
d = distance(metric, centroids, centroids, i, j)
174179
centroids_dist[i, j] = d
175180
centroids_dist[j, i] = d
176181
min_dist = min_dist < d ? min_dist : d
@@ -189,7 +194,8 @@ function update_containers(::Elkan, containers, centroids, n_threads)
189194
return centroids_dist
190195
end
191196

192-
function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, idx)
197+
198+
function chunk_update_centroids(::Elkan, containers, centroids, X, weights, metric, r, idx)
193199
# unpack
194200
ub = containers.ub
195201
lb = containers.lb
@@ -214,14 +220,14 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
214220

215221
# one calculation per iteration is enough
216222
if stale[i]
217-
min_dist = distance(X, centroids, i, label)
223+
min_dist = distance(metric, X, centroids, i, label)
218224
lb[label, i] = min_dist
219225
ub[i] = min_dist
220226
stale[i] = false
221227
end
222228

223229
if (min_dist > lb[j, i]) | (min_dist > centroids_dist[j, label])
224-
dist = distance(X, centroids, i, j)
230+
dist = distance(metric, X, centroids, i, j)
225231
lb[j, i] = dist
226232
if dist < min_dist
227233
min_dist = dist
@@ -242,16 +248,18 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
242248
end
243249
end
244250

245-
function calculate_centroids_movement(alg::Elkan, containers, centroids)
251+
252+
function calculate_centroids_movement(alg::Elkan, containers, centroids, metric)
246253
p = containers.p
247254
centroids_new = containers.centroids_new[end]
248255

249256
for i in axes(centroids, 2)
250-
p[i] = distance(centroids, centroids_new, i, i)
257+
p[i] = distance(metric, centroids, centroids_new, i, i)
251258
end
252259
end
253260

254-
function chunk_update_bounds(alg, containers, centroids, r, idx)
261+
262+
function chunk_update_bounds(alg, containers, centroids, metric::Euclidean, r, idx)
255263
p = containers.p
256264
lb = containers.lb
257265
ub = containers.ub
@@ -267,3 +275,22 @@ function chunk_update_bounds(alg, containers, centroids, r, idx)
267275
ub[i] += p[labels[i]] + T(2)*sqrt(abs(ub[i]*p[labels[i]]))
268276
end
269277
end
278+
279+
280+
function chunk_update_bounds(alg, containers, centroids, metric::Metric, r, idx)
281+
p = containers.p
282+
lb = containers.lb
283+
ub = containers.ub
284+
stale = containers.stale
285+
labels = containers.labels
286+
T = eltype(centroids)
287+
288+
@inbounds for i in r
289+
for j in axes(centroids, 2)
290+
lb[j, i] -= p[j]
291+
end
292+
stale[i] = true
293+
ub[i] += p[labels[i]]
294+
end
295+
296+
end

0 commit comments

Comments
 (0)