Add cluster centers extraction
This commit is contained in:
parent
f4a32bf57f
commit
155b350e1e
88
main.py
88
main.py
@ -1,59 +1,73 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import sys
|
||||
from typing import List
|
||||
|
||||
import numpy
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
# import scipy.cluster.hierarchy as sc
|
||||
import scipy.cluster.hierarchy as sc
|
||||
from matplotlib import pyplot as plt
|
||||
from pandas import DataFrame
|
||||
from numpy import ndarray
|
||||
from pandas import Series
|
||||
from sklearn.cluster import AgglomerativeClustering
|
||||
from sklearn.decomposition import PCA
|
||||
|
||||
from src.main.df_loader import DfLoader
|
||||
from src.main.georeverse import Georeverse
|
||||
|
||||
is_plots: bool = False
|
||||
default_clusters: int = 3
|
||||
georeverse: Georeverse = Georeverse()
|
||||
|
||||
|
||||
def __clustering(data: DataFrame) -> None:
|
||||
# clusters = round(math.sqrt(len(data) / 2))
|
||||
# plt.figure(figsize=(20, 7))
|
||||
# plt.title("Dendrograms")
|
||||
# # Create dendrogram
|
||||
# sc.dendrogram(sc.linkage(data.to_numpy(), method='ward'))
|
||||
# plt.title('Dendrogram')
|
||||
# plt.xlabel('Sample index')
|
||||
# plt.ylabel('Euclidean distance')
|
||||
|
||||
clusters = 3
|
||||
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
||||
model.fit(data)
|
||||
labels = model.labels_
|
||||
|
||||
data_norm = (data - data.min()) / (data.max() - data.min())
|
||||
|
||||
pca = PCA(n_components=2) # 2-dimensional PCA
|
||||
transformed = pd.DataFrame(pca.fit_transform(data_norm))
|
||||
# plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||
for i in range(clusters):
|
||||
series = transformed.iloc[numpy.where(labels[:] == i)]
|
||||
plt.scatter(series[0], series[1], label=f'Cluster {i + 1}')
|
||||
plt.legend()
|
||||
def __plots(data: ndarray, labels: ndarray) -> None:
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.subplot(1, 2, 1)
|
||||
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
||||
plt.title('Dendrogram')
|
||||
pca = PCA(n_components=2)
|
||||
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
||||
plt.subplot(1, 2, 2)
|
||||
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||
plt.title('Clustering')
|
||||
plt.show()
|
||||
|
||||
# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
|
||||
# sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering')
|
||||
# sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \
|
||||
# .set_title('With clustering')
|
||||
# plt.show()
|
||||
|
||||
# s = numpy.where(labels[:] == 34)
|
||||
# print(labels)
|
||||
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
||||
centers: List[List[float]] = list()
|
||||
for label in set(labels):
|
||||
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
|
||||
centers.append(list(center))
|
||||
return np.array(centers)
|
||||
|
||||
|
||||
def __print_center(center: ndarray) -> None:
|
||||
location: str = georeverse.get_city(center[0], center[1])
|
||||
sex = round(center[2])
|
||||
age = round(center[3])
|
||||
is_university = bool(round(center[4]))
|
||||
is_work = bool(round(center[5]))
|
||||
is_student = bool(round(center[6]))
|
||||
is_schoolboy = bool(round(center[7]))
|
||||
print(f'location: {location}, sex: {sex}, age: {age},'
|
||||
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
|
||||
|
||||
|
||||
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
|
||||
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
|
||||
model.fit(data)
|
||||
labels = model.labels_
|
||||
if plots:
|
||||
__plots(data, labels)
|
||||
centers = __get_cluster_centers(data, labels)
|
||||
for center in centers:
|
||||
__print_center(center)
|
||||
|
||||
|
||||
def __main(json_file):
|
||||
df_loader: DfLoader = DfLoader(json_file)
|
||||
data = df_loader.get_clustering_data()
|
||||
print(data)
|
||||
__clustering(data)
|
||||
data: ndarray = DfLoader(json_file).get_data()
|
||||
__clustering(data, default_clusters, is_plots)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
12
src/main/georeverse.py
Normal file
12
src/main/georeverse.py
Normal file
@ -0,0 +1,12 @@
|
||||
from functools import partial
|
||||
|
||||
from geopy import Nominatim
|
||||
|
||||
|
||||
class Georeverse:
|
||||
def __init__(self) -> None:
|
||||
geolocator: Nominatim = Nominatim(user_agent="MyApp")
|
||||
self.__reverse = partial(geolocator.reverse, language="ru")
|
||||
|
||||
def get_city(self, latitude: float, longitude: float) -> str:
|
||||
return self.__reverse(f'{latitude}, {longitude}')
|
Loading…
Reference in New Issue
Block a user