diff --git a/src/main/java/ru/ulstu/score/ScoreMethod.java b/src/main/java/ru/ulstu/score/ScoreMethod.java index 447d191..b85874f 100644 --- a/src/main/java/ru/ulstu/score/ScoreMethod.java +++ b/src/main/java/ru/ulstu/score/ScoreMethod.java @@ -11,6 +11,9 @@ import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.datamodel.ts.TimeSeriesValue; +import java.time.LocalDateTime; +import java.util.Map; + public abstract class ScoreMethod { private final String name; @@ -18,11 +21,11 @@ public abstract class ScoreMethod { this.name = name; } - public Score getScore(TimeSeries original, TimeSeries model) throws ModelingException { - return new Score(this, evaluate(original, model)); + public Score getScore(Map tsValues, TimeSeries model) throws ModelingException { + return new Score(this, evaluate(tsValues, model)); } - public abstract Number evaluate(TimeSeries original, TimeSeries model) throws ModelingException; + public abstract Number evaluate(Map tsValues, TimeSeries model) throws ModelingException; public String getName() { return name; diff --git a/src/main/java/ru/ulstu/score/Smape.java b/src/main/java/ru/ulstu/score/Smape.java index 17a8b2c..1b1655d 100644 --- a/src/main/java/ru/ulstu/score/Smape.java +++ b/src/main/java/ru/ulstu/score/Smape.java @@ -10,6 +10,10 @@ import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.datamodel.ts.TimeSeriesValue; +import java.time.LocalDateTime; +import java.util.Map; +import java.util.Optional; + import static java.lang.Math.abs; public class Smape extends ScoreMethod { @@ -18,10 +22,13 @@ public class Smape extends ScoreMethod { } @Override - public Number evaluate(TimeSeries original, TimeSeries model) throws ModelingException { + public Number evaluate(Map tsValues, TimeSeries model) throws ModelingException { double sum = 0; for (TimeSeriesValue modelValue : model.getValues()) { - double actualValue = getValueOnSameDate(modelValue, original).getValue(); + //double actualValue = getValueOnSameDate(modelValue, original).getValue(); + double actualValue = Optional.ofNullable(tsValues.get(modelValue.getDate())) + .orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: " + + modelValue.getDate())); sum += abs(modelValue.getValue() - actualValue) / ((abs(actualValue) + abs(modelValue.getValue())) / 2); } diff --git a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java index 1a4b8ec..dd7adff 100644 --- a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java +++ b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java @@ -11,6 +11,7 @@ import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.ModelingResult; import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.ts.TimeSeries; +import ru.ulstu.datamodel.ts.TimeSeriesValue; import ru.ulstu.method.Method; import ru.ulstu.method.MethodParamValue; import ru.ulstu.method.MethodParameter; @@ -18,6 +19,7 @@ import ru.ulstu.score.ScoreMethod; import ru.ulstu.score.Smape; import java.lang.reflect.InvocationTargetException; +import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Comparator; import java.util.List; @@ -47,6 +49,10 @@ public class MethodParamBruteForce { final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast; TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()), "test part of " + timeSeries.getName()); + + Map tsValues = timeSeries.getValues().stream() + .collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue)); + for (Method method : methods) { List> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); for (List parametersValues : availableParametersValues) { @@ -56,7 +62,7 @@ public class MethodParamBruteForce { TimeSeries forecast = methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints); return new ModelingResult(forecast, parametersValues, - scoreMethod.getScore(timeSeries, forecast), + scoreMethod.getScore(tsValues, forecast), methodInstance); })); } @@ -87,6 +93,10 @@ public class MethodParamBruteForce { public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { List> results = new ArrayList<>(); List results2 = new CopyOnWriteArrayList<>(); + + Map tsValues = timeSeries.getValues().stream() + .collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue)); + for (Method method : methods) { List> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); for (List parametersValues : availableParametersValues) { @@ -96,7 +106,7 @@ public class MethodParamBruteForce { Model model = methodInstance.getModel(timeSeries, parametersValues); return new ModelingResult(model.getTimeSeriesModel(), parametersValues, - scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()), + scoreMethod.getScore(tsValues, model.getTimeSeriesModel()), methodInstance); })); }