26 lines
905 B
Python

import dataclasses
from backend.classification.dto import ClassificationTreeDto
from backend.dataset.dto import DatasetDto
from backend.dataset.model import DatasetParams
from backend.regression.dto import RegressionTreeDto
from backend.tree.model import DecisionTreeParams
class Desiralizer:
def get_dataset_params(self, data) -> DatasetParams:
field_names = set(f.name for f in dataclasses.fields(DatasetParams))
return DatasetParams(**{k: v for k, v in data.items() if k in field_names})
def get_tree_params(self, data) -> DecisionTreeParams:
field_names = set(f.name for f in dataclasses.fields(DecisionTreeParams))
return DecisionTreeParams(**{k: v for k, v in data.items() if k in field_names})
class RegressionDto(DatasetDto, RegressionTreeDto, Desiralizer):
pass
class ClassificationDto(DatasetDto, ClassificationTreeDto, Desiralizer):
pass