#91 -- Fix fuzzy inference

This commit is contained in:
Anton Romanov 2023-04-11 16:36:05 +04:00
parent 5c6b388901
commit 4596a25a59
3 changed files with 28 additions and 63 deletions

View File

@ -8,7 +8,6 @@ import ru.ulstu.extractor.rule.service.AntecedentValueService;
import ru.ulstu.extractor.rule.service.DbRuleService; import ru.ulstu.extractor.rule.service.DbRuleService;
import ru.ulstu.extractor.rule.service.FuzzyInferenceService; import ru.ulstu.extractor.rule.service.FuzzyInferenceService;
import ru.ulstu.extractor.ts.model.TimeSeries; import ru.ulstu.extractor.ts.model.TimeSeries;
import ru.ulstu.extractor.ts.model.TimeSeriesValue;
import ru.ulstu.extractor.ts.service.TimeSeriesService; import ru.ulstu.extractor.ts.service.TimeSeriesService;
import java.util.ArrayList; import java.util.ArrayList;
@ -57,9 +56,7 @@ public class AssessmentService {
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), timeSeriesService.getLastTimeSeriesTendency(ts))); timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), timeSeriesService.getLastTimeSeriesTendency(ts)));
return fuzzyInferenceService.getFuzzyInference(dbRules, return fuzzyInferenceService.getFuzzyInference(dbRules,
antecedentValueService.getList(), antecedentValueService.getList(),
variableValues, variableValues);
getTSsMin(timeSeries),
getTSsMax(timeSeries));
} }
private List<Assessment> getAssessmentsByTimeSeriesTendencies(List<DbRule> dbRules, List<TimeSeries> timeSeries) { private List<Assessment> getAssessmentsByTimeSeriesTendencies(List<DbRule> dbRules, List<TimeSeries> timeSeries) {
@ -75,9 +72,7 @@ public class AssessmentService {
.getLastTimeSeriesTendency(ts))); .getLastTimeSeriesTendency(ts)));
return fuzzyInferenceService.getFuzzyInference(List.of(dbRule), return fuzzyInferenceService.getFuzzyInference(List.of(dbRule),
antecedentValueService.getList(), antecedentValueService.getList(),
variableValues, variableValues).stream();
getTSsMin(timeSeries),
getTSsMax(timeSeries)).stream();
}) })
.sorted(Comparator.comparing(Assessment::getDegree)) .sorted(Comparator.comparing(Assessment::getDegree))
.collect(Collectors.toList()); .collect(Collectors.toList());
@ -88,44 +83,6 @@ public class AssessmentService {
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), ts.getValues().get(ts.getValues().size() - 1).getValue())); timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), ts.getValues().get(ts.getValues().size() - 1).getValue()));
return fuzzyInferenceService.getFuzzyInference(dbRules, return fuzzyInferenceService.getFuzzyInference(dbRules,
antecedentValueService.getList(), antecedentValueService.getList(),
variableValues, variableValues);
getTSsMin(timeSeries),
getTSsMax(timeSeries));
}
private Double getMin(List<Double> values) {
return values.stream().mapToDouble(v -> v).min().getAsDouble();
}
private Map.Entry<String, Double> getTSMin(TimeSeries ts) {
return Map.entry(ts.getTimeSeriesType().name(),
getMin(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
}
private Map<String, Double> getTSsMin(List<TimeSeries> tss) {
Map<String, Double> res = new HashMap<>();
tss.forEach(ts -> {
Map.Entry<String, Double> entry = getTSMin(ts);
res.put(entry.getKey(), entry.getValue());
});
return res;
}
private Double getMax(List<Double> values) {
return values.stream().mapToDouble(v -> v).max().getAsDouble();
}
private Map.Entry<String, Double> getTSMax(TimeSeries ts) {
return Map.entry(ts.getTimeSeriesType().name(),
getMax(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
}
private Map<String, Double> getTSsMax(List<TimeSeries> tss) {
Map<String, Double> res = new HashMap<>();
tss.forEach(ts -> {
Map.Entry<String, Double> entry = getTSMax(ts);
res.put(entry.getKey(), entry.getValue());
});
return res;
} }
} }

View File

@ -34,8 +34,7 @@ public class FuzzyInferenceService {
+ " is %s"; + " is %s";
private final static String NO_RESULT = "Нет результата"; private final static String NO_RESULT = "Нет результата";
private List<String> getRulesFromDb(List<DbRule> dbRules, Map<String, Double> variableValues) { private List<String> mapRulesToString(List<DbRule> dbRules) {
validateVariables(variableValues, dbRules);
return dbRules.stream().map(this::getFuzzyRule).collect(Collectors.toList()); return dbRules.stream().map(this::getFuzzyRule).collect(Collectors.toList());
} }
@ -51,8 +50,6 @@ public class FuzzyInferenceService {
private RuleBlock getRuleBlock(Engine engine, private RuleBlock getRuleBlock(Engine engine,
List<DbRule> dbRules, List<DbRule> dbRules,
Map<String, Double> variableValues, Map<String, Double> variableValues,
Map<String, Double> min,
Map<String, Double> max,
List<AntecedentValue> antecedentValues, List<AntecedentValue> antecedentValues,
List<Integer> consequentValues) { List<Integer> consequentValues) {
variableValues.forEach((key, value) -> { variableValues.forEach((key, value) -> {
@ -61,16 +58,16 @@ public class FuzzyInferenceService {
input.setDescription(""); input.setDescription("");
input.setEnabled(true); input.setEnabled(true);
double delta = antecedentValues.size() > 1 double delta = antecedentValues.size() > 1
? (max.get(key) - min.get(key)) / (antecedentValues.size() - 1) ? 2.0 / (antecedentValues.size() - 1)
: (max.get(key) - min.get(key)); : 2.0;
input.setRange(min.get(key), max.get(key)); input.setRange(-1, 1);
input.setLockValueInRange(false); input.setLockValueInRange(false);
for (int i = 0; i < antecedentValues.size(); i++) { for (int i = 0; i < antecedentValues.size(); i++) {
input.addTerm( input.addTerm(
new Triangle( new Triangle(
antecedentValues.get(i).getAntecedentValue(), antecedentValues.get(i).getAntecedentValue(),
min.get(key) + i * delta - 0.5 * delta, -1 + i * delta - 0.5 * delta,
min.get(key) + i * delta + delta + 0.5 * delta -1 + i * delta + delta + 0.5 * delta
) )
); );
} }
@ -87,7 +84,7 @@ public class FuzzyInferenceService {
output.setDefaultValue(Double.NaN); output.setDefaultValue(Double.NaN);
output.setLockValueInRange(false); output.setLockValueInRange(false);
for (int i = 0; i < consequentValues.size(); i++) { for (int i = 0; i < consequentValues.size(); i++) {
output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 2.1)); output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 1));
} }
engine.addOutputVariable(output); engine.addOutputVariable(output);
@ -99,7 +96,7 @@ public class FuzzyInferenceService {
//mamdani.setDisjunction(null); //mamdani.setDisjunction(null);
mamdani.setImplication(new AlgebraicProduct()); mamdani.setImplication(new AlgebraicProduct());
mamdani.setActivation(new General()); mamdani.setActivation(new General());
getRulesFromDb(dbRules, variableValues).forEach(r -> { mapRulesToString(dbRules).forEach(r -> {
LOG.info(r); LOG.info(r);
mamdani.addRule(Rule.parse(r, engine)); mamdani.addRule(Rule.parse(r, engine));
}); });
@ -115,12 +112,11 @@ public class FuzzyInferenceService {
public List<Assessment> getFuzzyInference(List<DbRule> dbRules, public List<Assessment> getFuzzyInference(List<DbRule> dbRules,
List<AntecedentValue> antecedentValues, List<AntecedentValue> antecedentValues,
Map<String, Double> variableValues, Map<String, Double> variableValues) {
Map<String, Double> min, validateVariables(variableValues, dbRules);
Map<String, Double> max) {
Engine engine = getFuzzyEngine(); Engine engine = getFuzzyEngine();
List<Integer> consequentValues = dbRules.stream().map(DbRule::getId).collect(Collectors.toList()); List<Integer> consequentValues = dbRules.stream().map(DbRule::getId).collect(Collectors.toList());
engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, min, max, antecedentValues, consequentValues)); engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, antecedentValues, consequentValues));
Map.Entry<String, Double> consequent = getConsequent(engine, variableValues); Map.Entry<String, Double> consequent = getConsequent(engine, variableValues);
if (consequent.getKey().equals(NO_RESULT)) { if (consequent.getKey().equals(NO_RESULT)) {
return new ArrayList<>(); return new ArrayList<>();

View File

@ -127,17 +127,29 @@ public class TimeSeriesService {
public double getLastTimeSeriesTendency(TimeSeries ts) { public double getLastTimeSeriesTendency(TimeSeries ts) {
if (ts != null && ts.getValues().size() > MIN_TIME_SERIES_LENGTH) { if (ts != null && ts.getValues().size() > MIN_TIME_SERIES_LENGTH) {
JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(ts))); JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(normalizeTimeSeries(ts))));
LOG.debug("Успешно отправлен на сервис сглаживания"); LOG.debug("Успешно отправлен на сервис сглаживания");
if (response.has("response") && response.getString("response").equals("empty")) { if (response.has("response") && response.getString("response").equals("empty")) {
return DEFAULT_TIME_SERIES_TENDENCY; return DEFAULT_TIME_SERIES_TENDENCY;
} }
JSONArray jsonArray = response.getJSONObject("timeSeries").getJSONArray("values"); JSONArray jsonArray = response.getJSONObject("timeSeries").getJSONArray("values");
return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value"); return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value") -
jsonArray.getJSONObject(jsonArray.length() - 2).getDouble("value");
} }
return DEFAULT_TIME_SERIES_TENDENCY; return DEFAULT_TIME_SERIES_TENDENCY;
} }
private TimeSeries normalizeTimeSeries(TimeSeries ts) {
double sum = ts.getValues().stream().mapToDouble(TimeSeriesValue::getValue).sum();
if (sum > 0.0d) {
for (int i = 0; i < ts.getValues().size(); i++) {
ts.getValues().get(i).setValue(ts.getValues().get(i).getValue() / sum);
}
}
return ts;
}
public boolean isBranchContainsAllTimeSeries(Branch b) { public boolean isBranchContainsAllTimeSeries(Branch b) {
List<TimeSeries> timeSeries = getByBranch(b.getId()); List<TimeSeries> timeSeries = getByBranch(b.getId());
return Stream.of(TimeSeriesType.values()).allMatch(type -> timeSeries return Stream.of(TimeSeriesType.values()).allMatch(type -> timeSeries