#91 -- Fix fuzzy inference
This commit is contained in:
parent
5c6b388901
commit
4596a25a59
@ -8,7 +8,6 @@ import ru.ulstu.extractor.rule.service.AntecedentValueService;
|
||||
import ru.ulstu.extractor.rule.service.DbRuleService;
|
||||
import ru.ulstu.extractor.rule.service.FuzzyInferenceService;
|
||||
import ru.ulstu.extractor.ts.model.TimeSeries;
|
||||
import ru.ulstu.extractor.ts.model.TimeSeriesValue;
|
||||
import ru.ulstu.extractor.ts.service.TimeSeriesService;
|
||||
|
||||
import java.util.ArrayList;
|
||||
@ -57,9 +56,7 @@ public class AssessmentService {
|
||||
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), timeSeriesService.getLastTimeSeriesTendency(ts)));
|
||||
return fuzzyInferenceService.getFuzzyInference(dbRules,
|
||||
antecedentValueService.getList(),
|
||||
variableValues,
|
||||
getTSsMin(timeSeries),
|
||||
getTSsMax(timeSeries));
|
||||
variableValues);
|
||||
}
|
||||
|
||||
private List<Assessment> getAssessmentsByTimeSeriesTendencies(List<DbRule> dbRules, List<TimeSeries> timeSeries) {
|
||||
@ -75,9 +72,7 @@ public class AssessmentService {
|
||||
.getLastTimeSeriesTendency(ts)));
|
||||
return fuzzyInferenceService.getFuzzyInference(List.of(dbRule),
|
||||
antecedentValueService.getList(),
|
||||
variableValues,
|
||||
getTSsMin(timeSeries),
|
||||
getTSsMax(timeSeries)).stream();
|
||||
variableValues).stream();
|
||||
})
|
||||
.sorted(Comparator.comparing(Assessment::getDegree))
|
||||
.collect(Collectors.toList());
|
||||
@ -88,44 +83,6 @@ public class AssessmentService {
|
||||
timeSeries.forEach(ts -> variableValues.put(ts.getTimeSeriesType().name(), ts.getValues().get(ts.getValues().size() - 1).getValue()));
|
||||
return fuzzyInferenceService.getFuzzyInference(dbRules,
|
||||
antecedentValueService.getList(),
|
||||
variableValues,
|
||||
getTSsMin(timeSeries),
|
||||
getTSsMax(timeSeries));
|
||||
}
|
||||
|
||||
private Double getMin(List<Double> values) {
|
||||
return values.stream().mapToDouble(v -> v).min().getAsDouble();
|
||||
}
|
||||
|
||||
private Map.Entry<String, Double> getTSMin(TimeSeries ts) {
|
||||
return Map.entry(ts.getTimeSeriesType().name(),
|
||||
getMin(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
|
||||
}
|
||||
|
||||
private Map<String, Double> getTSsMin(List<TimeSeries> tss) {
|
||||
Map<String, Double> res = new HashMap<>();
|
||||
tss.forEach(ts -> {
|
||||
Map.Entry<String, Double> entry = getTSMin(ts);
|
||||
res.put(entry.getKey(), entry.getValue());
|
||||
});
|
||||
return res;
|
||||
}
|
||||
|
||||
private Double getMax(List<Double> values) {
|
||||
return values.stream().mapToDouble(v -> v).max().getAsDouble();
|
||||
}
|
||||
|
||||
private Map.Entry<String, Double> getTSMax(TimeSeries ts) {
|
||||
return Map.entry(ts.getTimeSeriesType().name(),
|
||||
getMax(ts.getValues().stream().map(TimeSeriesValue::getValue).collect(Collectors.toList())));
|
||||
}
|
||||
|
||||
private Map<String, Double> getTSsMax(List<TimeSeries> tss) {
|
||||
Map<String, Double> res = new HashMap<>();
|
||||
tss.forEach(ts -> {
|
||||
Map.Entry<String, Double> entry = getTSMax(ts);
|
||||
res.put(entry.getKey(), entry.getValue());
|
||||
});
|
||||
return res;
|
||||
variableValues);
|
||||
}
|
||||
}
|
||||
|
@ -34,8 +34,7 @@ public class FuzzyInferenceService {
|
||||
+ " is %s";
|
||||
private final static String NO_RESULT = "Нет результата";
|
||||
|
||||
private List<String> getRulesFromDb(List<DbRule> dbRules, Map<String, Double> variableValues) {
|
||||
validateVariables(variableValues, dbRules);
|
||||
private List<String> mapRulesToString(List<DbRule> dbRules) {
|
||||
return dbRules.stream().map(this::getFuzzyRule).collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@ -51,8 +50,6 @@ public class FuzzyInferenceService {
|
||||
private RuleBlock getRuleBlock(Engine engine,
|
||||
List<DbRule> dbRules,
|
||||
Map<String, Double> variableValues,
|
||||
Map<String, Double> min,
|
||||
Map<String, Double> max,
|
||||
List<AntecedentValue> antecedentValues,
|
||||
List<Integer> consequentValues) {
|
||||
variableValues.forEach((key, value) -> {
|
||||
@ -61,16 +58,16 @@ public class FuzzyInferenceService {
|
||||
input.setDescription("");
|
||||
input.setEnabled(true);
|
||||
double delta = antecedentValues.size() > 1
|
||||
? (max.get(key) - min.get(key)) / (antecedentValues.size() - 1)
|
||||
: (max.get(key) - min.get(key));
|
||||
input.setRange(min.get(key), max.get(key));
|
||||
? 2.0 / (antecedentValues.size() - 1)
|
||||
: 2.0;
|
||||
input.setRange(-1, 1);
|
||||
input.setLockValueInRange(false);
|
||||
for (int i = 0; i < antecedentValues.size(); i++) {
|
||||
input.addTerm(
|
||||
new Triangle(
|
||||
antecedentValues.get(i).getAntecedentValue(),
|
||||
min.get(key) + i * delta - 0.5 * delta,
|
||||
min.get(key) + i * delta + delta + 0.5 * delta
|
||||
-1 + i * delta - 0.5 * delta,
|
||||
-1 + i * delta + delta + 0.5 * delta
|
||||
)
|
||||
);
|
||||
}
|
||||
@ -87,7 +84,7 @@ public class FuzzyInferenceService {
|
||||
output.setDefaultValue(Double.NaN);
|
||||
output.setLockValueInRange(false);
|
||||
for (int i = 0; i < consequentValues.size(); i++) {
|
||||
output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 2.1));
|
||||
output.addTerm(new Triangle(consequentValues.get(i).toString(), i, i + 1));
|
||||
}
|
||||
engine.addOutputVariable(output);
|
||||
|
||||
@ -99,7 +96,7 @@ public class FuzzyInferenceService {
|
||||
//mamdani.setDisjunction(null);
|
||||
mamdani.setImplication(new AlgebraicProduct());
|
||||
mamdani.setActivation(new General());
|
||||
getRulesFromDb(dbRules, variableValues).forEach(r -> {
|
||||
mapRulesToString(dbRules).forEach(r -> {
|
||||
LOG.info(r);
|
||||
mamdani.addRule(Rule.parse(r, engine));
|
||||
});
|
||||
@ -115,12 +112,11 @@ public class FuzzyInferenceService {
|
||||
|
||||
public List<Assessment> getFuzzyInference(List<DbRule> dbRules,
|
||||
List<AntecedentValue> antecedentValues,
|
||||
Map<String, Double> variableValues,
|
||||
Map<String, Double> min,
|
||||
Map<String, Double> max) {
|
||||
Map<String, Double> variableValues) {
|
||||
validateVariables(variableValues, dbRules);
|
||||
Engine engine = getFuzzyEngine();
|
||||
List<Integer> consequentValues = dbRules.stream().map(DbRule::getId).collect(Collectors.toList());
|
||||
engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, min, max, antecedentValues, consequentValues));
|
||||
engine.addRuleBlock(getRuleBlock(engine, dbRules, variableValues, antecedentValues, consequentValues));
|
||||
Map.Entry<String, Double> consequent = getConsequent(engine, variableValues);
|
||||
if (consequent.getKey().equals(NO_RESULT)) {
|
||||
return new ArrayList<>();
|
||||
|
@ -127,17 +127,29 @@ public class TimeSeriesService {
|
||||
|
||||
public double getLastTimeSeriesTendency(TimeSeries ts) {
|
||||
if (ts != null && ts.getValues().size() > MIN_TIME_SERIES_LENGTH) {
|
||||
JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(ts)));
|
||||
JSONObject response = httpService.post(TIME_SERIES_TENDENCY_URL, new JSONObject(new SmoothingTimeSeries(normalizeTimeSeries(ts))));
|
||||
LOG.debug("Успешно отправлен на сервис сглаживания");
|
||||
if (response.has("response") && response.getString("response").equals("empty")) {
|
||||
return DEFAULT_TIME_SERIES_TENDENCY;
|
||||
}
|
||||
JSONArray jsonArray = response.getJSONObject("timeSeries").getJSONArray("values");
|
||||
return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value");
|
||||
return jsonArray.getJSONObject(jsonArray.length() - 1).getDouble("value") -
|
||||
jsonArray.getJSONObject(jsonArray.length() - 2).getDouble("value");
|
||||
}
|
||||
return DEFAULT_TIME_SERIES_TENDENCY;
|
||||
}
|
||||
|
||||
private TimeSeries normalizeTimeSeries(TimeSeries ts) {
|
||||
double sum = ts.getValues().stream().mapToDouble(TimeSeriesValue::getValue).sum();
|
||||
|
||||
if (sum > 0.0d) {
|
||||
for (int i = 0; i < ts.getValues().size(); i++) {
|
||||
ts.getValues().get(i).setValue(ts.getValues().get(i).getValue() / sum);
|
||||
}
|
||||
}
|
||||
return ts;
|
||||
}
|
||||
|
||||
public boolean isBranchContainsAllTimeSeries(Branch b) {
|
||||
List<TimeSeries> timeSeries = getByBranch(b.getId());
|
||||
return Stream.of(TimeSeriesType.values()).allMatch(type -> timeSeries
|
||||
|
Loading…
Reference in New Issue
Block a user