19 lines
753 B
Python
19 lines
753 B
Python
import dataclasses
|
|
|
|
from backend.dataset.dto import DatasetDto
|
|
from backend.dataset.model import DatasetParams
|
|
from backend.regression.dto import RegressionTreeDto
|
|
from backend.regression.model import RegressionTreeParams
|
|
|
|
|
|
class RegressionDto(DatasetDto, RegressionTreeDto):
|
|
def get_dataset_params(self, data) -> DatasetParams:
|
|
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
|
|
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
|
|
|
|
def get_tree_params(self, data) -> RegressionTreeParams:
|
|
field_names = set(f.name for f in dataclasses.fields(RegressionTreeParams))
|
|
return RegressionTreeParams(
|
|
**{k: v for k, v in data.items() if k in field_names}
|
|
)
|