#!/usr/bin/env python3 import os import sys from typing import List import numpy import numpy as np import pandas as pd import scipy.cluster.hierarchy as sc from matplotlib import pyplot as plt from numpy import ndarray from pandas import Series from sklearn.cluster import AgglomerativeClustering from sklearn.decomposition import PCA from src.main.df_loader import DfLoader from src.main.georeverse import Georeverse is_plots: bool = False default_clusters: int = 3 georeverse: Georeverse = Georeverse() def __plots(data: ndarray, labels: ndarray) -> None: plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level') plt.title('Dendrogram') pca = PCA(n_components=2) transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() plt.subplot(1, 2, 2) plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') plt.title('Clustering') plt.show() def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray: centers: List[List[float]] = list() for label in set(labels): center: Series = data[numpy.where(labels[:] == label)].mean(axis=0) centers.append(list(center)) return np.array(centers) def __print_center(center: ndarray) -> None: location: str = georeverse.get_city(center[0], center[1]) sex = round(center[2]) age = round(center[3]) is_university = bool(round(center[4])) is_work = bool(round(center[5])) is_student = bool(round(center[6])) is_schoolboy = bool(round(center[7])) print(f'location: {location}, sex: {sex}, age: {age},' f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}') def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None: model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') model.fit(data) labels = model.labels_ if plots: __plots(data, labels) centers = __get_cluster_centers(data, labels) for center in centers: __print_center(center) def __main(json_file): data: ndarray = DfLoader(json_file).get_data() __clustering(data, default_clusters, is_plots) if __name__ == '__main__': if len(sys.argv) != 2: print('You must specify the raw_dataset json file') exit(1) if not os.path.isfile(sys.argv[1]): print(f'File {sys.argv[1]} is not exists') __main(sys.argv[1])