Compare commits

..

No commits in common. "7993a7cf191f89a820886e351b051010708f2077" and "155b350e1eaf1821d23b4c286e74abe7438ba499" have entirely different histories.

5 changed files with 32 additions and 152 deletions

133
main.py
View File

@ -1,14 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import os import os
import sys import sys
from tkinter import Tk, Frame, Scrollbar, RIGHT, Y
from typing import List from typing import List
import numpy import numpy
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import scipy.cluster.hierarchy as sc import scipy.cluster.hierarchy as sc
from anytree import RenderTree
from matplotlib import pyplot as plt from matplotlib import pyplot as plt
from numpy import ndarray from numpy import ndarray
from pandas import Series from pandas import Series
@ -17,35 +15,22 @@ from sklearn.decomposition import PCA
from src.main.df_loader import DfLoader from src.main.df_loader import DfLoader
from src.main.georeverse import Georeverse from src.main.georeverse import Georeverse
from src.main.tree_node import TreeNode
from src.main.tree_view import TreeView
MAX_LEVEL: int = 5
GENDERS: dict = {
0: 'не указан',
1: 'женский',
2: 'мужской'
}
georeverse: Georeverse = Georeverse()
is_plots: bool = False is_plots: bool = False
default_clusters: int = 3
georeverse: Georeverse = Georeverse()
def __plots(data: ndarray, labels: ndarray, level: int) -> None: def __plots(data: ndarray, labels: ndarray) -> None:
print(level) plt.figure(figsize=(12, 6))
if data.shape[1] > 1: plt.subplot(1, 2, 1)
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
else:
plt.figure(figsize=(6, 6))
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level') sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
plt.title('Dendrogram') plt.title('Dendrogram')
if data.shape[1] > 1: pca = PCA(n_components=2)
pca = PCA(n_components=2) transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() plt.subplot(1, 2, 2)
plt.subplot(1, 2, 2) plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') plt.title('Clustering')
plt.title('Clustering')
plt.show() plt.show()
@ -57,92 +42,32 @@ def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
return np.array(centers) return np.array(centers)
def __format_center(center: ndarray, level: int) -> str: def __print_center(center: ndarray) -> None:
if level == 1: location: str = georeverse.get_city(center[0], center[1])
location: str = str(georeverse.get_city(center[0], center[1])) sex = round(center[2])
return location age = round(center[3])
if level == 3: is_university = bool(round(center[4]))
age = round(center[0]) is_work = bool(round(center[5]))
return str(age) is_student = bool(round(center[6]))
if level == 2: is_schoolboy = bool(round(center[7]))
sex = round(center[0]) print(f'location: {location}, sex: {sex}, age: {age},'
return GENDERS[sex] f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
# if level == 4:
# is_university = bool(round(center[0]))
# is_work = bool(round(center[1]))
# is_student = bool(round(center[2]))
# is_schoolboy = bool(round(center[3]))
# return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}'
raise Exception(f'Unknown level {level}')
def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None: def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
if level == MAX_LEVEL: model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
return model.fit(data)
cl_data = None labels = model.labels_
clusters = 0
if level == 1:
cl_data = data[:, 0:2]
clusters = 3
if level == 3:
cl_data = data[:, 3].reshape(-1, 1)
clusters = 7
if level == 2:
cl_data = data[:, 2].reshape(-1, 1)
clusters = 2
if level == 4:
univer: int = len(data[np.where(data[:, 4] == 1)])
work: int = len(data[np.where(data[:, 5] == 1)])
student: int = len(data[np.where(data[:, 6] == 1)])
schoolboy: int = len(data[np.where(data[:, 7] == 1)])
TreeNode(1, 'Высшее образование', univer, parent=root_node)
TreeNode(2, 'Работает', work, parent=root_node)
TreeNode(3, 'Студент', student, parent=root_node)
TreeNode(4, 'Школьник', schoolboy, parent=root_node)
return
if cl_data is None:
raise Exception(f'Unknown level {level}')
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
if len(cl_data) > 1:
model.fit(cl_data)
labels = model.labels_
else:
labels = np.array([0])
if plots: if plots:
__plots(cl_data, labels, level) __plots(data, labels)
centers = __get_cluster_centers(cl_data, labels) centers = __get_cluster_centers(data, labels)
nodes: dict = {} for center in centers:
for index, center in enumerate(centers): __print_center(center)
size: int = len(cl_data[numpy.where(labels[:] == index)])
node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node)
nodes[index] = node
if len(centers) == 1:
return
for cluster in range(clusters):
__clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1)
def __tree_sort(items):
return sorted(items, key=lambda item: item.data)
def __main(json_file): def __main(json_file):
data: ndarray = DfLoader(json_file).get_data() data: ndarray = DfLoader(json_file).get_data()
tree_root: TreeNode = TreeNode(0, f'ROOT', len(data)) __clustering(data, default_clusters, is_plots)
__clustering(data, tree_root, is_plots)
print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)])))
root = Tk(className='Clustering')
root.geometry('800x600')
frame = Frame(root)
frame.grid(column=0, row=0, sticky="nsew")
scrollbar = Scrollbar(frame)
scrollbar.pack(side=RIGHT, fill=Y)
tv = TreeView(frame, tree_root)
tv.generate()
scrollbar.config(command=tv.get().yview)
root.rowconfigure(0, weight=1)
root.columnconfigure(0, weight=1)
root.mainloop()
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -3,5 +3,5 @@ geopy==2.3.0
numpy==1.24.3 numpy==1.24.3
scikit-learn==1.2.2 scikit-learn==1.2.2
matplotlib==3.7.1 matplotlib==3.7.1
scipy==1.10.1 seaborn==0.12.2
anytree==2.8.0 scipy==1.10.1

View File

@ -1,13 +1,12 @@
from functools import partial from functools import partial
from geopy import Nominatim from geopy import Nominatim
from geopy.extra.rate_limiter import RateLimiter
class Georeverse: class Georeverse:
def __init__(self) -> None: def __init__(self) -> None:
geolocator: Nominatim = Nominatim(user_agent="MyApp") geolocator: Nominatim = Nominatim(user_agent="MyApp")
self.__reverse = RateLimiter(partial(geolocator.reverse, language="ru"), min_delay_seconds=1) self.__reverse = partial(geolocator.reverse, language="ru")
def get_city(self, latitude: float, longitude: float) -> str: def get_city(self, latitude: float, longitude: float) -> str:
return self.__reverse(f'{latitude}, {longitude}') return self.__reverse(f'{latitude}, {longitude}')

View File

@ -1,18 +0,0 @@
from anytree import NodeMixin
class TreeNode(NodeMixin):
def __init__(self, index: int, data: str, size: int, parent: NodeMixin = None) -> None:
super(TreeNode, self).__init__()
self.index = index
self.data = data
self.size = size
self.parent = parent
def __get_percent(self) -> float:
if self.parent is None:
return 100
return round(self.size / self.parent.size * 100, 2)
def __repr__(self) -> str:
return f'{self.data} [{self.size}/{self.__get_percent()}%]'

View File

@ -1,26 +0,0 @@
import re
from tkinter import Frame, ttk, LEFT, BOTH, TRUE
from anytree import LevelOrderGroupIter
class TreeView(Frame):
def __init__(self, parent, tree):
super().__init__()
self.tree = tree
self.treeview = ttk.Treeview(parent, height=30)
def get(self):
return self.treeview
@staticmethod
def alphanum_key(key):
return [int(s) if s.isdigit() else s.lower() for s in re.split("([0-9]+)", key)]
def generate(self):
self.treeview.pack(side=LEFT, fill=BOTH, expand=TRUE)
nodes_d = {}
for nodes in LevelOrderGroupIter(self.tree):
for index, node in enumerate(sorted(nodes, key=lambda item: self.alphanum_key(item.data))):
idd = self.treeview.insert('' if node.parent is None else nodes_d[node.parent], index, text=node)
nodes_d[node] = idd