Compare commits

..

2 Commits

6 changed files with 144079 additions and 2 deletions

2187
cardio.ipynb Normal file

File diff suppressed because one or more lines are too long

2897
cardio_fuzzy.ipynb Normal file

File diff suppressed because one or more lines are too long

Binary file not shown.

68986
data-cardio/cardio_clear.csv Normal file

File diff suppressed because it is too large Load Diff

70001
data-cardio/cardio_train.csv Normal file

File diff suppressed because it is too large Load Diff

View File

@ -71,7 +71,7 @@ class Rule:
# https://mljar.com/blog/extract-rules-decision-tree/
def get_rules(tree, feature_names) -> List[Rule]:
def get_rules(tree, feature_names, classes=None) -> List[Rule]:
tree_ = tree.tree_
feature_name = [
feature_names[i] if i != TREE_UNDEFINED else "undefined!" for i in tree_.feature
@ -91,7 +91,11 @@ def get_rules(tree, feature_names) -> List[Rule]:
p2.append(RuleAtom(name, ComparisonType.GREATER, threshold))
recurse(tree_.children_right[node], p2, rules)
else:
rules.append(Rule(antecedent, tree_.value[node][0][0]))
if classes is None:
rules.append(Rule(antecedent, tree_.value[node][0][0]))
else:
value = np.argmax(tree_.value[node][0])
rules.append(Rule(antecedent, classes[value]))
recurse(0, antecedent, rules)
@ -302,6 +306,8 @@ def _get_fuzzy_rules(
for rule in rules:
antecedent = []
for atom in rule.get_antecedent():
if fuzzy_variables.get(atom.get_varaible(), None) is None:
continue
antecedent.append(
_get_fuzzy_rule_atom(
fuzzy_variables[atom.get_varaible()], atom.get_value()