Fix max_clusters calculation
This commit is contained in:
parent
d97f465c5e
commit
d30c1005e9
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
||||
import math
|
||||
from typing import Dict
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
@ -7,8 +8,9 @@ from sklearn import cluster, metrics
|
||||
|
||||
|
||||
def get_best_clusters_num(
|
||||
X: DataFrame, random_state: int, max_clusters: int = 10
|
||||
X: DataFrame, random_state: int, max_clusters: int | None = None
|
||||
) -> Dict[int, float]:
|
||||
max_clusters = int(math.sqrt(len(X)) + 0.5) + 1
|
||||
silhouette_scores: Dict[int, float] = {}
|
||||
for cluster_num in range(2, max_clusters + 1):
|
||||
kmeans = cluster.KMeans(n_clusters=cluster_num, random_state=random_state)
|
||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user