Complete main algorithm, add GUI

This commit is contained in:
Aleksey Filippov 2023-06-08 01:05:04 +04:00
parent bb539ade32
commit 7993a7cf19
3 changed files with 148 additions and 29 deletions

133
main.py
View File

@ -1,12 +1,14 @@
#!/usr/bin/env python3
import os
import sys
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
from typing import List
import numpy
import numpy as np
import pandas as pd
import scipy.cluster.hierarchy as sc
from anytree import RenderTree
from matplotlib import pyplot as plt
from numpy import ndarray
from pandas import Series
@ -15,22 +17,35 @@ from sklearn.decomposition import PCA
from src.main.df_loader import DfLoader
from src.main.georeverse import Georeverse
from src.main.tree_node import TreeNode
from src.main.tree_view import TreeView
is_plots: bool = False
default_clusters: int = 3
MAX_LEVEL: int = 5
GENDERS: dict = {
0: 'не указан',
1: 'женский',
2: 'мужской'
}
georeverse: Georeverse = Georeverse()
is_plots: bool = False
def __plots(data: ndarray, labels: ndarray) -> None:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
print(level)
if data.shape[1] > 1:
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
else:
plt.figure(figsize=(6, 6))
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
plt.title('Dendrogram')
pca = PCA(n_components=2)
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
plt.subplot(1, 2, 2)
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.title('Clustering')
if data.shape[1] > 1:
pca = PCA(n_components=2)
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
plt.subplot(1, 2, 2)
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.title('Clustering')
plt.show()
@ -42,32 +57,92 @@ def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
return np.array(centers)
def __print_center(center: ndarray) -> None:
location: str = georeverse.get_city(center[0], center[1])
sex = round(center[2])
age = round(center[3])
is_university = bool(round(center[4]))
is_work = bool(round(center[5]))
is_student = bool(round(center[6]))
is_schoolboy = bool(round(center[7]))
print(f'location: {location}, sex: {sex}, age: {age},'
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
def __format_center(center: ndarray, level: int) -> str:
if level == 1:
location: str = str(georeverse.get_city(center[0], center[1]))
return location
if level == 3:
age = round(center[0])
return str(age)
if level == 2:
sex = round(center[0])
return GENDERS[sex]
# if level == 4:
# is_university = bool(round(center[0]))
# is_work = bool(round(center[1]))
# is_student = bool(round(center[2]))
# is_schoolboy = bool(round(center[3]))
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
raise Exception(f'Unknown level {level}')
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
model.fit(data)
labels = model.labels_
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
if level == MAX_LEVEL:
return
cl_data = None
clusters = 0
if level == 1:
cl_data = data[:, 0:2]
clusters = 3
if level == 3:
cl_data = data[:, 3].reshape(-1, 1)
clusters = 7
if level == 2:
cl_data = data[:, 2].reshape(-1, 1)
clusters = 2
if level == 4:
univer: int = len(data[np.where(data[:, 4] == 1)])
work: int = len(data[np.where(data[:, 5] == 1)])
student: int = len(data[np.where(data[:, 6] == 1)])
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
TreeNode(1, 'Высшее образование', univer, parent=root_node)
TreeNode(2, 'Работает', work, parent=root_node)
TreeNode(3, 'Студент', student, parent=root_node)
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
return
if cl_data is None:
raise Exception(f'Unknown level {level}')
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
if len(cl_data) > 1:
model.fit(cl_data)
labels = model.labels_
else:
labels = np.array([0])
if plots:
__plots(data, labels)
centers = __get_cluster_centers(data, labels)
for center in centers:
__print_center(center)
__plots(cl_data, labels, level)
centers = __get_cluster_centers(cl_data, labels)
nodes: dict = {}
for index, center in enumerate(centers):
size: int = len(cl_data[numpy.where(labels[:] == index)])
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
nodes[index] = node
if len(centers) == 1:
return
for cluster in range(clusters):
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
def __tree_sort(items):
return sorted(items, key=lambda item: item.data)
def __main(json_file):
data: ndarray = DfLoader(json_file).get_data()
__clustering(data, default_clusters, is_plots)
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
__clustering(data, tree_root, is_plots)
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
root = Tk(className='Clustering')
root.geometry('800x600')
frame = Frame(root)
frame.grid(column=0, row=0, sticky="nsew")
scrollbar = Scrollbar(frame)
scrollbar.pack(side=RIGHT, fill=Y)
tv = TreeView(frame, tree_root)
tv.generate()
scrollbar.config(command=tv.get().yview)
root.rowconfigure(0, weight=1)
root.columnconfigure(0, weight=1)
root.mainloop()
if __name__ == '__main__':

18
src/main/tree_node.py Normal file
View File

@ -0,0 +1,18 @@
from anytree import NodeMixin
class TreeNode(NodeMixin):
def __init__(self, index: int, data: str, size: int, parent: NodeMixin = None) -> None:
super(TreeNode, self).__init__()
self.index = index
self.data = data
self.size = size
self.parent = parent
def __get_percent(self) -> float:
if self.parent is None:
return 100
return round(self.size / self.parent.size * 100, 2)
def __repr__(self) -> str:
return f'{self.data} [{self.size}/{self.__get_percent()}%]'

26
src/main/tree_view.py Normal file
View File

@ -0,0 +1,26 @@
import re
from tkinter import Frame, ttk, LEFT, BOTH, TRUE
from anytree import LevelOrderGroupIter
class TreeView(Frame):
def __init__(self, parent, tree):
super().__init__()
self.tree = tree
self.treeview = ttk.Treeview(parent, height=30)
def get(self):
return self.treeview
@staticmethod
def alphanum_key(key):
return [int(s) if s.isdigit() else s.lower() for s in re.split("([0-9]+)", key)]
def generate(self):
self.treeview.pack(side=LEFT, fill=BOTH, expand=TRUE)
nodes_d = {}
for nodes in LevelOrderGroupIter(self.tree):
for index, node in enumerate(sorted(nodes, key=lambda item: self.alphanum_key(item.data))):
idd = self.treeview.insert('' if node.parent is None else nodes_d[node.parent], index, text=node)
nodes_d[node] = idd