diff --git a/src/main/df_loader.py b/src/main/df_loader.py index 0c8fd0a..48a6942 100644 --- a/src/main/df_loader.py +++ b/src/main/df_loader.py @@ -2,7 +2,6 @@ from datetime import date import numpy as np import pandas as pd -from numpy import ndarray from pandas import DataFrame from src.main.constants import Constants as const @@ -49,27 +48,33 @@ class DfLoader: def __prepare_dataset_status(self) -> None: is_univer_mask = ((self.__df['age'] >= const.university_gr_age()) | (self.__df['age'] == const.empty_age())) & \ ((self.__df['universities'].str.len() > 0) | (self.__df['occupation_type'] == 'university')) - self.__df['is_university'] = np.where(is_univer_mask, True, False) + self.__df['is_university'] = np.where(is_univer_mask, 1, 0) is_work_mask = ((self.__df['age'] > const.school_gr_age()) | (self.__df['age'] == const.empty_age())) & \ - ((self.__df['is_university']) | (self.__df['occupation_type'] == 'work')) | \ + ((self.__df['is_university'] == 1) | (self.__df['occupation_type'] == 'work')) | \ (self.__df['age'] > const.university_gr_age()) - self.__df['is_work'] = np.where(is_work_mask, True, False) + self.__df['is_work'] = np.where(is_work_mask, 1, 0) is_student_mask = ((self.__df['occupation_type'] == 'university') & ((self.__df['age'] >= const.school_gr_age()) & (self.__df['age'] <= const.university_gr_age()))) - self.__df['is_student'] = np.where(is_student_mask, True, False) + self.__df['is_student'] = np.where(is_student_mask, 1, 0) is_schoolboy_mask = ((self.__df['age'] < const.school_gr_age()) & (self.__df['age'] != const.empty_age())) | \ ((self.__df['age'] == const.empty_age()) & (self.__df['occupation_type'] == 'school')) - self.__df['is_schoolboy'] = np.where(is_schoolboy_mask, True, False) + self.__df['is_schoolboy'] = np.where(is_schoolboy_mask, 1, 0) def __prepare_dataset_location(self) -> None: self.__geocache.update_geo_cache(self.__df['city'].unique().tolist()) self.__df['location'] = self.__df['city'] \ .apply(lambda val: '' if Utils.is_empty_str(val) else self.__geocache.get_location(val)) - - def get_clustering_data(self) -> ndarray: - columns: [] = ['location', 'sex', 'age', 'is_university', 'is_work', 'is_student', 'is_schoolboy'] - return self.__df[columns].to_numpy() + self.__df['location-la'] = self.__df.loc[:, 'location'] \ + .apply(lambda val: 0 if Utils.is_empty_collection(val) else val[0]) + self.__df['location-lo'] = self.__df.loc[:, 'location'] \ + .apply(lambda val: 0 if Utils.is_empty_collection(val) else val[1]) + + def get_clustering_data(self) -> DataFrame: + columns: [] = ['location-la', 'location-lo', + 'sex', 'age', 'is_university', 'is_work', 'is_student', 'is_schoolboy'] + df = self.__df + return df[columns]