From 155b350e1eaf1821d23b4c286e74abe7438ba499 Mon Sep 17 00:00:00 2001 From: Aleksey Filippov Date: Wed, 7 Jun 2023 15:24:49 +0400 Subject: [PATCH] Add cluster centers extraction --- main.py | 88 ++++++++++++++++++++++++------------------ src/main/georeverse.py | 12 ++++++ 2 files changed, 63 insertions(+), 37 deletions(-) create mode 100644 src/main/georeverse.py diff --git a/main.py b/main.py index 78dd017..568fff4 100644 --- a/main.py +++ b/main.py @@ -1,59 +1,73 @@ #!/usr/bin/env python3 import os import sys +from typing import List import numpy +import numpy as np import pandas as pd -# import scipy.cluster.hierarchy as sc +import scipy.cluster.hierarchy as sc from matplotlib import pyplot as plt -from pandas import DataFrame +from numpy import ndarray +from pandas import Series from sklearn.cluster import AgglomerativeClustering from sklearn.decomposition import PCA from src.main.df_loader import DfLoader +from src.main.georeverse import Georeverse + +is_plots: bool = False +default_clusters: int = 3 +georeverse: Georeverse = Georeverse() -def __clustering(data: DataFrame) -> None: - # clusters = round(math.sqrt(len(data) / 2)) - # plt.figure(figsize=(20, 7)) - # plt.title("Dendrograms") - # # Create dendrogram - # sc.dendrogram(sc.linkage(data.to_numpy(), method='ward')) - # plt.title('Dendrogram') - # plt.xlabel('Sample index') - # plt.ylabel('Euclidean distance') - - clusters = 3 - model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward') - model.fit(data) - labels = model.labels_ - - data_norm = (data - data.min()) / (data.max() - data.min()) - - pca = PCA(n_components=2) # 2-dimensional PCA - transformed = pd.DataFrame(pca.fit_transform(data_norm)) - # plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') - for i in range(clusters): - series = transformed.iloc[numpy.where(labels[:] == i)] - plt.scatter(series[0], series[1], label=f'Cluster {i + 1}') - plt.legend() +def __plots(data: ndarray, labels: ndarray) -> None: + plt.figure(figsize=(12, 6)) + plt.subplot(1, 2, 1) + sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level') + plt.title('Dendrogram') + pca = PCA(n_components=2) + transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() + plt.subplot(1, 2, 2) + plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') + plt.title('Clustering') plt.show() - # fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5)) - # sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering') - # sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \ - # .set_title('With clustering') - # plt.show() - # s = numpy.where(labels[:] == 34) - # print(labels) +def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray: + centers: List[List[float]] = list() + for label in set(labels): + center: Series = data[numpy.where(labels[:] == label)].mean(axis=0) + centers.append(list(center)) + return np.array(centers) + + +def __print_center(center: ndarray) -> None: + location: str = georeverse.get_city(center[0], center[1]) + sex = round(center[2]) + age = round(center[3]) + is_university = bool(round(center[4])) + is_work = bool(round(center[5])) + is_student = bool(round(center[6])) + is_schoolboy = bool(round(center[7])) + print(f'location: {location}, sex: {sex}, age: {age},' + f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}') + + +def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None: + model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') + model.fit(data) + labels = model.labels_ + if plots: + __plots(data, labels) + centers = __get_cluster_centers(data, labels) + for center in centers: + __print_center(center) def __main(json_file): - df_loader: DfLoader = DfLoader(json_file) - data = df_loader.get_clustering_data() - print(data) - __clustering(data) + data: ndarray = DfLoader(json_file).get_data() + __clustering(data, default_clusters, is_plots) if __name__ == '__main__': diff --git a/src/main/georeverse.py b/src/main/georeverse.py new file mode 100644 index 0000000..4ca8e10 --- /dev/null +++ b/src/main/georeverse.py @@ -0,0 +1,12 @@ +from functools import partial + +from geopy import Nominatim + + +class Georeverse: + def __init__(self) -> None: + geolocator: Nominatim = Nominatim(user_agent="MyApp") + self.__reverse = partial(geolocator.reverse, language="ru") + + def get_city(self, latitude: float, longitude: float) -> str: + return self.__reverse(f'{latitude}, {longitude}')