155 lines
5.1 KiB
Python
155 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
|
import os
|
|
import sys
|
|
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
|
|
from typing import List
|
|
|
|
import numpy
|
|
import numpy as np
|
|
import pandas as pd
|
|
import scipy.cluster.hierarchy as sc
|
|
from anytree import RenderTree
|
|
from matplotlib import pyplot as plt
|
|
from numpy import ndarray
|
|
from pandas import Series
|
|
from sklearn.cluster import AgglomerativeClustering
|
|
from sklearn.decomposition import PCA
|
|
|
|
from src.main.df_loader import DfLoader
|
|
from src.main.georeverse import Georeverse
|
|
from src.main.tree_node import TreeNode
|
|
from src.main.tree_view import TreeView
|
|
|
|
MAX_LEVEL: int = 5
|
|
GENDERS: dict = {
|
|
0: 'не указан',
|
|
1: 'женский',
|
|
2: 'мужской'
|
|
}
|
|
georeverse: Georeverse = Georeverse()
|
|
|
|
is_plots: bool = False
|
|
|
|
|
|
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
|
|
print(level)
|
|
if data.shape[1] > 1:
|
|
plt.figure(figsize=(12, 6))
|
|
plt.subplot(1, 2, 1)
|
|
else:
|
|
plt.figure(figsize=(6, 6))
|
|
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
|
plt.title('Dendrogram')
|
|
if data.shape[1] > 1:
|
|
pca = PCA(n_components=2)
|
|
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
|
plt.subplot(1, 2, 2)
|
|
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
|
plt.title('Clustering')
|
|
plt.show()
|
|
|
|
|
|
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
|
centers: List[List[float]] = list()
|
|
for label in set(labels):
|
|
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
|
|
centers.append(list(center))
|
|
return np.array(centers)
|
|
|
|
|
|
def __format_center(center: ndarray, level: int) -> str:
|
|
if level == 1:
|
|
location: str = str(georeverse.get_city(center[0], center[1]))
|
|
return location
|
|
if level == 3:
|
|
age = round(center[0])
|
|
return str(age)
|
|
if level == 2:
|
|
sex = round(center[0])
|
|
return GENDERS[sex]
|
|
# if level == 4:
|
|
# is_university = bool(round(center[0]))
|
|
# is_work = bool(round(center[1]))
|
|
# is_student = bool(round(center[2]))
|
|
# is_schoolboy = bool(round(center[3]))
|
|
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
|
|
raise Exception(f'Unknown level {level}')
|
|
|
|
|
|
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
|
|
if level == MAX_LEVEL:
|
|
return
|
|
cl_data = None
|
|
clusters = 0
|
|
if level == 1:
|
|
cl_data = data[:, 0:2]
|
|
clusters = 3
|
|
if level == 3:
|
|
cl_data = data[:, 3].reshape(-1, 1)
|
|
clusters = 7
|
|
if level == 2:
|
|
cl_data = data[:, 2].reshape(-1, 1)
|
|
clusters = 2
|
|
if level == 4:
|
|
univer: int = len(data[np.where(data[:, 4] == 1)])
|
|
work: int = len(data[np.where(data[:, 5] == 1)])
|
|
student: int = len(data[np.where(data[:, 6] == 1)])
|
|
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
|
|
TreeNode(1, 'Высшее образование', univer, parent=root_node)
|
|
TreeNode(2, 'Работает', work, parent=root_node)
|
|
TreeNode(3, 'Студент', student, parent=root_node)
|
|
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
|
|
return
|
|
if cl_data is None:
|
|
raise Exception(f'Unknown level {level}')
|
|
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
|
if len(cl_data) > 1:
|
|
model.fit(cl_data)
|
|
labels = model.labels_
|
|
else:
|
|
labels = np.array([0])
|
|
if plots:
|
|
__plots(cl_data, labels, level)
|
|
centers = __get_cluster_centers(cl_data, labels)
|
|
nodes: dict = {}
|
|
for index, center in enumerate(centers):
|
|
size: int = len(cl_data[numpy.where(labels[:] == index)])
|
|
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
|
|
nodes[index] = node
|
|
if len(centers) == 1:
|
|
return
|
|
for cluster in range(clusters):
|
|
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
|
|
|
|
|
|
def __tree_sort(items):
|
|
return sorted(items, key=lambda item: item.data)
|
|
|
|
|
|
def __main(json_file):
|
|
data: ndarray = DfLoader(json_file).get_data()
|
|
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
|
|
__clustering(data, tree_root, is_plots)
|
|
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
|
|
root = Tk(className='Clustering')
|
|
root.geometry('800x600')
|
|
frame = Frame(root)
|
|
frame.grid(column=0, row=0, sticky="nsew")
|
|
scrollbar = Scrollbar(frame)
|
|
scrollbar.pack(side=RIGHT, fill=Y)
|
|
tv = TreeView(frame, tree_root)
|
|
tv.generate()
|
|
scrollbar.config(command=tv.get().yview)
|
|
root.rowconfigure(0, weight=1)
|
|
root.columnconfigure(0, weight=1)
|
|
root.mainloop()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
if len(sys.argv) != 2:
|
|
print('You must specify the raw_dataset json file')
|
|
exit(1)
|
|
if not os.path.isfile(sys.argv[1]):
|
|
print(f'File {sys.argv[1]} is not exists')
|
|
__main(sys.argv[1])
|