Improve read dataset function
This commit is contained in:
parent
bc0c132a35
commit
e44671c259
@ -3,6 +3,7 @@ import uuid
|
||||
|
||||
import pandas as pd
|
||||
from pandas import DataFrame
|
||||
from pandas.errors import ParserError
|
||||
from sklearn.model_selection import train_test_split
|
||||
from werkzeug import utils
|
||||
|
||||
@ -33,10 +34,21 @@ class Dataset:
|
||||
return file_name
|
||||
|
||||
def read(self, params: DatasetParams) -> DataFrame:
|
||||
df = pd.read_csv(self.__file_name, sep=params.sep, decimal=params.decimal)
|
||||
df = None
|
||||
try:
|
||||
df = pd.read_csv(self.__file_name, sep=params.sep, decimal=params.decimal)
|
||||
except ParserError:
|
||||
raise Exception(
|
||||
"Can't parse dataset. Try to use correct 'sep' and 'decimal' values."
|
||||
)
|
||||
if df.columns.size < 2:
|
||||
raise Exception(
|
||||
"Dataset contains less than 2 columns. "
|
||||
"Try to use correct 'sep' parameter value."
|
||||
)
|
||||
params.target = params.target or df.columns[-1]
|
||||
if params.input is not None:
|
||||
return df[params.input + [params.target]]
|
||||
|
||||
return df
|
||||
|
||||
def __split(
|
||||
@ -46,9 +58,8 @@ class Dataset:
|
||||
random_state: int,
|
||||
is_classification: bool = False,
|
||||
) -> SplittedDataset:
|
||||
target = params.target or data.columns[-1]
|
||||
X = data.drop([params.target], axis=1)
|
||||
y = data[[target]]
|
||||
y = data[[params.target]]
|
||||
stratify = None if not is_classification else y
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X,
|
||||
|
Loading…
x
Reference in New Issue
Block a user