Complete main algorithm, add GUI
This commit is contained in:
parent
bb539ade32
commit
7993a7cf19
133
main.py
133
main.py
@ -1,12 +1,14 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import scipy.cluster.hierarchy as sc
|
||||
from anytree import RenderTree
|
||||
from matplotlib import pyplot as plt
|
||||
from numpy import ndarray
|
||||
from pandas import Series
|
||||
@ -15,22 +17,35 @@ from sklearn.decomposition import PCA
|
||||
|
||||
from src.main.df_loader import DfLoader
|
||||
from src.main.georeverse import Georeverse
|
||||
from src.main.tree_node import TreeNode
|
||||
from src.main.tree_view import TreeView
|
||||
|
||||
is_plots: bool = False
|
||||
default_clusters: int = 3
|
||||
MAX_LEVEL: int = 5
|
||||
GENDERS: dict = {
|
||||
0: 'не указан',
|
||||
1: 'женский',
|
||||
2: 'мужской'
|
||||
}
|
||||
georeverse: Georeverse = Georeverse()
|
||||
|
||||
is_plots: bool = False
|
||||
|
||||
def __plots(data: ndarray, labels: ndarray) -> None:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.subplot(1, 2, 1)
|
||||
|
||||
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
|
||||
print(level)
|
||||
if data.shape[1] > 1:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.subplot(1, 2, 1)
|
||||
else:
|
||||
plt.figure(figsize=(6, 6))
|
||||
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
||||
plt.title('Dendrogram')
|
||||
pca = PCA(n_components=2)
|
||||
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||
plt.title('Clustering')
|
||||
if data.shape[1] > 1:
|
||||
pca = PCA(n_components=2)
|
||||
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||
plt.title('Clustering')
|
||||
plt.show()
|
||||
|
||||
|
||||
@ -42,32 +57,92 @@ def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
||||
return np.array(centers)
|
||||
|
||||
|
||||
def __print_center(center: ndarray) -> None:
|
||||
location: str = georeverse.get_city(center[0], center[1])
|
||||
sex = round(center[2])
|
||||
age = round(center[3])
|
||||
is_university = bool(round(center[4]))
|
||||
is_work = bool(round(center[5]))
|
||||
is_student = bool(round(center[6]))
|
||||
is_schoolboy = bool(round(center[7]))
|
||||
print(f'location: {location}, sex: {sex}, age: {age},'
|
||||
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
|
||||
def __format_center(center: ndarray, level: int) -> str:
|
||||
if level == 1:
|
||||
location: str = str(georeverse.get_city(center[0], center[1]))
|
||||
return location
|
||||
if level == 3:
|
||||
age = round(center[0])
|
||||
return str(age)
|
||||
if level == 2:
|
||||
sex = round(center[0])
|
||||
return GENDERS[sex]
|
||||
# if level == 4:
|
||||
# is_university = bool(round(center[0]))
|
||||
# is_work = bool(round(center[1]))
|
||||
# is_student = bool(round(center[2]))
|
||||
# is_schoolboy = bool(round(center[3]))
|
||||
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
|
||||
raise Exception(f'Unknown level {level}')
|
||||
|
||||
|
||||
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
|
||||
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
|
||||
model.fit(data)
|
||||
labels = model.labels_
|
||||
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
|
||||
if level == MAX_LEVEL:
|
||||
return
|
||||
cl_data = None
|
||||
clusters = 0
|
||||
if level == 1:
|
||||
cl_data = data[:, 0:2]
|
||||
clusters = 3
|
||||
if level == 3:
|
||||
cl_data = data[:, 3].reshape(-1, 1)
|
||||
clusters = 7
|
||||
if level == 2:
|
||||
cl_data = data[:, 2].reshape(-1, 1)
|
||||
clusters = 2
|
||||
if level == 4:
|
||||
univer: int = len(data[np.where(data[:, 4] == 1)])
|
||||
work: int = len(data[np.where(data[:, 5] == 1)])
|
||||
student: int = len(data[np.where(data[:, 6] == 1)])
|
||||
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
|
||||
TreeNode(1, 'Высшее образование', univer, parent=root_node)
|
||||
TreeNode(2, 'Работает', work, parent=root_node)
|
||||
TreeNode(3, 'Студент', student, parent=root_node)
|
||||
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
|
||||
return
|
||||
if cl_data is None:
|
||||
raise Exception(f'Unknown level {level}')
|
||||
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
||||
if len(cl_data) > 1:
|
||||
model.fit(cl_data)
|
||||
labels = model.labels_
|
||||
else:
|
||||
labels = np.array([0])
|
||||
if plots:
|
||||
__plots(data, labels)
|
||||
centers = __get_cluster_centers(data, labels)
|
||||
for center in centers:
|
||||
__print_center(center)
|
||||
__plots(cl_data, labels, level)
|
||||
centers = __get_cluster_centers(cl_data, labels)
|
||||
nodes: dict = {}
|
||||
for index, center in enumerate(centers):
|
||||
size: int = len(cl_data[numpy.where(labels[:] == index)])
|
||||
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
|
||||
nodes[index] = node
|
||||
if len(centers) == 1:
|
||||
return
|
||||
for cluster in range(clusters):
|
||||
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
|
||||
|
||||
|
||||
def __tree_sort(items):
|
||||
return sorted(items, key=lambda item: item.data)
|
||||
|
||||
|
||||
def __main(json_file):
|
||||
data: ndarray = DfLoader(json_file).get_data()
|
||||
__clustering(data, default_clusters, is_plots)
|
||||
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
|
||||
__clustering(data, tree_root, is_plots)
|
||||
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
|
||||
root = Tk(className='Clustering')
|
||||
root.geometry('800x600')
|
||||
frame = Frame(root)
|
||||
frame.grid(column=0, row=0, sticky="nsew")
|
||||
scrollbar = Scrollbar(frame)
|
||||
scrollbar.pack(side=RIGHT, fill=Y)
|
||||
tv = TreeView(frame, tree_root)
|
||||
tv.generate()
|
||||
scrollbar.config(command=tv.get().yview)
|
||||
root.rowconfigure(0, weight=1)
|
||||
root.columnconfigure(0, weight=1)
|
||||
root.mainloop()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
18
src/main/tree_node.py
Normal file
18
src/main/tree_node.py
Normal file
@ -0,0 +1,18 @@
|
||||
from anytree import NodeMixin
|
||||
|
||||
|
||||
class TreeNode(NodeMixin):
|
||||
def __init__(self, index: int, data: str, size: int, parent: NodeMixin = None) -> None:
|
||||
super(TreeNode, self).__init__()
|
||||
self.index = index
|
||||
self.data = data
|
||||
self.size = size
|
||||
self.parent = parent
|
||||
|
||||
def __get_percent(self) -> float:
|
||||
if self.parent is None:
|
||||
return 100
|
||||
return round(self.size / self.parent.size * 100, 2)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f'{self.data} [{self.size}/{self.__get_percent()}%]'
|
26
src/main/tree_view.py
Normal file
26
src/main/tree_view.py
Normal file
@ -0,0 +1,26 @@
|
||||
import re
|
||||
from tkinter import Frame, ttk, LEFT, BOTH, TRUE
|
||||
|
||||
from anytree import LevelOrderGroupIter
|
||||
|
||||
|
||||
class TreeView(Frame):
|
||||
def __init__(self, parent, tree):
|
||||
super().__init__()
|
||||
self.tree = tree
|
||||
self.treeview = ttk.Treeview(parent, height=30)
|
||||
|
||||
def get(self):
|
||||
return self.treeview
|
||||
|
||||
@staticmethod
|
||||
def alphanum_key(key):
|
||||
return [int(s) if s.isdigit() else s.lower() for s in re.split("([0-9]+)", key)]
|
||||
|
||||
def generate(self):
|
||||
self.treeview.pack(side=LEFT, fill=BOTH, expand=TRUE)
|
||||
nodes_d = {}
|
||||
for nodes in LevelOrderGroupIter(self.tree):
|
||||
for index, node in enumerate(sorted(nodes, key=lambda item: self.alphanum_key(item.data))):
|
||||
idd = self.treeview.insert('' if node.parent is None else nodes_d[node.parent], index, text=node)
|
||||
nodes_d[node] = idd
|
Loading…
Reference in New Issue
Block a user