#91 -- Fix fuzzy inference

This commit is contained in:
Anton Romanov 2023-04-11 16:36:05 +04:00
parent 5c6b388901
commit 4596a25a59
3 changed files with 28 additions and 63 deletions

View File

@ -8,7 +8,6 @@ import ru.ulstu.extractor.rule.service.AntecedentValueService;
import ru.ulstu.extractor.rule.service.DbRuleService;
import ru.ulstu.extractor.rule.service.FuzzyInferenceService;
import ru.ulstu.extractor.ts.model.TimeSeries;
import ru.ulstu.extractor.ts.model.TimeSeriesValue;
import ru.ulstu.extractor.ts.service.TimeSeriesService;
import java.util.ArrayList;
@ -57,9 +56,7 @@ public class AssessmentService {
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), timeSeriesService.getLastTimeSeriesTendency(ts)));
return fuzzyInferenceService.getFuzzyInference(dbRules,
antecedentValueService.getList(),
variableValues,
getTSsMin(timeSeries),
getTSsMax(timeSeries));
variableValues);
}
private List<Assessment> getAssessmentsByTimeSeriesTendencies(List<DbRule> dbRules, List<TimeSeries> timeSeries) {
@ -75,9 +72,7 @@ public class AssessmentService {
.getLastTimeSeriesTendency(ts)));
return fuzzyInferenceService.getFuzzyInference(List.of(dbRule),
antecedentValueService.getList(),
variableValues,
getTSsMin(timeSeries),
getTSsMax(timeSeries)).stream();
variableValues).stream();
})
.sorted(Comparator.comparing(Assessment::getDegree))
.collect(Collectors.toList());
@ -88,44 +83,6 @@ public class AssessmentService {
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), ts.getValues().get(ts.getValues().size() - 1).getValue()));
return fuzzyInferenceService.getFuzzyInference(dbRules,
antecedentValueService.getList(),
variableValues,
getTSsMin(timeSeries),
getTSsMax(timeSeries));
}
private Double getMin(List<Double> values) {
return values.stream().mapToDouble(v -> v).min().getAsDouble();
}
private Map.Entry<String, Double> getTSMin(TimeSeries ts) {
return Map.entry(ts.getTimeSeriesType().name(),
getMin(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
}
private Map<String, Double> getTSsMin(List<TimeSeries> tss) {
Map<String, Double> res = new HashMap<>();
tss.forEach(ts -> {
Map.Entry<String, Double> entry = getTSMin(ts);
res.put(entry.getKey(), entry.getValue());
});
return res;
}
private Double getMax(List<Double> values) {
return values.stream().mapToDouble(v -> v).max().getAsDouble();
}
private Map.Entry<String, Double> getTSMax(TimeSeries ts) {
return Map.entry(ts.getTimeSeriesType().name(),
getMax(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
}
private Map<String, Double> getTSsMax(List<TimeSeries> tss) {
Map<String, Double> res = new HashMap<>();
tss.forEach(ts -> {
Map.Entry<String, Double> entry = getTSMax(ts);
res.put(entry.getKey(), entry.getValue());
});
return res;
variableValues);
}
}

View File

@ -34,8 +34,7 @@ public class FuzzyInferenceService {
+ " is %s";
private final static String NO_RESULT = "Нет результата";
private List<String> getRulesFromDb(List<DbRule> dbRules, Map<String, Double> variableValues) {
validateVariables(variableValues, dbRules);
private List<String> mapRulesToString(List<DbRule> dbRules) {
return dbRules.stream().map(this::getFuzzyRule).collect(Collectors.toList());
}
@ -51,8 +50,6 @@ public class FuzzyInferenceService {
private RuleBlock getRuleBlock(Engine engine,
List<DbRule> dbRules,
Map<String, Double> variableValues,
Map<String, Double> min,
Map<String, Double> max,
List<AntecedentValue> antecedentValues,
List<Integer> consequentValues) {
variableValues.forEach((key, value) -> {
@ -61,16 +58,16 @@ public class FuzzyInferenceService {
input.setDescription("");
input.setEnabled(true);
double delta = antecedentValues.size() > 1
? (max.get(key) - min.get(key)) / (antecedentValues.size() - 1)
: (max.get(key) - min.get(key));
input.setRange(min.get(key), max.get(key));
? 2.0 / (antecedentValues.size() - 1)
: 2.0;
input.setRange(-1, 1);
input.setLockValueInRange(false);
for (int i = 0; i < antecedentValues.size(); i++) {
input.addTerm(
new Triangle(
antecedentValues.get(i).getAntecedentValue(),
min.get(key) + i * delta - 0.5 * delta,
min.get(key) + i * delta + delta + 0.5 * delta
-1 + i * delta - 0.5 * delta,
-1 + i * delta + delta + 0.5 * delta
)
);
}
@ -87,7 +84,7 @@ public class FuzzyInferenceService {
output.setDefaultValue(Double.NaN);
output.setLockValueInRange(false);
for (int i = 0; i < consequentValues.size(); i++) {
output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 2.1));
output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 1));
}
engine.addOutputVariable(output);
@ -99,7 +96,7 @@ public class FuzzyInferenceService {
//mamdani.setDisjunction(null);
mamdani.setImplication(new AlgebraicProduct());
mamdani.setActivation(new General());
getRulesFromDb(dbRules, variableValues).forEach(r -> {
mapRulesToString(dbRules).forEach(r -> {
LOG.info(r);
mamdani.addRule(Rule.parse(r, engine));
});
@ -115,12 +112,11 @@ public class FuzzyInferenceService {
public List<Assessment> getFuzzyInference(List<DbRule> dbRules,
List<AntecedentValue> antecedentValues,
Map<String, Double> variableValues,
Map<String, Double> min,
Map<String, Double> max) {
Map<String, Double> variableValues) {
validateVariables(variableValues, dbRules);
Engine engine = getFuzzyEngine();
List<Integer> consequentValues = dbRules.stream().map(DbRule::getId).collect(Collectors.toList());
engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, min, max, antecedentValues, consequentValues));
engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, antecedentValues, consequentValues));
Map.Entry<String, Double> consequent = getConsequent(engine, variableValues);
if (consequent.getKey().equals(NO_RESULT)) {
return new ArrayList<>();

View File

@ -127,17 +127,29 @@ public class TimeSeriesService {
public double getLastTimeSeriesTendency(TimeSeries ts) {
if (ts != null && ts.getValues().size() > MIN_TIME_SERIES_LENGTH) {
JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(ts)));
JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(normalizeTimeSeries(ts))));
LOG.debug("Успешно отправлен на сервис сглаживания");
if (response.has("response") && response.getString("response").equals("empty")) {
return DEFAULT_TIME_SERIES_TENDENCY;
}
JSONArray jsonArray = response.getJSONObject("timeSeries").getJSONArray("values");
return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value");
return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value") -
jsonArray.getJSONObject(jsonArray.length() - 2).getDouble("value");
}
return DEFAULT_TIME_SERIES_TENDENCY;
}
private TimeSeries normalizeTimeSeries(TimeSeries ts) {
double sum = ts.getValues().stream().mapToDouble(TimeSeriesValue::getValue).sum();
if (sum > 0.0d) {
for (int i = 0; i < ts.getValues().size(); i++) {
ts.getValues().get(i).setValue(ts.getValues().get(i).getValue() / sum);
}
}
return ts;
}
public boolean isBranchContainsAllTimeSeries(Branch b) {
List<TimeSeries> timeSeries = getByBranch(b.getId());
return Stream.of(TimeSeriesType.values()).allMatch(type -> timeSeries