Add agglomerative clustering

This commit is contained in:
Aleksey Filippov 2023-06-06 00:33:45 +04:00
parent 0eefc9fde0
commit ae87945f46
2 changed files with 52 additions and 3 deletions

49
main.py
View File

@ -2,13 +2,58 @@
import os import os
import sys import sys
import numpy
import pandas as pd
# import scipy.cluster.hierarchy as sc
from matplotlib import pyplot as plt
from pandas import DataFrame
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from src.main.df_loader import DfLoader from src.main.df_loader import DfLoader
def __clustering(data: DataFrame) -> None:
# clusters = round(math.sqrt(len(data) / 2))
# plt.figure(figsize=(20, 7))
# plt.title("Dendrograms")
# # Create dendrogram
# sc.dendrogram(sc.linkage(data.to_numpy(), method='ward'))
# plt.title('Dendrogram')
# plt.xlabel('Sample index')
# plt.ylabel('Euclidean distance')
clusters = 3
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
model.fit(data)
labels = model.labels_
data_norm = (data - data.min()) / (data.max() - data.min())
pca = PCA(n_components=2) # 2-dimensional PCA
transformed = pd.DataFrame(pca.fit_transform(data_norm))
# plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
for i in range(clusters):
series = transformed.iloc[numpy.where(labels[:] == i)]
plt.scatter(series[0], series[1], label=f'Cluster {i + 1}')
plt.legend()
plt.show()
# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
# sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering')
# sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \
# .set_title('With clustering')
# plt.show()
# s = numpy.where(labels[:] == 34)
# print(labels)
def __main(json_file): def __main(json_file):
df_loader: DfLoader = DfLoader(json_file) df_loader: DfLoader = DfLoader(json_file)
df = df_loader.get_data_frame() data = df_loader.get_clustering_data()
print('done') print(data)
__clustering(data)
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -1,3 +1,7 @@
pandas==2.0.1 pandas==2.0.1
geopy==2.3.0 geopy==2.3.0
numpy==1.24.3 numpy==1.24.3
scikit-learn==1.2.2
matplotlib==3.7.1
seaborn==0.12.2
scipy==1.10.1