19 lines
753 B
Python

import dataclasses
from backend.dataset.dto import DatasetDto
from backend.dataset.model import DatasetParams
from backend.regression.dto import RegressionTreeDto
from backend.regression.model import RegressionTreeParams
class RegressionDto(DatasetDto, RegressionTreeDto):
def get_dataset_params(self, data) -> DatasetParams:
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
def get_tree_params(self, data) -> RegressionTreeParams:
field_names = set(f.name for f in dataclasses.fields(RegressionTreeParams))
return RegressionTreeParams(
**{k: v for k, v in data.items() if k in field_names}
)