Improve read dataset function
This commit is contained in:
parent
bc0c132a35
commit
e44671c259
@ -3,6 +3,7 @@ import uuid
|
|||||||
|
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
from pandas import DataFrame
|
from pandas import DataFrame
|
||||||
|
from pandas.errors import ParserError
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from werkzeug import utils
|
from werkzeug import utils
|
||||||
|
|
||||||
@ -33,10 +34,21 @@ class Dataset:
|
|||||||
return file_name
|
return file_name
|
||||||
|
|
||||||
def read(self, params: DatasetParams) -> DataFrame:
|
def read(self, params: DatasetParams) -> DataFrame:
|
||||||
df = pd.read_csv(self.__file_name, sep=params.sep, decimal=params.decimal)
|
df = None
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(self.__file_name, sep=params.sep, decimal=params.decimal)
|
||||||
|
except ParserError:
|
||||||
|
raise Exception(
|
||||||
|
"Can't parse dataset. Try to use correct 'sep' and 'decimal' values."
|
||||||
|
)
|
||||||
|
if df.columns.size < 2:
|
||||||
|
raise Exception(
|
||||||
|
"Dataset contains less than 2 columns. "
|
||||||
|
"Try to use correct 'sep' parameter value."
|
||||||
|
)
|
||||||
|
params.target = params.target or df.columns[-1]
|
||||||
if params.input is not None:
|
if params.input is not None:
|
||||||
return df[params.input + [params.target]]
|
return df[params.input + [params.target]]
|
||||||
|
|
||||||
return df
|
return df
|
||||||
|
|
||||||
def __split(
|
def __split(
|
||||||
@ -46,9 +58,8 @@ class Dataset:
|
|||||||
random_state: int,
|
random_state: int,
|
||||||
is_classification: bool = False,
|
is_classification: bool = False,
|
||||||
) -> SplittedDataset:
|
) -> SplittedDataset:
|
||||||
target = params.target or data.columns[-1]
|
|
||||||
X = data.drop([params.target], axis=1)
|
X = data.drop([params.target], axis=1)
|
||||||
y = data[[target]]
|
y = data[[params.target]]
|
||||||
stratify = None if not is_classification else y
|
stratify = None if not is_classification else y
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
X,
|
X,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user