diff --git a/src/main/df_loader.py b/src/main/df_loader.py index 0cd77a4..0c8fd0a 100644 --- a/src/main/df_loader.py +++ b/src/main/df_loader.py @@ -2,6 +2,7 @@ from datetime import date import numpy as np import pandas as pd +from numpy import ndarray from pandas import DataFrame from src.main.constants import Constants as const @@ -10,6 +11,7 @@ from src.main.utils import Utils class DfLoader: + def __init__(self, json_file: str) -> None: self.__geocache: Geocache = Geocache() print(f'Try to load data from the {json_file} file') @@ -68,5 +70,6 @@ class DfLoader: self.__df['location'] = self.__df['city'] \ .apply(lambda val: '' if Utils.is_empty_str(val) else self.__geocache.get_location(val)) - def get_clustering_data(self) -> DataFrame: - return self.__df + def get_clustering_data(self) -> ndarray: + columns: [] = ['location', 'sex', 'age', 'is_university', 'is_work', 'is_student', 'is_schoolboy'] + return self.__df[columns].to_numpy()