@@ -18,16 +18,18 @@ kmeans(Elkan(), X, 3) # 3 clusters, Elkan algorithm
1818"""
1919struct Elkan <: AbstractKMeansAlg end
2020
21- function kmeans! (alg:: Elkan , containers, X, k, weights;
21+
22+ function kmeans! (alg:: Elkan , containers, X, k, weights= nothing , metric= Euclidean ();
2223 n_threads = Threads. nthreads (),
2324 k_init = " k-means++" , max_iters = 300 ,
2425 tol = eltype (X)(1e-6 ), verbose = false ,
2526 init = nothing , rng = Random. GLOBAL_RNG)
27+
2628 nrow, ncol = size (X)
2729 centroids = init == nothing ? smart_init (X, k, n_threads, weights, rng, init= k_init). centroids : deepcopy (init)
2830
29- update_containers (alg, containers, centroids, n_threads)
30- @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights)
31+ update_containers (alg, containers, centroids, n_threads, metric )
32+ @parallelize n_threads ncol chunk_initialize (alg, containers, centroids, X, weights, metric )
3133
3234 T = eltype (X)
3335 converged = false
@@ -38,7 +40,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
3840 while niters < max_iters
3941 niters += 1
4042 # Core iteration
41- @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights)
43+ @parallelize n_threads ncol chunk_update_centroids (alg, containers, centroids, X, weights, metric )
4244
4345 # Collect distributed containers (such as centroids_new, centroids_cnt)
4446 # in paper it is step 4
@@ -47,10 +49,10 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
4749 J = sum (containers. ub)
4850
4951 # auxiliary calculation, in paper it's d(c, m(c))
50- calculate_centroids_movement (alg, containers, centroids)
52+ calculate_centroids_movement (alg, containers, centroids, metric )
5153
5254 # lower and ounds update, in paper it's steps 5 and 6
53- @parallelize n_threads ncol chunk_update_bounds (alg, containers, centroids)
55+ @parallelize n_threads ncol chunk_update_bounds (alg, containers, centroids, metric )
5456
5557 # Step 7, final assignment of new centroids
5658 centroids .= containers. centroids_new[end ]
@@ -67,11 +69,11 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
6769 end
6870
6971 # Step 1 in original paper, calulation of distance d(c, c')
70- update_containers (alg, containers, centroids, n_threads)
72+ update_containers (alg, containers, centroids, n_threads, metric )
7173 J_previous = J
7274 end
7375
74- @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights)
76+ @parallelize n_threads ncol sum_of_squares (containers, X, containers. labels, centroids, weights, metric )
7577 totalcost = sum (containers. sum_of_squares)
7678
7779 # Terminate algorithm with the assumption that K-means has converged
@@ -85,6 +87,7 @@ function kmeans!(alg::Elkan, containers, X, k, weights;
8587 return KmeansResult (centroids, containers. labels, T[], Int[], T[], totalcost, niters, converged)
8688end
8789
90+
8891function create_containers (alg:: Elkan , X, k, nrow, ncol, n_threads)
8992 T = eltype (X)
9093 lng = n_threads + 1
@@ -128,7 +131,8 @@ function create_containers(alg::Elkan, X, k, nrow, ncol, n_threads)
128131 )
129132end
130133
131- function chunk_initialize (:: Elkan , containers, centroids, X, weights, r, idx)
134+
135+ function chunk_initialize (:: Elkan , containers, centroids, X, weights, metric, r, idx)
132136 ub = containers. ub
133137 lb = containers. lb
134138 centroids_dist = containers. centroids_dist
@@ -138,15 +142,15 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
138142 T = eltype (X)
139143
140144 @inbounds for i in r
141- min_dist = distance (X, centroids, i, 1 )
145+ min_dist = distance (metric, X, centroids, i, 1 )
142146 label = 1
143147 lb[label, i] = min_dist
144148 for j in 2 : size (centroids, 2 )
145149 # triangular inequality
146150 if centroids_dist[j, label] > min_dist
147151 lb[j, i] = min_dist
148152 else
149- dist = distance (X, centroids, i, j)
153+ dist = distance (metric, X, centroids, i, j)
150154 label = dist < min_dist ? j : label
151155 min_dist = dist < min_dist ? dist : min_dist
152156 lb[j, i] = dist
@@ -161,7 +165,8 @@ function chunk_initialize(::Elkan, containers, centroids, X, weights, r, idx)
161165 end
162166end
163167
164- function update_containers (:: Elkan , containers, centroids, n_threads)
168+
169+ function update_containers (:: Elkan , containers, centroids, n_threads, metric)
165170 # unpack containers for easier manipulations
166171 centroids_dist = containers. centroids_dist
167172 T = eltype (centroids)
@@ -170,7 +175,7 @@ function update_containers(::Elkan, containers, centroids, n_threads)
170175 @inbounds for j in axes (centroids_dist, 2 )
171176 min_dist = T (Inf )
172177 for i in j + 1 : k
173- d = distance (centroids, centroids, i, j)
178+ d = distance (metric, centroids, centroids, i, j)
174179 centroids_dist[i, j] = d
175180 centroids_dist[j, i] = d
176181 min_dist = min_dist < d ? min_dist : d
@@ -189,7 +194,8 @@ function update_containers(::Elkan, containers, centroids, n_threads)
189194 return centroids_dist
190195end
191196
192- function chunk_update_centroids (:: Elkan , containers, centroids, X, weights, r, idx)
197+
198+ function chunk_update_centroids (:: Elkan , containers, centroids, X, weights, metric, r, idx)
193199 # unpack
194200 ub = containers. ub
195201 lb = containers. lb
@@ -214,14 +220,14 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
214220
215221 # one calculation per iteration is enough
216222 if stale[i]
217- min_dist = distance (X, centroids, i, label)
223+ min_dist = distance (metric, X, centroids, i, label)
218224 lb[label, i] = min_dist
219225 ub[i] = min_dist
220226 stale[i] = false
221227 end
222228
223229 if (min_dist > lb[j, i]) | (min_dist > centroids_dist[j, label])
224- dist = distance (X, centroids, i, j)
230+ dist = distance (metric, X, centroids, i, j)
225231 lb[j, i] = dist
226232 if dist < min_dist
227233 min_dist = dist
@@ -242,16 +248,18 @@ function chunk_update_centroids(::Elkan, containers, centroids, X, weights, r, i
242248 end
243249end
244250
245- function calculate_centroids_movement (alg:: Elkan , containers, centroids)
251+
252+ function calculate_centroids_movement (alg:: Elkan , containers, centroids, metric)
246253 p = containers. p
247254 centroids_new = containers. centroids_new[end ]
248255
249256 for i in axes (centroids, 2 )
250- p[i] = distance (centroids, centroids_new, i, i)
257+ p[i] = distance (metric, centroids, centroids_new, i, i)
251258 end
252259end
253260
254- function chunk_update_bounds (alg, containers, centroids, r, idx)
261+
262+ function chunk_update_bounds (alg, containers, centroids, metric:: Euclidean , r, idx)
255263 p = containers. p
256264 lb = containers. lb
257265 ub = containers. ub
@@ -267,3 +275,22 @@ function chunk_update_bounds(alg, containers, centroids, r, idx)
267275 ub[i] += p[labels[i]] + T (2 )* sqrt (abs (ub[i]* p[labels[i]]))
268276 end
269277end
278+
279+
280+ function chunk_update_bounds (alg, containers, centroids, metric:: Metric , r, idx)
281+ p = containers. p
282+ lb = containers. lb
283+ ub = containers. ub
284+ stale = containers. stale
285+ labels = containers. labels
286+ T = eltype (centroids)
287+
288+ @inbounds for i in r
289+ for j in axes (centroids, 2 )
290+ lb[j, i] -= p[j]
291+ end
292+ stale[i] = true
293+ ub[i] += p[labels[i]]
294+ end
295+
296+ end
0 commit comments