Complete main algorithm, add GUI
This commit is contained in:
parent
bb539ade32
commit
7993a7cf19
133
main.py
133
main.py
@ -1,12 +1,14 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import scipy.cluster.hierarchy as sc
|
import scipy.cluster.hierarchy as sc
|
||||||
|
from anytree import RenderTree
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
from pandas import Series
|
from pandas import Series
|
||||||
@ -15,22 +17,35 @@ from sklearn.decomposition import PCA
|
|||||||
|
|
||||||
from src.main.df_loader import DfLoader
|
from src.main.df_loader import DfLoader
|
||||||
from src.main.georeverse import Georeverse
|
from src.main.georeverse import Georeverse
|
||||||
|
from src.main.tree_node import TreeNode
|
||||||
|
from src.main.tree_view import TreeView
|
||||||
|
|
||||||
is_plots: bool = False
|
MAX_LEVEL: int = 5
|
||||||
default_clusters: int = 3
|
GENDERS: dict = {
|
||||||
|
0: 'не указан',
|
||||||
|
1: 'женский',
|
||||||
|
2: 'мужской'
|
||||||
|
}
|
||||||
georeverse: Georeverse = Georeverse()
|
georeverse: Georeverse = Georeverse()
|
||||||
|
|
||||||
|
is_plots: bool = False
|
||||||
|
|
||||||
def __plots(data: ndarray, labels: ndarray) -> None:
|
|
||||||
plt.figure(figsize=(12, 6))
|
def __plots(data: ndarray, labels: ndarray, level: int) -> None:
|
||||||
plt.subplot(1, 2, 1)
|
print(level)
|
||||||
|
if data.shape[1] > 1:
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.subplot(1, 2, 1)
|
||||||
|
else:
|
||||||
|
plt.figure(figsize=(6, 6))
|
||||||
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
||||||
plt.title('Dendrogram')
|
plt.title('Dendrogram')
|
||||||
pca = PCA(n_components=2)
|
if data.shape[1] > 1:
|
||||||
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
pca = PCA(n_components=2)
|
||||||
plt.subplot(1, 2, 2)
|
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
||||||
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
plt.subplot(1, 2, 2)
|
||||||
plt.title('Clustering')
|
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||||
|
plt.title('Clustering')
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
|
|
||||||
@ -42,32 +57,92 @@ def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
|||||||
return np.array(centers)
|
return np.array(centers)
|
||||||
|
|
||||||
|
|
||||||
def __print_center(center: ndarray) -> None:
|
def __format_center(center: ndarray, level: int) -> str:
|
||||||
location: str = georeverse.get_city(center[0], center[1])
|
if level == 1:
|
||||||
sex = round(center[2])
|
location: str = str(georeverse.get_city(center[0], center[1]))
|
||||||
age = round(center[3])
|
return location
|
||||||
is_university = bool(round(center[4]))
|
if level == 3:
|
||||||
is_work = bool(round(center[5]))
|
age = round(center[0])
|
||||||
is_student = bool(round(center[6]))
|
return str(age)
|
||||||
is_schoolboy = bool(round(center[7]))
|
if level == 2:
|
||||||
print(f'location: {location}, sex: {sex}, age: {age},'
|
sex = round(center[0])
|
||||||
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
|
return GENDERS[sex]
|
||||||
|
# if level == 4:
|
||||||
|
# is_university = bool(round(center[0]))
|
||||||
|
# is_work = bool(round(center[1]))
|
||||||
|
# is_student = bool(round(center[2]))
|
||||||
|
# is_schoolboy = bool(round(center[3]))
|
||||||
|
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
|
||||||
|
raise Exception(f'Unknown level {level}')
|
||||||
|
|
||||||
|
|
||||||
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
|
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None:
|
||||||
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
|
if level == MAX_LEVEL:
|
||||||
model.fit(data)
|
return
|
||||||
labels = model.labels_
|
cl_data = None
|
||||||
|
clusters = 0
|
||||||
|
if level == 1:
|
||||||
|
cl_data = data[:, 0:2]
|
||||||
|
clusters = 3
|
||||||
|
if level == 3:
|
||||||
|
cl_data = data[:, 3].reshape(-1, 1)
|
||||||
|
clusters = 7
|
||||||
|
if level == 2:
|
||||||
|
cl_data = data[:, 2].reshape(-1, 1)
|
||||||
|
clusters = 2
|
||||||
|
if level == 4:
|
||||||
|
univer: int = len(data[np.where(data[:, 4] == 1)])
|
||||||
|
work: int = len(data[np.where(data[:, 5] == 1)])
|
||||||
|
student: int = len(data[np.where(data[:, 6] == 1)])
|
||||||
|
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
|
||||||
|
TreeNode(1, 'Высшее образование', univer, parent=root_node)
|
||||||
|
TreeNode(2, 'Работает', work, parent=root_node)
|
||||||
|
TreeNode(3, 'Студент', student, parent=root_node)
|
||||||
|
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
|
||||||
|
return
|
||||||
|
if cl_data is None:
|
||||||
|
raise Exception(f'Unknown level {level}')
|
||||||
|
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
||||||
|
if len(cl_data) > 1:
|
||||||
|
model.fit(cl_data)
|
||||||
|
labels = model.labels_
|
||||||
|
else:
|
||||||
|
labels = np.array([0])
|
||||||
if plots:
|
if plots:
|
||||||
__plots(data, labels)
|
__plots(cl_data, labels, level)
|
||||||
centers = __get_cluster_centers(data, labels)
|
centers = __get_cluster_centers(cl_data, labels)
|
||||||
for center in centers:
|
nodes: dict = {}
|
||||||
__print_center(center)
|
for index, center in enumerate(centers):
|
||||||
|
size: int = len(cl_data[numpy.where(labels[:] == index)])
|
||||||
|
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
|
||||||
|
nodes[index] = node
|
||||||
|
if len(centers) == 1:
|
||||||
|
return
|
||||||
|
for cluster in range(clusters):
|
||||||
|
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
|
||||||
|
|
||||||
|
|
||||||
|
def __tree_sort(items):
|
||||||
|
return sorted(items, key=lambda item: item.data)
|
||||||
|
|
||||||
|
|
||||||
def __main(json_file):
|
def __main(json_file):
|
||||||
data: ndarray = DfLoader(json_file).get_data()
|
data: ndarray = DfLoader(json_file).get_data()
|
||||||
__clustering(data, default_clusters, is_plots)
|
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data))
|
||||||
|
__clustering(data, tree_root, is_plots)
|
||||||
|
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
|
||||||
|
root = Tk(className='Clustering')
|
||||||
|
root.geometry('800x600')
|
||||||
|
frame = Frame(root)
|
||||||
|
frame.grid(column=0, row=0, sticky="nsew")
|
||||||
|
scrollbar = Scrollbar(frame)
|
||||||
|
scrollbar.pack(side=RIGHT, fill=Y)
|
||||||
|
tv = TreeView(frame, tree_root)
|
||||||
|
tv.generate()
|
||||||
|
scrollbar.config(command=tv.get().yview)
|
||||||
|
root.rowconfigure(0, weight=1)
|
||||||
|
root.columnconfigure(0, weight=1)
|
||||||
|
root.mainloop()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
18
src/main/tree_node.py
Normal file
18
src/main/tree_node.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from anytree import NodeMixin
|
||||||
|
|
||||||
|
|
||||||
|
class TreeNode(NodeMixin):
|
||||||
|
def __init__(self, index: int, data: str, size: int, parent: NodeMixin = None) -> None:
|
||||||
|
super(TreeNode, self).__init__()
|
||||||
|
self.index = index
|
||||||
|
self.data = data
|
||||||
|
self.size = size
|
||||||
|
self.parent = parent
|
||||||
|
|
||||||
|
def __get_percent(self) -> float:
|
||||||
|
if self.parent is None:
|
||||||
|
return 100
|
||||||
|
return round(self.size / self.parent.size * 100, 2)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f'{self.data} [{self.size}/{self.__get_percent()}%]'
|
26
src/main/tree_view.py
Normal file
26
src/main/tree_view.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
import re
|
||||||
|
from tkinter import Frame, ttk, LEFT, BOTH, TRUE
|
||||||
|
|
||||||
|
from anytree import LevelOrderGroupIter
|
||||||
|
|
||||||
|
|
||||||
|
class TreeView(Frame):
|
||||||
|
def __init__(self, parent, tree):
|
||||||
|
super().__init__()
|
||||||
|
self.tree = tree
|
||||||
|
self.treeview = ttk.Treeview(parent, height=30)
|
||||||
|
|
||||||
|
def get(self):
|
||||||
|
return self.treeview
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def alphanum_key(key):
|
||||||
|
return [int(s) if s.isdigit() else s.lower() for s in re.split("([0-9]+)", key)]
|
||||||
|
|
||||||
|
def generate(self):
|
||||||
|
self.treeview.pack(side=LEFT, fill=BOTH, expand=TRUE)
|
||||||
|
nodes_d = {}
|
||||||
|
for nodes in LevelOrderGroupIter(self.tree):
|
||||||
|
for index, node in enumerate(sorted(nodes, key=lambda item: self.alphanum_key(item.data))):
|
||||||
|
idd = self.treeview.insert('' if node.parent is None else nodes_d[node.parent], index, text=node)
|
||||||
|
nodes_d[node] = idd
|
Loading…
Reference in New Issue
Block a user