From 90a107648d4d847b817cd97af329197024744da6 Mon Sep 17 00:00:00 2001 From: Anton Romanov Date: Wed, 23 Jun 2021 14:42:02 +0400 Subject: [PATCH] fix forecast by best model --- .../controller/TimeSeriesController.java | 7 +- .../ulstu/datamodel/ts/TimeSeriesValue.java | 6 +- src/main/java/ru/ulstu/method/Method.java | 20 ++++- .../addtrendaddseason/AddTrendAddSeason.java | 8 +- src/main/java/ru/ulstu/page/IndexView.java | 24 +++--- src/main/java/ru/ulstu/score/ScoreMethod.java | 4 +- .../ulstu/service/MethodParamBruteForce.java | 86 ++++++++++++------- .../ru/ulstu/service/TimeSeriesService.java | 14 ++- .../META-INF/resources/basicTemplate.xhtml | 22 ++++- 9 files changed, 122 insertions(+), 69 deletions(-) diff --git a/src/main/java/ru/ulstu/controller/TimeSeriesController.java b/src/main/java/ru/ulstu/controller/TimeSeriesController.java index 154e5d2..6667b64 100644 --- a/src/main/java/ru/ulstu/controller/TimeSeriesController.java +++ b/src/main/java/ru/ulstu/controller/TimeSeriesController.java @@ -21,6 +21,9 @@ import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.service.MethodParamBruteForce; import ru.ulstu.service.TimeSeriesService; +import java.lang.reflect.InvocationTargetException; +import java.util.concurrent.ExecutionException; + @RestController @RequestMapping(ApiConfiguration.API_1_0) public class TimeSeriesController { @@ -36,14 +39,14 @@ public class TimeSeriesController { @PostMapping("getForecast") @ApiOperation("Получить прогноз временного ряда") - public ResponseEntity getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ModelingException { + public ResponseEntity getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { return new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(), forecastParams.getCountForecast()), HttpStatus.OK); } @PostMapping("getSmoothed") @ApiOperation("Получить сглаженный временной ряд") - public ResponseEntity getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ModelingException { + public ResponseEntity getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException { return new ResponseEntity<>(methodParamBruteForce.getSmoothedTimeSeries(timeSeries), HttpStatus.OK); } } diff --git a/src/main/java/ru/ulstu/datamodel/ts/TimeSeriesValue.java b/src/main/java/ru/ulstu/datamodel/ts/TimeSeriesValue.java index fef77f9..f57b01e 100644 --- a/src/main/java/ru/ulstu/datamodel/ts/TimeSeriesValue.java +++ b/src/main/java/ru/ulstu/datamodel/ts/TimeSeriesValue.java @@ -48,7 +48,10 @@ public class TimeSeriesValue { @Override public String toString() { - return value.toString(); + return "TimeSeriesValue{" + + "date=" + date + + ", value=" + value + + '}'; } @Override @@ -64,4 +67,5 @@ public class TimeSeriesValue { public int hashCode() { return Objects.hash(date, value); } + } diff --git a/src/main/java/ru/ulstu/method/Method.java b/src/main/java/ru/ulstu/method/Method.java index 7422c9e..54d9281 100644 --- a/src/main/java/ru/ulstu/method/Method.java +++ b/src/main/java/ru/ulstu/method/Method.java @@ -7,6 +7,8 @@ package ru.ulstu.method; import com.fasterxml.jackson.annotation.JsonIgnore; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import ru.ulstu.TimeSeriesUtils; import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.exception.ForecastValidateException; @@ -22,6 +24,8 @@ import java.util.List; * Наиболее общая логика моделировании и прогнозирования временных рядов */ public abstract class Method { + private static final Logger LOGGER = LoggerFactory.getLogger(Method.class); + @JsonIgnore public abstract List getAvailableParameters(); @@ -31,8 +35,7 @@ public abstract class Method { * * @return модель временного ряда */ - protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries, - List parameters) throws ModelingException; + protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries, List parameters); /** * Возвращает модельное представление временного ряда: для тех же точек времени что и в параметре timeSeries @@ -58,9 +61,10 @@ public abstract class Method { */ protected abstract TimeSeries getForecastWithValidParams(Model model, TimeSeries forecast) throws ModelingException; - public boolean canMakeForecast(TimeSeries timeSeries, int countPoints) { + public boolean canMakeForecast(TimeSeries timeSeries, List parameters, int countPoints) { try { validateTimeSeries(timeSeries); + validateAdditionalParams(timeSeries, parameters); validateForecastParams(countPoints); } catch (ModelingException ex) { return false; @@ -68,6 +72,16 @@ public abstract class Method { return true; } + public boolean canMakeModel(TimeSeries timeSeries, List parameters) { + try { + validateTimeSeries(timeSeries); + validateAdditionalParams(timeSeries, parameters); + } catch (ModelingException ex) { + return false; + } + return true; + } + /** * Выполняет построение модели и прогноза временного ряда. Даты спрогнозированных точек будут сгенерированы * по модельным точкам. diff --git a/src/main/java/ru/ulstu/method/exponential/addtrendaddseason/AddTrendAddSeason.java b/src/main/java/ru/ulstu/method/exponential/addtrendaddseason/AddTrendAddSeason.java index 1cbcdb5..a14df9f 100644 --- a/src/main/java/ru/ulstu/method/exponential/addtrendaddseason/AddTrendAddSeason.java +++ b/src/main/java/ru/ulstu/method/exponential/addtrendaddseason/AddTrendAddSeason.java @@ -9,7 +9,6 @@ package ru.ulstu.method.exponential.addtrendaddseason; import org.springframework.stereotype.Component; import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.exception.ModelingException; -import ru.ulstu.datamodel.exception.TimeSeriesValidateException; import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.method.Method; import ru.ulstu.method.MethodParamValue; @@ -67,11 +66,10 @@ public class AddTrendAddSeason extends Method { for (MethodParamValue parameter : parameters) { if (parameter.getParameter() instanceof Season) { if (ts.getLength() < parameter.getValue().intValue()) { - throw new TimeSeriesValidateException("Период больше чем длина ряда"); + throw new ModelingException("Период больше чем длина ряда"); } } } - } @Override @@ -82,10 +80,10 @@ public class AddTrendAddSeason extends Method { List iComponent = currentModel.getSeasonComponent(); for (int t = 1; t < forecast.getLength(); t++) { iComponent.add(currentModel.getGamma().getDoubleValue() * forecast.getNumericValue(t - 1) / sComponent.get(sComponent.size() - 1) - + (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue())); + + (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1)); forecast.getValues().get(t).setValue((sComponent.get(sComponent.size() - 1) + tComponent.get(tComponent.size() - 1) * t) - * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue())); + * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1)); } return forecast; } diff --git a/src/main/java/ru/ulstu/page/IndexView.java b/src/main/java/ru/ulstu/page/IndexView.java index 539c06c..6f4f500 100644 --- a/src/main/java/ru/ulstu/page/IndexView.java +++ b/src/main/java/ru/ulstu/page/IndexView.java @@ -24,7 +24,9 @@ import javax.annotation.PostConstruct; import javax.faces.view.ViewScoped; import javax.inject.Named; import java.io.Serializable; +import java.lang.reflect.InvocationTargetException; import java.time.format.DateTimeFormatter; +import java.util.concurrent.ExecutionException; @Named @ViewScoped @@ -41,12 +43,12 @@ public class IndexView implements Serializable { private String timeSeriesString; @PostConstruct - public void init() { + public void init() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { timeSeriesString = utilService.getTimeSeriesToString(utilService.getRandomTimeSeries(50)); createChart(); } - private LineChartModel initLinearModel() { + private LineChartModel initLinearModel() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { LineChartModel model = new LineChartModel(); LineChartSeries series1 = new LineChartSeries(); @@ -60,20 +62,18 @@ public class IndexView implements Serializable { LineChartSeries series2 = new LineChartSeries(); series2.setLabel("Сглаженный ряд"); try { - for (TimeSeriesValue value : timeSeriesService.smoothTimeSeries(timeSeries).getValues()) { + TimeSeries smoothedTimeSeries = timeSeriesService.smoothTimeSeries(timeSeries); + for (TimeSeriesValue value : smoothedTimeSeries.getValues()) { series2.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue()); } - } catch (ModelingException ex) { + } catch (Exception ex) { LOG.warn(ex.getMessage()); } LineChartSeries series3 = new LineChartSeries(); series3.setLabel("Прогноз"); - try { - for (TimeSeriesValue value : timeSeriesService.getForecast(timeSeries, 20).getValues()) { - series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue()); - } - } catch (ModelingException ex) { - LOG.warn(ex.getMessage()); + TimeSeries forecast = timeSeriesService.getForecast(timeSeries, 20).getTimeSeries(); + for (TimeSeriesValue value : forecast.getValues()) { + series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue()); } model.addSeries(series1); model.addSeries(series2); @@ -81,7 +81,7 @@ public class IndexView implements Serializable { return model; } - public void createChart() { + public void createChart() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { model = initLinearModel(); model.setTitle("Сглаживание временного ряда"); model.setLegendPosition("d"); @@ -95,7 +95,7 @@ public class IndexView implements Serializable { return model; } - public String getTimeSeriesString() { + public String getTimeSeriesString() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { createChart(); return timeSeriesString; } diff --git a/src/main/java/ru/ulstu/score/ScoreMethod.java b/src/main/java/ru/ulstu/score/ScoreMethod.java index 5c227f2..447d191 100644 --- a/src/main/java/ru/ulstu/score/ScoreMethod.java +++ b/src/main/java/ru/ulstu/score/ScoreMethod.java @@ -33,6 +33,8 @@ public abstract class ScoreMethod { .stream() .filter(v -> v.getDate().equals(timeSeriesValueToFind.getDate())) .findAny() - .orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: " + timeSeriesValueToFind.getDate())); + .orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: " + + timeSeriesValueToFind.getDate() + + " " + timeSeries)); } } diff --git a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java index 0f1e6ad..753a7cc 100644 --- a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java +++ b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java @@ -17,15 +17,18 @@ import ru.ulstu.method.MethodParameter; import ru.ulstu.score.ScoreMethod; import ru.ulstu.score.Smape; +import java.lang.reflect.InvocationTargetException; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.stream.Collectors; @Service public class MethodParamBruteForce { @@ -38,52 +41,69 @@ public class MethodParamBruteForce { this.methods = methods; } - public TimeSeries getForecast(TimeSeries timeSeries) { - throw new RuntimeException("Not implemented"); + public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { + List> results = new ArrayList<>(); + List results2 = new CopyOnWriteArrayList<>(); + final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast; + TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()), + "test part of " + timeSeries.getName()); + for (Method method : methods) { + List> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); + for (List parametersValues : availableParametersValues) { + Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); + if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) { + results.add(executors.submit(() -> { + TimeSeries forecast = methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints); + return new ModelingResult(forecast, + parametersValues, + scoreMethod.getScore(timeSeries, forecast), + methodInstance); + })); + } + } + } + for (Future futureModelingResult : results) { + results2.add(futureModelingResult.get()); + } + ModelingResult bestResult = results2.stream() + .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) + .orElse(null); + return new ModelingResult(bestResult.getTimeSeriesMethod().getForecast(timeSeries, + bestResult.getParamValues(), + countPoints), + bestResult.getParamValues(), + bestResult.getScore(), + bestResult.getTimeSeriesMethod()); } public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) { throw new RuntimeException("Not implemented"); } - public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ModelingException { + public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { List> results = new ArrayList<>(); List results2 = new CopyOnWriteArrayList<>(); - try { - for (Method method : methods) { - List> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); - for (List parametersValues : availableParametersValues) { + for (Method method : methods) { + List> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); + for (List parametersValues : availableParametersValues) { + Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); + if (methodInstance.canMakeModel(timeSeries, parametersValues)) { results.add(executors.submit(() -> { - Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); - try { - Model model = methodInstance.getModel(timeSeries, parametersValues); - return new ModelingResult(model.getTimeSeriesModel(), - parametersValues, - scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()), - methodInstance); - } catch (ModelingException ex) { - ex.printStackTrace(); - return null; - } + Model model = methodInstance.getModel(timeSeries, parametersValues); + return new ModelingResult(model.getTimeSeriesModel(), + parametersValues, + scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()), + methodInstance); })); } } - results.forEach(modelingResultFuture -> { - try { - modelingResultFuture.get(); - } catch (Exception e) { - e.printStackTrace(); - } - }); - for (Future futureModelingResult : results) { - results2.add(futureModelingResult.get()); - } - return results2.stream() - .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) - .orElse(null); - } catch (Exception e) { - throw new ModelingException(e.getMessage()); } + for (Future futureModelingResult : results) { + results2.add(futureModelingResult.get()); + } + return results2.stream() + .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) + .orElse(null); } private List> getAvailableParametersValues(List availableParameters) { diff --git a/src/main/java/ru/ulstu/service/TimeSeriesService.java b/src/main/java/ru/ulstu/service/TimeSeriesService.java index 1758f21..e1fff5d 100644 --- a/src/main/java/ru/ulstu/service/TimeSeriesService.java +++ b/src/main/java/ru/ulstu/service/TimeSeriesService.java @@ -9,12 +9,12 @@ package ru.ulstu.service; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.stereotype.Service; +import ru.ulstu.datamodel.ModelingResult; import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.ts.TimeSeries; -import ru.ulstu.method.Method; -import ru.ulstu.method.exponential.addtrendaddseason.AddTrendAddSeason; -import java.util.Collections; +import java.lang.reflect.InvocationTargetException; +import java.util.concurrent.ExecutionException; @Service @@ -26,13 +26,11 @@ public class TimeSeriesService { this.methodParamBruteForce = methodParamBruteForce; } - public TimeSeries getForecast(TimeSeries timeSeries, int countPoints) throws ModelingException { - Method method; - method = new AddTrendAddSeason(); - return method.getForecast(timeSeries, Collections.emptyList(), countPoints); + public ModelingResult getForecast(TimeSeries timeSeries, int countPoints) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { + return methodParamBruteForce.getForecast(timeSeries, countPoints); } - public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ModelingException { + public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException { return methodParamBruteForce.getSmoothedTimeSeries(timeSeries).getTimeSeries(); } } diff --git a/src/main/resources/META-INF/resources/basicTemplate.xhtml b/src/main/resources/META-INF/resources/basicTemplate.xhtml index 25646b8..2d8d6d3 100644 --- a/src/main/resources/META-INF/resources/basicTemplate.xhtml +++ b/src/main/resources/META-INF/resources/basicTemplate.xhtml @@ -1,4 +1,10 @@ + +
-
Ulyanovsk State Technical University
-
ulstu.ru
-
2020
+
Ulyanovsk State Technical University
+
+ ulstu.ru + +
+
+ api for developers + +
+
2020