Add cluster centers extraction
This commit is contained in:
parent
f4a32bf57f
commit
155b350e1e
88
main.py
88
main.py
@ -1,59 +1,73 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from typing import List
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
# import scipy.cluster.hierarchy as sc
|
import scipy.cluster.hierarchy as sc
|
||||||
from matplotlib import pyplot as plt
|
from matplotlib import pyplot as plt
|
||||||
from pandas import DataFrame
|
from numpy import ndarray
|
||||||
|
from pandas import Series
|
||||||
from sklearn.cluster import AgglomerativeClustering
|
from sklearn.cluster import AgglomerativeClustering
|
||||||
from sklearn.decomposition import PCA
|
from sklearn.decomposition import PCA
|
||||||
|
|
||||||
from src.main.df_loader import DfLoader
|
from src.main.df_loader import DfLoader
|
||||||
|
from src.main.georeverse import Georeverse
|
||||||
|
|
||||||
|
is_plots: bool = False
|
||||||
|
default_clusters: int = 3
|
||||||
|
georeverse: Georeverse = Georeverse()
|
||||||
|
|
||||||
|
|
||||||
def __clustering(data: DataFrame) -> None:
|
def __plots(data: ndarray, labels: ndarray) -> None:
|
||||||
# clusters = round(math.sqrt(len(data) / 2))
|
plt.figure(figsize=(12, 6))
|
||||||
# plt.figure(figsize=(20, 7))
|
plt.subplot(1, 2, 1)
|
||||||
# plt.title("Dendrograms")
|
sc.dendrogram(sc.linkage(data, method='ward'), p=4, truncate_mode='level')
|
||||||
# # Create dendrogram
|
plt.title('Dendrogram')
|
||||||
# sc.dendrogram(sc.linkage(data.to_numpy(), method='ward'))
|
pca = PCA(n_components=2)
|
||||||
# plt.title('Dendrogram')
|
transformed = pd.DataFrame(pca.fit_transform(data)).to_numpy()
|
||||||
# plt.xlabel('Sample index')
|
plt.subplot(1, 2, 2)
|
||||||
# plt.ylabel('Euclidean distance')
|
plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
||||||
|
plt.title('Clustering')
|
||||||
clusters = 3
|
|
||||||
model = AgglomerativeClustering(n_clusters=clusters, metric='euclidean', linkage='ward')
|
|
||||||
model.fit(data)
|
|
||||||
labels = model.labels_
|
|
||||||
|
|
||||||
data_norm = (data - data.min()) / (data.max() - data.min())
|
|
||||||
|
|
||||||
pca = PCA(n_components=2) # 2-dimensional PCA
|
|
||||||
transformed = pd.DataFrame(pca.fit_transform(data_norm))
|
|
||||||
# plt.scatter(x=transformed[:, 0], y=transformed[:, 1], c=labels, cmap='rainbow')
|
|
||||||
for i in range(clusters):
|
|
||||||
series = transformed.iloc[numpy.where(labels[:] == i)]
|
|
||||||
plt.scatter(series[0], series[1], label=f'Cluster {i + 1}')
|
|
||||||
plt.legend()
|
|
||||||
plt.show()
|
plt.show()
|
||||||
|
|
||||||
# fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(15, 5))
|
|
||||||
# sns.scatterplot(ax=axes[0], data=data, x='location-la,location-lo', y='age,sex').set_title('Without clustering')
|
|
||||||
# sns.scatterplot(ax=axes[1], data=data, x='location-la,location-lo', y='age,sex', hue=labels) \
|
|
||||||
# .set_title('With clustering')
|
|
||||||
# plt.show()
|
|
||||||
|
|
||||||
# s = numpy.where(labels[:] == 34)
|
def __get_cluster_centers(data: ndarray, labels: ndarray) -> ndarray:
|
||||||
# print(labels)
|
centers: List[List[float]] = list()
|
||||||
|
for label in set(labels):
|
||||||
|
center: Series = data[numpy.where(labels[:] == label)].mean(axis=0)
|
||||||
|
centers.append(list(center))
|
||||||
|
return np.array(centers)
|
||||||
|
|
||||||
|
|
||||||
|
def __print_center(center: ndarray) -> None:
|
||||||
|
location: str = georeverse.get_city(center[0], center[1])
|
||||||
|
sex = round(center[2])
|
||||||
|
age = round(center[3])
|
||||||
|
is_university = bool(round(center[4]))
|
||||||
|
is_work = bool(round(center[5]))
|
||||||
|
is_student = bool(round(center[6]))
|
||||||
|
is_schoolboy = bool(round(center[7]))
|
||||||
|
print(f'location: {location}, sex: {sex}, age: {age},'
|
||||||
|
f' univer: {is_university}, work: {is_work}, student: {is_student}, school: {is_schoolboy}')
|
||||||
|
|
||||||
|
|
||||||
|
def __clustering(data: ndarray, n_clusters: int = 3, plots: bool = False) -> None:
|
||||||
|
model = AgglomerativeClustering(n_clusters=n_clusters, metric='euclidean', linkage='ward')
|
||||||
|
model.fit(data)
|
||||||
|
labels = model.labels_
|
||||||
|
if plots:
|
||||||
|
__plots(data, labels)
|
||||||
|
centers = __get_cluster_centers(data, labels)
|
||||||
|
for center in centers:
|
||||||
|
__print_center(center)
|
||||||
|
|
||||||
|
|
||||||
def __main(json_file):
|
def __main(json_file):
|
||||||
df_loader: DfLoader = DfLoader(json_file)
|
data: ndarray = DfLoader(json_file).get_data()
|
||||||
data = df_loader.get_clustering_data()
|
__clustering(data, default_clusters, is_plots)
|
||||||
print(data)
|
|
||||||
__clustering(data)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
12
src/main/georeverse.py
Normal file
12
src/main/georeverse.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from geopy import Nominatim
|
||||||
|
|
||||||
|
|
||||||
|
class Georeverse:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
geolocator: Nominatim = Nominatim(user_agent="MyApp")
|
||||||
|
self.__reverse = partial(geolocator.reverse, language="ru")
|
||||||
|
|
||||||
|
def get_city(self, latitude: float, longitude: float) -> str:
|
||||||
|
return self.__reverse(f'{latitude}, {longitude}')
|
Loading…
Reference in New Issue
Block a user