import dataclasses from backend.dataset.dto import DatasetDto from backend.dataset.model import DatasetParams from backend.regression.dto import RegressionTreeDto from backend.regression.model import RegressionTreeParams class RegressionDto(DatasetDto, RegressionTreeDto): def get_dataset_params(self, data) -> DatasetParams: field_names = set(f.name for f in dataclasses.fields(DatasetParams)) return DatasetParams(**{k: v for k, v in data.items() if k in field_names}) def get_tree_params(self, data) -> RegressionTreeParams: field_names = set(f.name for f in dataclasses.fields(RegressionTreeParams)) return RegressionTreeParams( **{k: v for k, v in data.items() if k in field_names} )