import dataclasses from backend.classification.dto import ClassificationTreeDto from backend.dataset.dto import DatasetDto from backend.dataset.model import DatasetParams from backend.regression.dto import RegressionTreeDto from backend.tree.model import DecisionTreeParams class Desiralizer: def get_dataset_params(self, data) -> DatasetParams: field_names = set(f.name for f in dataclasses.fields(DatasetParams)) return DatasetParams(**{k: v for k, v in data.items() if k in field_names}) def get_tree_params(self, data) -> DecisionTreeParams: field_names = set(f.name for f in dataclasses.fields(DecisionTreeParams)) return DecisionTreeParams(**{k: v for k, v in data.items() if k in field_names}) class RegressionDto(DatasetDto, RegressionTreeDto, Desiralizer): pass class ClassificationDto(DatasetDto, ClassificationTreeDto, Desiralizer): pass