80 lines
2.4 KiB
Python
80 lines
2.4 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
from typing import List
|
|
|
|
import numpy
|
|
import numpy as np
|
|
import pandas as pd
|
|
import scipy.cluster.hierarchy as sc
|
|
from matplotlib import pyplot as plt
|
|
from numpy import ndarray
|
|
from pandas import Series
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
from sklearn.decomposition import PCA
|
|
|
|
from src.main.df_loader import DfLoader
|
|
from src.main.georeverse import Georeverse
|
|
|
|
is_plots: bool = False
|
|
default_clusters: int = 3
|
|
georeverse: Georeverse = Georeverse()
|
|
|
|
|
|
def __plots(data: ndarray, labels: ndarray) -> None:
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(1, 2, 1)
|
|
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
|
plt.title('Dendrogram')
|
|
pca = PCA(n_components=2)
|
|
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
|
plt.subplot(1, 2, 2)
|
|
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
|
plt.title('Clustering')
|
|
plt.show()
|
|
|
|
|
|
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
|
centers: List[List[float]] = list()
|
|
for label in set(labels):
|
|
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
|
|
centers.append(list(center))
|
|
return np.array(centers)
|
|
|
|
|
|
def __print_center(center: ndarray) -> None:
|
|
location: str = georeverse.get_city(center[0], center[1])
|
|
sex = round(center[2])
|
|
age = round(center[3])
|
|
is_university = bool(round(center[4]))
|
|
is_work = bool(round(center[5]))
|
|
is_student = bool(round(center[6]))
|
|
is_schoolboy = bool(round(center[7]))
|
|
print(f'location: {location}, sex: {sex}, age: {age},'
|
|
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
|
|
|
|
|
|
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
|
|
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
|
|
model.fit(data)
|
|
labels = model.labels_
|
|
if plots:
|
|
__plots(data, labels)
|
|
centers = __get_cluster_centers(data, labels)
|
|
for center in centers:
|
|
__print_center(center)
|
|
|
|
|
|
def __main(json_file):
|
|
data: ndarray = DfLoader(json_file).get_data()
|
|
__clustering(data, default_clusters, is_plots)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) != 2:
|
|
print('You must specify the raw_dataset json file')
|
|
exit(1)
|
|
if not os.path.isfile(sys.argv[1]):
|
|
print(f'File {sys.argv[1]} is not exists')
|
|
__main(sys.argv[1])
|