66 lines
2.0 KiB
Python
66 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
|
|
import numpy
|
|
import pandas as pd
|
|
# import scipy.cluster.hierarchy as sc
|
|
from matplotlib import pyplot as plt
|
|
from pandas import DataFrame
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
from sklearn.decomposition import PCA
|
|
|
|
from src.main.df_loader import DfLoader
|
|
|
|
|
|
def __clustering(data: DataFrame) -> None:
|
|
# clusters = round(math.sqrt(len(data) / 2))
|
|
# plt.figure(figsize=(20, 7))
|
|
# plt.title("Dendrograms")
|
|
# # Create dendrogram
|
|
# sc.dendrogram(sc.linkage(data.to_numpy(), method='ward'))
|
|
# plt.title('Dendrogram')
|
|
# plt.xlabel('Sample index')
|
|
# plt.ylabel('Euclidean distance')
|
|
|
|
clusters = 3
|
|
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
|
model.fit(data)
|
|
labels = model.labels_
|
|
|
|
data_norm = (data - data.min()) / (data.max() - data.min())
|
|
|
|
pca = PCA(n_components=2) # 2-dimensional PCA
|
|
transformed = pd.DataFrame(pca.fit_transform(data_norm))
|
|
# plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
|
for i in range(clusters):
|
|
series = transformed.iloc[numpy.where(labels[:] == i)]
|
|
plt.scatter(series[0], series[1], label=f'Cluster {i + 1}')
|
|
plt.legend()
|
|
plt.show()
|
|
|
|
# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
|
|
# sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering')
|
|
# sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \
|
|
# .set_title('With clustering')
|
|
# plt.show()
|
|
|
|
# s = numpy.where(labels[:] == 34)
|
|
# print(labels)
|
|
|
|
|
|
def __main(json_file):
|
|
df_loader: DfLoader = DfLoader(json_file)
|
|
data = df_loader.get_clustering_data()
|
|
print(data)
|
|
__clustering(data)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) != 2:
|
|
print('You must specify the raw_dataset json file')
|
|
exit(1)
|
|
if not os.path.isfile(sys.argv[1]):
|
|
print(f'File {sys.argv[1]} is not exists')
|
|
__main(sys.argv[1])
|