From 7993a7cf191f89a820886e351b051010708f2077 Mon Sep 17 00:00:00 2001 From: Aleksey Filippov Date: Thu, 8 Jun 2023 01:05:04 +0400 Subject: [PATCH] Complete main algorithm, add GUI --- main.py | 133 +++++++++++++++++++++++++++++++++--------- src/main/tree_node.py | 18 ++++++ src/main/tree_view.py | 26 +++++++++ 3 files changed, 148 insertions(+), 29 deletions(-) create mode 100644 src/main/tree_node.py create mode 100644 src/main/tree_view.py diff --git a/main.py b/main.py index 568fff4..7bffd67 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,14 @@ #!/usr/bin/env python3 import os import sys +from tkinter import Tk, Frame, Scrollbar, RIGHT, Y from typing import List import numpy import numpy as np import pandas as pd import scipy.cluster.hierarchy as sc +from anytree import RenderTree from matplotlib import pyplot as plt from numpy import ndarray from pandas import Series @@ -15,22 +17,35 @@ from sklearn.decomposition import PCA from src.main.df_loader import DfLoader from src.main.georeverse import Georeverse +from src.main.tree_node import TreeNode +from src.main.tree_view import TreeView -is_plots: bool = False -default_clusters: int = 3 +MAX_LEVEL: int = 5 +GENDERS: dict = { + 0: 'не указан', + 1: 'женский', + 2: 'мужской' +} georeverse: Georeverse = Georeverse() +is_plots: bool = False -def __plots(data: ndarray, labels: ndarray) -> None: - plt.figure(figsize=(12, 6)) - plt.subplot(1, 2, 1) + +def __plots(data: ndarray, labels: ndarray, level: int) -> None: + print(level) + if data.shape[1] > 1: + plt.figure(figsize=(12, 6)) + plt.subplot(1, 2, 1) + else: + plt.figure(figsize=(6, 6)) sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level') plt.title('Dendrogram') - pca = PCA(n_components=2) - transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() - plt.subplot(1, 2, 2) - plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') - plt.title('Clustering') + if data.shape[1] > 1: + pca = PCA(n_components=2) + transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() + plt.subplot(1, 2, 2) + plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') + plt.title('Clustering') plt.show() @@ -42,32 +57,92 @@ def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray: return np.array(centers) -def __print_center(center: ndarray) -> None: - location: str = georeverse.get_city(center[0], center[1]) - sex = round(center[2]) - age = round(center[3]) - is_university = bool(round(center[4])) - is_work = bool(round(center[5])) - is_student = bool(round(center[6])) - is_schoolboy = bool(round(center[7])) - print(f'location: {location}, sex: {sex}, age: {age},' - f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}') +def __format_center(center: ndarray, level: int) -> str: + if level == 1: + location: str = str(georeverse.get_city(center[0], center[1])) + return location + if level == 3: + age = round(center[0]) + return str(age) + if level == 2: + sex = round(center[0]) + return GENDERS[sex] + # if level == 4: + # is_university = bool(round(center[0])) + # is_work = bool(round(center[1])) + # is_student = bool(round(center[2])) + # is_schoolboy = bool(round(center[3])) + # return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}' + raise Exception(f'Unknown level {level}') -def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None: - model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward') - model.fit(data) - labels = model.labels_ +def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None: + if level == MAX_LEVEL: + return + cl_data = None + clusters = 0 + if level == 1: + cl_data = data[:, 0:2] + clusters = 3 + if level == 3: + cl_data = data[:, 3].reshape(-1, 1) + clusters = 7 + if level == 2: + cl_data = data[:, 2].reshape(-1, 1) + clusters = 2 + if level == 4: + univer: int = len(data[np.where(data[:, 4] == 1)]) + work: int = len(data[np.where(data[:, 5] == 1)]) + student: int = len(data[np.where(data[:, 6] == 1)]) + schoolboy: int = len(data[np.where(data[:, 7] == 1)]) + TreeNode(1, 'Высшее образование', univer, parent=root_node) + TreeNode(2, 'Работает', work, parent=root_node) + TreeNode(3, 'Студент', student, parent=root_node) + TreeNode(4, 'Школьник', schoolboy, parent=root_node) + return + if cl_data is None: + raise Exception(f'Unknown level {level}') + model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward') + if len(cl_data) > 1: + model.fit(cl_data) + labels = model.labels_ + else: + labels = np.array([0]) if plots: - __plots(data, labels) - centers = __get_cluster_centers(data, labels) - for center in centers: - __print_center(center) + __plots(cl_data, labels, level) + centers = __get_cluster_centers(cl_data, labels) + nodes: dict = {} + for index, center in enumerate(centers): + size: int = len(cl_data[numpy.where(labels[:] == index)]) + node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node) + nodes[index] = node + if len(centers) == 1: + return + for cluster in range(clusters): + __clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1) + + +def __tree_sort(items): + return sorted(items, key=lambda item: item.data) def __main(json_file): data: ndarray = DfLoader(json_file).get_data() - __clustering(data, default_clusters, is_plots) + tree_root: TreeNode = TreeNode(0, f'ROOT', len(data)) + __clustering(data, tree_root, is_plots) + print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)]))) + root = Tk(className='Clustering') + root.geometry('800x600') + frame = Frame(root) + frame.grid(column=0, row=0, sticky="nsew") + scrollbar = Scrollbar(frame) + scrollbar.pack(side=RIGHT, fill=Y) + tv = TreeView(frame, tree_root) + tv.generate() + scrollbar.config(command=tv.get().yview) + root.rowconfigure(0, weight=1) + root.columnconfigure(0, weight=1) + root.mainloop() if __name__ == '__main__': diff --git a/src/main/tree_node.py b/src/main/tree_node.py new file mode 100644 index 0000000..4cb76a1 --- /dev/null +++ b/src/main/tree_node.py @@ -0,0 +1,18 @@ +from anytree import NodeMixin + + +class TreeNode(NodeMixin): + def __init__(self, index: int, data: str, size: int, parent: NodeMixin = None) -> None: + super(TreeNode, self).__init__() + self.index = index + self.data = data + self.size = size + self.parent = parent + + def __get_percent(self) -> float: + if self.parent is None: + return 100 + return round(self.size / self.parent.size * 100, 2) + + def __repr__(self) -> str: + return f'{self.data} [{self.size}/{self.__get_percent()}%]' diff --git a/src/main/tree_view.py b/src/main/tree_view.py new file mode 100644 index 0000000..889b859 --- /dev/null +++ b/src/main/tree_view.py @@ -0,0 +1,26 @@ +import re +from tkinter import Frame, ttk, LEFT, BOTH, TRUE + +from anytree import LevelOrderGroupIter + + +class TreeView(Frame): + def __init__(self, parent, tree): + super().__init__() + self.tree = tree + self.treeview = ttk.Treeview(parent, height=30) + + def get(self): + return self.treeview + + @staticmethod + def alphanum_key(key): + return [int(s) if s.isdigit() else s.lower() for s in re.split("([0-9]+)", key)] + + def generate(self): + self.treeview.pack(side=LEFT, fill=BOTH, expand=TRUE) + nodes_d = {} + for nodes in LevelOrderGroupIter(self.tree): + for index, node in enumerate(sorted(nodes, key=lambda item: self.alphanum_key(item.data))): + idd = self.treeview.insert('' if node.parent is None else nodes_d[node.parent], index, text=node) + nodes_d[node] = idd