Fix max_clusters calculation
This commit is contained in:
parent
d97f465c5e
commit
d30c1005e9
File diff suppressed because one or more lines are too long
@ -1,3 +1,4 @@
|
|||||||
|
import math
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
|
||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
@ -7,8 +8,9 @@ from sklearn import cluster, metrics
|
|||||||
|
|
||||||
|
|
||||||
def get_best_clusters_num(
|
def get_best_clusters_num(
|
||||||
X: DataFrame, random_state: int, max_clusters: int = 10
|
X: DataFrame, random_state: int, max_clusters: int | None = None
|
||||||
) -> Dict[int, float]:
|
) -> Dict[int, float]:
|
||||||
|
max_clusters = int(math.sqrt(len(X)) + 0.5) + 1
|
||||||
silhouette_scores: Dict[int, float] = {}
|
silhouette_scores: Dict[int, float] = {}
|
||||||
for cluster_num in range(2, max_clusters + 1):
|
for cluster_num in range(2, max_clusters + 1):
|
||||||
kmeans = cluster.KMeans(n_clusters=cluster_num, random_state=random_state)
|
kmeans = cluster.KMeans(n_clusters=cluster_num, random_state=random_state)
|
||||||
|
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user