social-clusters/main.py

155 lines
5.1 KiB
Python

#!/usr/bin/env python3
import os
import sys
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
from typing import List
import numpy
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sc
from anytree import RenderTree
from matplotlib import pyplot as plt
from numpy import ndarray
from pandas import Series
from sklearn.cluster import AgglomerativeClustering
from sklearn.decomposition import PCA
from src.main.df_loader import DfLoader
from src.main.georeverse import Georeverse
from src.main.tree_node import TreeNode
from src.main.tree_view import TreeView
MAX_LEVEL: int = 5
GENDERS: dict = {
0: 'не указан',
1: 'женский',
2: 'мужской'
}
georeverse: Georeverse = Georeverse()
is_plots: bool = False
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
print(level)
if data.shape[1] > 1:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
else:
plt.figure(figsize=(6, 6))
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
plt.title('Dendrogram')
if data.shape[1] > 1:
pca = PCA(n_components=2)
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
plt.subplot(1, 2, 2)
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.title('Clustering')
plt.show()
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
centers: List[List[float]] = list()
for label in set(labels):
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
centers.append(list(center))
return np.array(centers)
def __format_center(center: ndarray, level: int) -> str:
if level == 1:
location: str = str(georeverse.get_city(center[0], center[1]))
return location
if level == 3:
age = round(center[0])
return str(age)
if level == 2:
sex = round(center[0])
return GENDERS[sex]
# if level == 4:
# is_university = bool(round(center[0]))
# is_work = bool(round(center[1]))
# is_student = bool(round(center[2]))
# is_schoolboy = bool(round(center[3]))
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
raise Exception(f'Unknown level {level}')
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
if level == MAX_LEVEL:
return
cl_data = None
clusters = 0
if level == 1:
cl_data = data[:, 0:2]
clusters = 3
if level == 3:
cl_data = data[:, 3].reshape(-1, 1)
clusters = 7
if level == 2:
cl_data = data[:, 2].reshape(-1, 1)
clusters = 2
if level == 4:
univer: int = len(data[np.where(data[:, 4] == 1)])
work: int = len(data[np.where(data[:, 5] == 1)])
student: int = len(data[np.where(data[:, 6] == 1)])
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
TreeNode(1, 'Высшее образование', univer, parent=root_node)
TreeNode(2, 'Работает', work, parent=root_node)
TreeNode(3, 'Студент', student, parent=root_node)
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
return
if cl_data is None:
raise Exception(f'Unknown level {level}')
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
if len(cl_data) > 1:
model.fit(cl_data)
labels = model.labels_
else:
labels = np.array([0])
if plots:
__plots(cl_data, labels, level)
centers = __get_cluster_centers(cl_data, labels)
nodes: dict = {}
for index, center in enumerate(centers):
size: int = len(cl_data[numpy.where(labels[:] == index)])
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
nodes[index] = node
if len(centers) == 1:
return
for cluster in range(clusters):
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
def __tree_sort(items):
return sorted(items, key=lambda item: item.data)
def __main(json_file):
data: ndarray = DfLoader(json_file).get_data()
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
__clustering(data, tree_root, is_plots)
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
root = Tk(className='Clustering')
root.geometry('800x600')
frame = Frame(root)
frame.grid(column=0, row=0, sticky="nsew")
scrollbar = Scrollbar(frame)
scrollbar.pack(side=RIGHT, fill=Y)
tv = TreeView(frame, tree_root)
tv.generate()
scrollbar.config(command=tv.get().yview)
root.rowconfigure(0, weight=1)
root.columnconfigure(0, weight=1)
root.mainloop()
if __name__ == '__main__':
if len(sys.argv) != 2:
print('You must specify the raw_dataset json file')
exit(1)
if not os.path.isfile(sys.argv[1]):
print(f'File {sys.argv[1]} is not exists')
__main(sys.argv[1])