From ae87945f4687b8b2a75af3289e1c4a2c9b3b9b5a Mon Sep 17 00:00:00 2001 From: Aleksey Filippov Date: Tue, 6 Jun 2023 00:33:45 +0400 Subject: [PATCH] Add agglomerative clustering --- main.py | 49 ++++++++++++++++++++++++++++++++++++++++++++++-- requirements.txt | 6 +++++- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/main.py b/main.py index 15b2373..78dd017 100644 --- a/main.py +++ b/main.py @@ -2,13 +2,58 @@ import os import sys +import numpy +import pandas as pd +# import scipy.cluster.hierarchy as sc +from matplotlib import pyplot as plt +from pandas import DataFrame +from sklearn.cluster import AgglomerativeClustering +from sklearn.decomposition import PCA + from src.main.df_loader import DfLoader +def __clustering(data: DataFrame) -> None: + # clusters = round(math.sqrt(len(data) / 2)) + # plt.figure(figsize=(20, 7)) + # plt.title("Dendrograms") + # # Create dendrogram + # sc.dendrogram(sc.linkage(data.to_numpy(), method='ward')) + # plt.title('Dendrogram') + # plt.xlabel('Sample index') + # plt.ylabel('Euclidean distance') + + clusters = 3 + model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward') + model.fit(data) + labels = model.labels_ + + data_norm = (data - data.min()) / (data.max() - data.min()) + + pca = PCA(n_components=2) # 2-dimensional PCA + transformed = pd.DataFrame(pca.fit_transform(data_norm)) + # plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') + for i in range(clusters): + series = transformed.iloc[numpy.where(labels[:] == i)] + plt.scatter(series[0], series[1], label=f'Cluster {i + 1}') + plt.legend() + plt.show() + + # fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) + # sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering') + # sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \ + # .set_title('With clustering') + # plt.show() + + # s = numpy.where(labels[:] == 34) + # print(labels) + + def __main(json_file): df_loader: DfLoader = DfLoader(json_file) - df = df_loader.get_data_frame() - print('done') + data = df_loader.get_clustering_data() + print(data) + __clustering(data) if __name__ == '__main__': diff --git a/requirements.txt b/requirements.txt index 069d67a..a4da13b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,7 @@ pandas==2.0.1 geopy==2.3.0 -numpy==1.24.3 \ No newline at end of file +numpy==1.24.3 +scikit-learn==1.2.2 +matplotlib==3.7.1 +seaborn==0.12.2 +scipy==1.10.1 \ No newline at end of file