Add agglomerative clustering
This commit is contained in:
parent
0eefc9fde0
commit
ae87945f46
49
main.py
49
main.py
@ -2,13 +2,58 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
import numpy
|
||||||
|
import pandas as pd
|
||||||
|
# import scipy.cluster.hierarchy as sc
|
||||||
|
from matplotlib import pyplot as plt
|
||||||
|
from pandas import DataFrame
|
||||||
|
from sklearn.cluster import AgglomerativeClustering
|
||||||
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
from src.main.df_loader import DfLoader
|
from src.main.df_loader import DfLoader
|
||||||
|
|
||||||
|
|
||||||
|
def __clustering(data: DataFrame) -> None:
|
||||||
|
# clusters = round(math.sqrt(len(data) / 2))
|
||||||
|
# plt.figure(figsize=(20, 7))
|
||||||
|
# plt.title("Dendrograms")
|
||||||
|
# # Create dendrogram
|
||||||
|
# sc.dendrogram(sc.linkage(data.to_numpy(), method='ward'))
|
||||||
|
# plt.title('Dendrogram')
|
||||||
|
# plt.xlabel('Sample index')
|
||||||
|
# plt.ylabel('Euclidean distance')
|
||||||
|
|
||||||
|
clusters = 3
|
||||||
|
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
||||||
|
model.fit(data)
|
||||||
|
labels = model.labels_
|
||||||
|
|
||||||
|
data_norm = (data - data.min()) / (data.max() - data.min())
|
||||||
|
|
||||||
|
pca = PCA(n_components=2) # 2-dimensional PCA
|
||||||
|
transformed = pd.DataFrame(pca.fit_transform(data_norm))
|
||||||
|
# plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||||
|
for i in range(clusters):
|
||||||
|
series = transformed.iloc[numpy.where(labels[:] == i)]
|
||||||
|
plt.scatter(series[0], series[1], label=f'Cluster {i + 1}')
|
||||||
|
plt.legend()
|
||||||
|
plt.show()
|
||||||
|
|
||||||
|
# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
|
||||||
|
# sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering')
|
||||||
|
# sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \
|
||||||
|
# .set_title('With clustering')
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
# s = numpy.where(labels[:] == 34)
|
||||||
|
# print(labels)
|
||||||
|
|
||||||
|
|
||||||
def __main(json_file):
|
def __main(json_file):
|
||||||
df_loader: DfLoader = DfLoader(json_file)
|
df_loader: DfLoader = DfLoader(json_file)
|
||||||
df = df_loader.get_data_frame()
|
data = df_loader.get_clustering_data()
|
||||||
print('done')
|
print(data)
|
||||||
|
__clustering(data)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
@ -1,3 +1,7 @@
|
|||||||
pandas==2.0.1
|
pandas==2.0.1
|
||||||
geopy==2.3.0
|
geopy==2.3.0
|
||||||
numpy==1.24.3
|
numpy==1.24.3
|
||||||
|
scikit-learn==1.2.2
|
||||||
|
matplotlib==3.7.1
|
||||||
|
seaborn==0.12.2
|
||||||
|
scipy==1.10.1
|
Loading…
Reference in New Issue
Block a user