diff --git a/src/main/java/ru/ulstu/controller/TimeSeriesController.java b/src/main/java/ru/ulstu/controller/TimeSeriesController.java index 915d920..2c5f0b1 100644 --- a/src/main/java/ru/ulstu/controller/TimeSeriesController.java +++ b/src/main/java/ru/ulstu/controller/TimeSeriesController.java @@ -20,6 +20,7 @@ import ru.ulstu.method.Method; import ru.ulstu.service.TimeSeriesService; import javax.servlet.http.HttpServletRequest; +import javax.validation.Valid; import java.lang.reflect.InvocationTargetException; import java.util.List; import java.util.concurrent.ExecutionException; @@ -37,7 +38,7 @@ public class TimeSeriesController { @PostMapping("getForecast") @Operation(description = "Получить прогноз временного ряда любым методом") - public ResponseEntity getForecastTimeSeries(@RequestBody ForecastParams forecastParams, HttpServletRequest request) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { + public ResponseEntity getForecastTimeSeries(@RequestBody @Valid ForecastParams forecastParams, HttpServletRequest request) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { LOGGER.info("User ip: " + HttpUtils.getUserIp(request)); LOGGER.info("Forecast: " + forecastParams); ResponseEntity result = new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(), @@ -58,10 +59,11 @@ public class TimeSeriesController { @PostMapping("getSpecificMethodForecast") @Operation(description = "Получить прогноз временного ряда указанным методом") - public ResponseEntity getForecastTimeSeriesSpecificMethod(@RequestBody ForecastParams forecastParams, HttpServletRequest request) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { + public ResponseEntity getForecastTimeSeriesSpecificMethod(@RequestBody @Valid ForecastParams forecastParams, HttpServletRequest request) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { LOGGER.info("User ip: " + HttpUtils.getUserIp(request)); LOGGER.info("Forecast: " + forecastParams); ResponseEntity result = new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(), + forecastParams.getMethodClassName(), forecastParams.getCountForecast()), HttpStatus.OK); LOGGER.info("Forecast result complete"); return result; diff --git a/src/main/java/ru/ulstu/datamodel/ForecastParams.java b/src/main/java/ru/ulstu/datamodel/ForecastParams.java index c56d497..265dfe6 100644 --- a/src/main/java/ru/ulstu/datamodel/ForecastParams.java +++ b/src/main/java/ru/ulstu/datamodel/ForecastParams.java @@ -2,9 +2,14 @@ package ru.ulstu.datamodel; import ru.ulstu.datamodel.ts.TimeSeries; +import javax.validation.constraints.NotNull; + public class ForecastParams { + @NotNull private TimeSeries originalTimeSeries; + @NotNull private int countForecast; + private String methodClassName; public TimeSeries getOriginalTimeSeries() { return originalTimeSeries; @@ -22,6 +27,14 @@ public class ForecastParams { this.countForecast = countForecast; } + public String getMethodClassName() { + return methodClassName; + } + + public void setMethodClassName(String methodClassName) { + this.methodClassName = methodClassName; + } + @Override public String toString() { return "ForecastParams{" + diff --git a/src/main/java/ru/ulstu/db/model/TimeSeriesSet.java b/src/main/java/ru/ulstu/db/model/TimeSeriesSet.java index b7cdb62..5dbbfbf 100644 --- a/src/main/java/ru/ulstu/db/model/TimeSeriesSet.java +++ b/src/main/java/ru/ulstu/db/model/TimeSeriesSet.java @@ -3,7 +3,7 @@ package ru.ulstu.db.model; import java.io.File; public class TimeSeriesSet { - private String key; + private final String key; public TimeSeriesSet(File dir) { this.key = dir.getName(); diff --git a/src/main/java/ru/ulstu/method/Method.java b/src/main/java/ru/ulstu/method/Method.java index 7f8a755..ca04a08 100644 --- a/src/main/java/ru/ulstu/method/Method.java +++ b/src/main/java/ru/ulstu/method/Method.java @@ -86,7 +86,9 @@ public abstract class Method { Model model = getModel(timeSeries, parameters); TimeSeries forecast = generateEmptyForecastPoints(model.getTimeSeriesModel(), countPoints); forecast.getFirstValue().setValue(model.getTimeSeriesModel().getLastValue().getValue()); - return getForecastWithValidParams(model, forecast); + forecast = getForecastWithValidParams(model, forecast); + forecast.getFirstValue().setValue(timeSeries.getLastValue().getValue()); + return forecast; } protected TimeSeries generateEmptyForecastPoints(TimeSeries model, int countPointForecast) { diff --git a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java index eab23f8..ec4a5af 100644 --- a/src/main/java/ru/ulstu/service/MethodParamBruteForce.java +++ b/src/main/java/ru/ulstu/service/MethodParamBruteForce.java @@ -37,9 +37,9 @@ class MethodParamBruteForce { this.methods = methods; } - public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { - List> results = new ArrayList<>(); - List results2 = new CopyOnWriteArrayList<>(); + private ModelingResult getForecastByMethods(TimeSeries timeSeries, List methods, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { + List> futureModelingResults = new ArrayList<>(); + List modelingResults = new CopyOnWriteArrayList<>(); final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast; TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()), "test part of " + timeSeries.getKey()); @@ -52,7 +52,7 @@ class MethodParamBruteForce { for (List parametersValues : availableParametersValues) { Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) { - results.add(executors.submit(() -> { + futureModelingResults.add(executors.submit(() -> { TimeSeries forecast = syncDates(methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints), timeSeries); return new ModelingResult(forecast, null, parametersValues, @@ -62,17 +62,39 @@ class MethodParamBruteForce { } } } - for (Future futureModelingResult : results) { - results2.add(futureModelingResult.get()); + for (Future futureModelingResult : futureModelingResults) { + modelingResults.add(futureModelingResult.get()); } - ModelingResult bestResult = results2.stream() + + return getBestResultForecast(modelingResults, timeSeries, countPoints); + } + + public ModelingResult getForecast(TimeSeries timeSeries, String methodClassName, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { + Method method = methods.stream() + .filter(m -> m.getClass().getSimpleName().equals(methodClassName)) + .findAny() + .orElseThrow(() -> new ModelingException("Неизвестный метод прогнозирования")); + return getForecastByMethods(timeSeries, List.of(method), countPointsForecast); + } + + public ModelingResult getForecast(TimeSeries timeSeries, Method method, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { + return getForecastByMethods(timeSeries, List.of(method), countPointsForecast); + } + + public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { + return getForecastByMethods(timeSeries, methods, countPointsForecast); + } + + private ModelingResult getBestResultForecast(List modelingResults, + TimeSeries timeSeries, + int countPoints) throws ModelingException { + ModelingResult bestResult = modelingResults.stream() .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) .orElseThrow(() -> new ModelingException("Лучший метод не найден")); TimeSeries forecast = bestResult.getTimeSeriesMethod().getForecast(timeSeries, bestResult.getParamValues(), countPoints); - forecast.getValue(0).setValue(timeSeries.getNumericValue(timeSeries.getLength() - 1)); return new ModelingResult(forecast, bestResult.getTimeSeries(), @@ -90,9 +112,12 @@ class MethodParamBruteForce { return forecast; } + /* + TODO: public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) { throw new RuntimeException("Not implemented"); } + */ public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { List> results = new ArrayList<>(); @@ -133,7 +158,7 @@ class MethodParamBruteForce { parameterOffset.put(methodParameter, 0); parameterValues.put(methodParameter, methodParameter.getAvailableValues()); } - while (!isAllValuesUsed(parameterOffset, parameterValues)) { + while (isNotAllParameterValuesUsed(parameterOffset, parameterValues)) { List resultRow = new ArrayList<>(); for (MethodParameter methodParameter : parameterOffset.keySet()) { resultRow.add(new MethodParamValue(methodParameter, @@ -149,7 +174,7 @@ class MethodParamBruteForce { Map> parameterValues) { List parameters = new ArrayList<>(parameterOffset.keySet()); int i = 0; - while (i < parameters.size() && !isAllValuesUsed(parameterOffset, parameterValues)) { + while (i < parameters.size() && isNotAllParameterValuesUsed(parameterOffset, parameterValues)) { if (parameterOffset.get(parameters.get(i)) == parameterValues.get(parameters.get(i)).size() - 1) { parameterOffset.put(parameters.get(i), 0); i++; @@ -163,14 +188,14 @@ class MethodParamBruteForce { } } - private boolean isAllValuesUsed(Map parameterOffset, - Map> parameterValues) { + private boolean isNotAllParameterValuesUsed(Map parameterOffset, + Map> parameterValues) { for (MethodParameter methodParameter : parameterOffset.keySet()) { if (parameterOffset.get(methodParameter) != parameterValues.get(methodParameter).size() - 1) { - return false; + return true; } } - return true; + return false; } public List getAvailableMethods() { diff --git a/src/main/java/ru/ulstu/service/TimeSeriesService.java b/src/main/java/ru/ulstu/service/TimeSeriesService.java index 1b84572..4b80bc7 100644 --- a/src/main/java/ru/ulstu/service/TimeSeriesService.java +++ b/src/main/java/ru/ulstu/service/TimeSeriesService.java @@ -1,5 +1,6 @@ package ru.ulstu.service; +import org.springframework.context.ApplicationContext; import org.springframework.stereotype.Service; import ru.ulstu.datamodel.ModelingResult; import ru.ulstu.datamodel.exception.ModelingException; @@ -14,15 +15,22 @@ import java.util.concurrent.ExecutionException; @Service public class TimeSeriesService { private final MethodParamBruteForce methodParamBruteForce; + private final ApplicationContext applicationContext; - public TimeSeriesService(MethodParamBruteForce methodParamBruteForce) { + public TimeSeriesService(MethodParamBruteForce methodParamBruteForce, + ApplicationContext applicationContext) { this.methodParamBruteForce = methodParamBruteForce; + this.applicationContext = applicationContext; } public ModelingResult getForecast(TimeSeries timeSeries, int countPoints) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { return methodParamBruteForce.getForecast(timeSeries, countPoints); } + public ModelingResult getForecast(TimeSeries timeSeries, String methodClassName, int countPoints) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException { + return methodParamBruteForce.getForecast(timeSeries, methodClassName, countPoints); + } + public ModelingResult smoothTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException { return methodParamBruteForce.getSmoothedTimeSeries(timeSeries); }