@@ -64,7 +64,8 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
6464 method : {'alternate', 'pam'}, default: 'alternate'
6565 Which algorithm to use. 'alternate' is faster while 'pam' is more accurate.
6666
67- init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'heuristic'
67+ init : {'random', 'heuristic', 'k-medoids++', 'build'}, or array-like of shape
68+ (n_clusters, n_features), optional, default: 'heuristic'
6869 Specify medoid initialization method. 'random' selects n_clusters
6970 elements from the dataset. 'heuristic' picks the n_clusters points
7071 with the smallest sum distance to every other point. 'k-medoids++'
@@ -74,6 +75,8 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
7475 algorithm. Often 'build' is more efficient but slower than other
7576 initializations on big datasets and it is also very non-robust,
7677 if there are outliers in the dataset, use another initialization.
78+ If an array is passed, it should be of shape (n_clusters, n_features)
79+ and gives the initial centers.
7780
7881 .. _k-means++: https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
7982
@@ -181,13 +184,29 @@ def _check_init_args(self):
181184
182185 # Check init
183186 init_methods = ["random" , "heuristic" , "k-medoids++" , "build" ]
184- if self .init not in init_methods :
187+ if not (
188+ hasattr (self .init , "__array__" )
189+ or (isinstance (self .init , str ) and self .init in init_methods )
190+ ):
185191 raise ValueError (
186192 "init needs to be one of "
187193 + "the following: "
188- + "%s" % init_methods
194+ + "%s" % ( init_methods + [ "array-like" ])
189195 )
190196
197+ # Check n_clusters
198+ if (
199+ hasattr (self .init , "__array__" )
200+ and self .n_clusters != self .init .shape [0 ]
201+ ):
202+ warnings .warn (
203+ "n_clusters should be equal to size of array-like if init "
204+ "is array-like setting n_clusters to {}." .format (
205+ self .init .shape [0 ]
206+ )
207+ )
208+ self .n_clusters = self .init .shape [0 ]
209+
191210 def fit (self , X , y = None ):
192211 """Fit K-Medoids to the provided data.
193212
@@ -219,7 +238,7 @@ def fit(self, X, y=None):
219238 D = pairwise_distances (X , metric = self .metric )
220239
221240 medoid_idxs = self ._initialize_medoids (
222- D , self .n_clusters , random_state_
241+ D , self .n_clusters , random_state_ , X
223242 )
224243 labels = None
225244
@@ -407,10 +426,14 @@ def predict(self, X):
407426
408427 return pd_argmin
409428
410- def _initialize_medoids (self , D , n_clusters , random_state_ ):
429+ def _initialize_medoids (self , D , n_clusters , random_state_ , X = None ):
411430 """Select initial mediods when beginning clustering."""
412431
413- if self .init == "random" : # Random initialization
432+ if hasattr (self .init , "__array__" ): # Pre assign cluster
433+ medoids = np .hstack (
434+ [np .where ((X == c ).all (axis = 1 )) for c in self .init ]
435+ ).ravel ()
436+ elif self .init == "random" : # Random initialization
414437 # Pick random k medoids as the initial ones.
415438 medoids = random_state_ .choice (len (D ), n_clusters , replace = False )
416439 elif self .init == "k-medoids++" :
0 commit comments