social-clusters/main.py

80 lines
2.4 KiB
Python
Raw Normal View History

2023-05-26 10:33:54 +04:00
#!/usr/bin/env python3
import os
import sys
2023-06-07 15:24:49 +04:00
from typing import List
2023-05-26 10:33:54 +04:00
2023-06-06 00:33:45 +04:00
import numpy
2023-06-07 15:24:49 +04:00
import numpy as np
2023-06-06 00:33:45 +04:00
import pandas as pd
2023-06-07 15:24:49 +04:00
import scipy.cluster.hierarchy as sc
2023-06-06 00:33:45 +04:00
from matplotlib import pyplot as plt
2023-06-07 15:24:49 +04:00
from numpy import ndarray
from pandas import Series
2023-06-06 00:33:45 +04:00
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
2023-06-05 18:18:18 +04:00
from src.main.df_loader import DfLoader
2023-06-07 15:24:49 +04:00
from src.main.georeverse import Georeverse
2023-05-29 22:56:53 +04:00
2023-06-07 15:24:49 +04:00
is_plots: bool = False
default_clusters: int = 3
georeverse: Georeverse = Georeverse()
2023-05-29 22:56:53 +04:00
2023-06-06 00:33:45 +04:00
2023-06-07 15:24:49 +04:00
def __plots(data: ndarray, labels: ndarray) -> None:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
plt.title('Dendrogram')
pca = PCA(n_components=2)
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
plt.subplot(1, 2, 2)
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.title('Clustering')
plt.show()
2023-06-06 00:33:45 +04:00
2023-06-07 15:24:49 +04:00
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
centers: List[List[float]] = list()
for label in set(labels):
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
centers.append(list(center))
return np.array(centers)
2023-06-06 00:33:45 +04:00
2023-06-07 15:24:49 +04:00
def __print_center(center: ndarray) -> None:
location: str = georeverse.get_city(center[0], center[1])
sex = round(center[2])
age = round(center[3])
is_university = bool(round(center[4]))
is_work = bool(round(center[5]))
is_student = bool(round(center[6]))
is_schoolboy = bool(round(center[7]))
print(f'location: {location}, sex: {sex}, age: {age},'
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
model.fit(data)
labels = model.labels_
if plots:
__plots(data, labels)
centers = __get_cluster_centers(data, labels)
for center in centers:
__print_center(center)
2023-06-06 00:33:45 +04:00
2023-05-29 22:56:53 +04:00
def __main(json_file):
2023-06-07 15:24:49 +04:00
data: ndarray = DfLoader(json_file).get_data()
__clustering(data, default_clusters, is_plots)
2023-05-26 10:33:54 +04:00
if __name__ == '__main__':
if len(sys.argv) != 2:
print('You must specify the raw_dataset json file')
exit(1)
if not os.path.isfile(sys.argv[1]):
print(f'File {sys.argv[1]} is not exists')
__main(sys.argv[1])