#!/usr/bin/env python3 import os import sys from tkinter import Tk, Frame, Scrollbar, RIGHT, Y from typing import List import numpy import numpy as np import pandas as pd import scipy.cluster.hierarchy as sc from anytree import RenderTree from matplotlib import pyplot as plt from numpy import ndarray from pandas import Series from sklearn.cluster import AgglomerativeClustering from sklearn.decomposition import PCA from src.main.df_loader import DfLoader from src.main.georeverse import Georeverse from src.main.tree_node import TreeNode from src.main.tree_view import TreeView MAX_LEVEL: int = 5 GENDERS: dict = { 0: 'не указан', 1: 'женский', 2: 'мужской' } georeverse: Georeverse = Georeverse() is_plots: bool = False def __plots(data: ndarray, labels: ndarray, level: int) -> None: print(level) if data.shape[1] > 1: plt.figure(figsize=(12, 6)) plt.subplot(1, 2, 1) else: plt.figure(figsize=(6, 6)) sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level') plt.title('Dendrogram') if data.shape[1] > 1: pca = PCA(n_components=2) transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy() plt.subplot(1, 2, 2) plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow') plt.title('Clustering') plt.show() def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray: centers: List[List[float]] = list() for label in set(labels): center: Series = data[numpy.where(labels[:] == label)].mean(axis=0) centers.append(list(center)) return np.array(centers) def __format_center(center: ndarray, level: int) -> str: if level == 1: location: str = str(georeverse.get_city(center[0], center[1])) return location if level == 3: age = round(center[0]) return str(age) if level == 2: sex = round(center[0]) return GENDERS[sex] # if level == 4: # is_university = bool(round(center[0])) # is_work = bool(round(center[1])) # is_student = bool(round(center[2])) # is_schoolboy = bool(round(center[3])) # return f'univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}' raise Exception(f'Unknown level {level}') def __clustering(data: ndarray, root_node: TreeNode, plots: bool = False, level: int = 1) -> None: if level == MAX_LEVEL: return cl_data = None clusters = 0 if level == 1: cl_data = data[:, 0:2] clusters = 3 if level == 3: cl_data = data[:, 3].reshape(-1, 1) clusters = 7 if level == 2: cl_data = data[:, 2].reshape(-1, 1) clusters = 2 if level == 4: univer: int = len(data[np.where(data[:, 4] == 1)]) work: int = len(data[np.where(data[:, 5] == 1)]) student: int = len(data[np.where(data[:, 6] == 1)]) schoolboy: int = len(data[np.where(data[:, 7] == 1)]) TreeNode(1, 'Высшее образование', univer, parent=root_node) TreeNode(2, 'Работает', work, parent=root_node) TreeNode(3, 'Студент', student, parent=root_node) TreeNode(4, 'Школьник', schoolboy, parent=root_node) return if cl_data is None: raise Exception(f'Unknown level {level}') model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward') if len(cl_data) > 1: model.fit(cl_data) labels = model.labels_ else: labels = np.array([0]) if plots: __plots(cl_data, labels, level) centers = __get_cluster_centers(cl_data, labels) nodes: dict = {} for index, center in enumerate(centers): size: int = len(cl_data[numpy.where(labels[:] == index)]) node: TreeNode = TreeNode(index, __format_center(center, level), size, parent=root_node) nodes[index] = node if len(centers) == 1: return for cluster in range(clusters): __clustering(data[numpy.where(labels[:] == cluster)], nodes[cluster], plots, level + 1) def __tree_sort(items): return sorted(items, key=lambda item: item.data) def __main(json_file): data: ndarray = DfLoader(json_file).get_data() tree_root: TreeNode = TreeNode(0, f'ROOT', len(data)) __clustering(data, tree_root, is_plots) print(print('\n'.join([f'{pre}{node}' for pre, fill, node in RenderTree(tree_root, childiter=__tree_sort)]))) root = Tk(className='Clustering') root.geometry('800x600') frame = Frame(root) frame.grid(column=0, row=0, sticky="nsew") scrollbar = Scrollbar(frame) scrollbar.pack(side=RIGHT, fill=Y) tv = TreeView(frame, tree_root) tv.generate() scrollbar.config(command=tv.get().yview) root.rowconfigure(0, weight=1) root.columnconfigure(0, weight=1) root.mainloop() if __name__ == '__main__': if len(sys.argv) != 2: print('You must specify the raw_dataset json file') exit(1) if not os.path.isfile(sys.argv[1]): print(f'File {sys.argv[1]} is not exists') __main(sys.argv[1])