social-clusters/main.py

155 lines
5.1 KiB
Python
Raw Permalink Normal View History

2023-05-26 10:33:54 +04:00
#!/usr/bin/env python3
import os
import sys
2023-06-08 01:05:04 +04:00
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
2023-06-07 15:24:49 +04:00
from typing import List
2023-05-26 10:33:54 +04:00
2023-06-06 00:33:45 +04:00
import numpy
2023-06-07 15:24:49 +04:00
import numpy as np
2023-06-06 00:33:45 +04:00
import pandas as pd
2023-06-07 15:24:49 +04:00
import scipy.cluster.hierarchy as sc
2023-06-08 01:05:04 +04:00
from anytree import RenderTree
2023-06-06 00:33:45 +04:00
from matplotlib import pyplot as plt
2023-06-07 15:24:49 +04:00
from numpy import ndarray
from pandas import Series
2023-06-06 00:33:45 +04:00
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
2023-06-05 18:18:18 +04:00
from src.main.df_loader import DfLoader
2023-06-07 15:24:49 +04:00
from src.main.georeverse import Georeverse
2023-06-08 01:05:04 +04:00
from src.main.tree_node import TreeNode
from src.main.tree_view import TreeView
2023-05-29 22:56:53 +04:00
2023-06-08 01:05:04 +04:00
MAX_LEVEL: int = 5
GENDERS: dict = {
0: 'не указан',
1: 'женский',
2: 'мужской'
}
2023-06-07 15:24:49 +04:00
georeverse: Georeverse = Georeverse()
2023-05-29 22:56:53 +04:00
2023-06-08 01:05:04 +04:00
is_plots: bool = False
2023-06-06 00:33:45 +04:00
2023-06-08 01:05:04 +04:00
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
print(level)
if data.shape[1] > 1:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
else:
plt.figure(figsize=(6, 6))
2023-06-07 15:24:49 +04:00
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
plt.title('Dendrogram')
2023-06-08 01:05:04 +04:00
if data.shape[1] > 1:
pca = PCA(n_components=2)
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
plt.subplot(1, 2, 2)
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.title('Clustering')
2023-06-07 15:24:49 +04:00
plt.show()
2023-06-06 00:33:45 +04:00
2023-06-07 15:24:49 +04:00
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
centers: List[List[float]] = list()
for label in set(labels):
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
centers.append(list(center))
return np.array(centers)
2023-06-06 00:33:45 +04:00
2023-06-08 01:05:04 +04:00
def __format_center(center: ndarray, level: int) -> str:
if level == 1:
location: str = str(georeverse.get_city(center[0], center[1]))
return location
if level == 3:
age = round(center[0])
return str(age)
if level == 2:
sex = round(center[0])
return GENDERS[sex]
# if level == 4:
# is_university = bool(round(center[0]))
# is_work = bool(round(center[1]))
# is_student = bool(round(center[2]))
# is_schoolboy = bool(round(center[3]))
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
raise Exception(f'Unknown level {level}')
2023-06-07 15:24:49 +04:00
2023-06-08 01:05:04 +04:00
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
if level == MAX_LEVEL:
return
cl_data = None
clusters = 0
if level == 1:
cl_data = data[:, 0:2]
clusters = 3
if level == 3:
cl_data = data[:, 3].reshape(-1, 1)
clusters = 7
if level == 2:
cl_data = data[:, 2].reshape(-1, 1)
clusters = 2
if level == 4:
univer: int = len(data[np.where(data[:, 4] == 1)])
work: int = len(data[np.where(data[:, 5] == 1)])
student: int = len(data[np.where(data[:, 6] == 1)])
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
TreeNode(1, 'Высшее образование', univer, parent=root_node)
TreeNode(2, 'Работает', work, parent=root_node)
TreeNode(3, 'Студент', student, parent=root_node)
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
return
if cl_data is None:
raise Exception(f'Unknown level {level}')
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
if len(cl_data) > 1:
model.fit(cl_data)
labels = model.labels_
else:
labels = np.array([0])
2023-06-07 15:24:49 +04:00
if plots:
2023-06-08 01:05:04 +04:00
__plots(cl_data, labels, level)
centers = __get_cluster_centers(cl_data, labels)
nodes: dict = {}
for index, center in enumerate(centers):
size: int = len(cl_data[numpy.where(labels[:] == index)])
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
nodes[index] = node
if len(centers) == 1:
return
for cluster in range(clusters):
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
def __tree_sort(items):
return sorted(items, key=lambda item: item.data)
2023-06-06 00:33:45 +04:00
2023-05-29 22:56:53 +04:00
def __main(json_file):
2023-06-07 15:24:49 +04:00
data: ndarray = DfLoader(json_file).get_data()
2023-06-08 01:05:04 +04:00
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
__clustering(data, tree_root, is_plots)
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
root = Tk(className='Clustering')
root.geometry('800x600')
frame = Frame(root)
frame.grid(column=0, row=0, sticky="nsew")
scrollbar = Scrollbar(frame)
scrollbar.pack(side=RIGHT, fill=Y)
tv = TreeView(frame, tree_root)
tv.generate()
scrollbar.config(command=tv.get().yview)
root.rowconfigure(0, weight=1)
root.columnconfigure(0, weight=1)
root.mainloop()
2023-05-26 10:33:54 +04:00
if __name__ == '__main__':
if len(sys.argv) != 2:
print('You must specify the raw_dataset json file')
exit(1)
if not os.path.isfile(sys.argv[1]):
print(f'File {sys.argv[1]} is not exists')
__main(sys.argv[1])