fix forecast by best model

This commit is contained in:
Anton Romanov 2021-06-23 14:42:02 +04:00
parent 518e01ca93
commit 90a107648d
9 changed files with 122 additions and 69 deletions

View File

@ -21,6 +21,9 @@ import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.service.MethodParamBruteForce;
import ru.ulstu.service.TimeSeriesService;
import java.lang.reflect.InvocationTargetException;
import java.util.concurrent.ExecutionException;
@RestController
@RequestMapping(ApiConfiguration.API_1_0)
public class TimeSeriesController {
@ -36,14 +39,14 @@ public class TimeSeriesController {
@PostMapping("getForecast")
@ApiOperation("Получить прогноз временного ряда")
public ResponseEntity<TimeSeries> getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ModelingException {
public ResponseEntity<ModelingResult> getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
return new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(),
forecastParams.getCountForecast()), HttpStatus.OK);
}
@PostMapping("getSmoothed")
@ApiOperation("Получить сглаженный временной ряд")
public ResponseEntity<ModelingResult> getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ModelingException {
public ResponseEntity<ModelingResult> getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
return new ResponseEntity<>(methodParamBruteForce.getSmoothedTimeSeries(timeSeries), HttpStatus.OK);
}
}

View File

@ -48,7 +48,10 @@ public class TimeSeriesValue {
@Override
public String toString() {
return value.toString();
return "TimeSeriesValue{" +
"date=" + date +
", value=" + value +
'}';
}
@Override
@ -64,4 +67,5 @@ public class TimeSeriesValue {
public int hashCode() {
return Objects.hash(date, value);
}
}

View File

@ -7,6 +7,8 @@
package ru.ulstu.method;
import com.fasterxml.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ru.ulstu.TimeSeriesUtils;
import ru.ulstu.datamodel.Model;
import ru.ulstu.datamodel.exception.ForecastValidateException;
@ -22,6 +24,8 @@ import java.util.List;
* Наиболее общая логика моделировании и прогнозирования временных рядов
*/
public abstract class Method {
private static final Logger LOGGER = LoggerFactory.getLogger(Method.class);
@JsonIgnore
public abstract List<MethodParameter> getAvailableParameters();
@ -31,8 +35,7 @@ public abstract class Method {
*
* @return модель временного ряда
*/
protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries,
List<MethodParamValue> parameters) throws ModelingException;
protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries, List<MethodParamValue> parameters);
/**
* Возвращает модельное представление временного ряда: для тех же точек времени что и в параметре timeSeries
@ -58,9 +61,10 @@ public abstract class Method {
*/
protected abstract TimeSeries getForecastWithValidParams(Model model, TimeSeries forecast) throws ModelingException;
public boolean canMakeForecast(TimeSeries timeSeries, int countPoints) {
public boolean canMakeForecast(TimeSeries timeSeries, List<MethodParamValue> parameters, int countPoints) {
try {
validateTimeSeries(timeSeries);
validateAdditionalParams(timeSeries, parameters);
validateForecastParams(countPoints);
} catch (ModelingException ex) {
return false;
@ -68,6 +72,16 @@ public abstract class Method {
return true;
}
public boolean canMakeModel(TimeSeries timeSeries, List<MethodParamValue> parameters) {
try {
validateTimeSeries(timeSeries);
validateAdditionalParams(timeSeries, parameters);
} catch (ModelingException ex) {
return false;
}
return true;
}
/**
* Выполняет построение модели и прогноза временного ряда. Даты спрогнозированных точек будут сгенерированы
* по модельным точкам.

View File

@ -9,7 +9,6 @@ package ru.ulstu.method.exponential.addtrendaddseason;
import org.springframework.stereotype.Component;
import ru.ulstu.datamodel.Model;
import ru.ulstu.datamodel.exception.ModelingException;
import ru.ulstu.datamodel.exception.TimeSeriesValidateException;
import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.method.Method;
import ru.ulstu.method.MethodParamValue;
@ -67,11 +66,10 @@ public class AddTrendAddSeason extends Method {
for (MethodParamValue parameter : parameters) {
if (parameter.getParameter() instanceof Season) {
if (ts.getLength() < parameter.getValue().intValue()) {
throw new TimeSeriesValidateException("Период больше чем длина ряда");
throw new ModelingException("Период больше чем длина ряда");
}
}
}
}
@Override
@ -82,10 +80,10 @@ public class AddTrendAddSeason extends Method {
List<Double> iComponent = currentModel.getSeasonComponent();
for (int t = 1; t < forecast.getLength(); t++) {
iComponent.add(currentModel.getGamma().getDoubleValue() * forecast.getNumericValue(t - 1) / sComponent.get(sComponent.size() - 1)
+ (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue()));
+ (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1));
forecast.getValues().get(t).setValue((sComponent.get(sComponent.size() - 1) + tComponent.get(tComponent.size() - 1) * t)
* iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue()));
* iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1));
}
return forecast;
}

View File

@ -24,7 +24,9 @@ import javax.annotation.PostConstruct;
import javax.faces.view.ViewScoped;
import javax.inject.Named;
import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.ExecutionException;
@Named
@ViewScoped
@ -41,12 +43,12 @@ public class IndexView implements Serializable {
private String timeSeriesString;
@PostConstruct
public void init() {
public void init() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
timeSeriesString = utilService.getTimeSeriesToString(utilService.getRandomTimeSeries(50));
createChart();
}
private LineChartModel initLinearModel() {
private LineChartModel initLinearModel() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
LineChartModel model = new LineChartModel();
LineChartSeries series1 = new LineChartSeries();
@ -60,20 +62,18 @@ public class IndexView implements Serializable {
LineChartSeries series2 = new LineChartSeries();
series2.setLabel("Сглаженный ряд");
try {
for (TimeSeriesValue value : timeSeriesService.smoothTimeSeries(timeSeries).getValues()) {
TimeSeries smoothedTimeSeries = timeSeriesService.smoothTimeSeries(timeSeries);
for (TimeSeriesValue value : smoothedTimeSeries.getValues()) {
series2.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue());
}
} catch (ModelingException ex) {
} catch (Exception ex) {
LOG.warn(ex.getMessage());
}
LineChartSeries series3 = new LineChartSeries();
series3.setLabel("Прогноз");
try {
for (TimeSeriesValue value : timeSeriesService.getForecast(timeSeries, 20).getValues()) {
series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue());
}
} catch (ModelingException ex) {
LOG.warn(ex.getMessage());
TimeSeries forecast = timeSeriesService.getForecast(timeSeries, 20).getTimeSeries();
for (TimeSeriesValue value : forecast.getValues()) {
series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue());
}
model.addSeries(series1);
model.addSeries(series2);
@ -81,7 +81,7 @@ public class IndexView implements Serializable {
return model;
}
public void createChart() {
public void createChart() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
model = initLinearModel();
model.setTitle("Сглаживание временного ряда");
model.setLegendPosition("d");
@ -95,7 +95,7 @@ public class IndexView implements Serializable {
return model;
}
public String getTimeSeriesString() {
public String getTimeSeriesString() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
createChart();
return timeSeriesString;
}

View File

@ -33,6 +33,8 @@ public abstract class ScoreMethod {
.stream()
.filter(v -> v.getDate().equals(timeSeriesValueToFind.getDate()))
.findAny()
.orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: " + timeSeriesValueToFind.getDate()));
.orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: "
+ timeSeriesValueToFind.getDate()
+ " " + timeSeries));
}
}

View File

@ -17,15 +17,18 @@ import ru.ulstu.method.MethodParameter;
import ru.ulstu.score.ScoreMethod;
import ru.ulstu.score.Smape;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@Service
public class MethodParamBruteForce {
@ -38,52 +41,69 @@ public class MethodParamBruteForce {
this.methods = methods;
}
public TimeSeries getForecast(TimeSeries timeSeries) {
throw new RuntimeException("Not implemented");
public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
List<Future<ModelingResult>> results = new ArrayList<>();
List<ModelingResult> results2 = new CopyOnWriteArrayList<>();
final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast;
TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()),
"test part of " + timeSeries.getName());
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) {
results.add(executors.submit(() -> {
TimeSeries forecast = methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints);
return new ModelingResult(forecast,
parametersValues,
scoreMethod.getScore(timeSeries, forecast),
methodInstance);
}));
}
}
}
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
ModelingResult bestResult = results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
return new ModelingResult(bestResult.getTimeSeriesMethod().getForecast(timeSeries,
bestResult.getParamValues(),
countPoints),
bestResult.getParamValues(),
bestResult.getScore(),
bestResult.getTimeSeriesMethod());
}
public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) {
throw new RuntimeException("Not implemented");
}
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ModelingException {
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
List<Future<ModelingResult>> results = new ArrayList<>();
List<ModelingResult> results2 = new CopyOnWriteArrayList<>();
try {
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeModel(timeSeries, parametersValues)) {
results.add(executors.submit(() -> {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
try {
Model model = methodInstance.getModel(timeSeries, parametersValues);
return new ModelingResult(model.getTimeSeriesModel(),
parametersValues,
scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()),
methodInstance);
} catch (ModelingException ex) {
ex.printStackTrace();
return null;
}
Model model = methodInstance.getModel(timeSeries, parametersValues);
return new ModelingResult(model.getTimeSeriesModel(),
parametersValues,
scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()),
methodInstance);
}));
}
}
results.forEach(modelingResultFuture -> {
try {
modelingResultFuture.get();
} catch (Exception e) {
e.printStackTrace();
}
});
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
return results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
} catch (Exception e) {
throw new ModelingException(e.getMessage());
}
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
return results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
}
private List<List<MethodParamValue>> getAvailableParametersValues(List<MethodParameter> availableParameters) {

View File

@ -9,12 +9,12 @@ package ru.ulstu.service;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import ru.ulstu.datamodel.ModelingResult;
import ru.ulstu.datamodel.exception.ModelingException;
import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.method.Method;
import ru.ulstu.method.exponential.addtrendaddseason.AddTrendAddSeason;
import java.util.Collections;
import java.lang.reflect.InvocationTargetException;
import java.util.concurrent.ExecutionException;
@Service
@ -26,13 +26,11 @@ public class TimeSeriesService {
this.methodParamBruteForce = methodParamBruteForce;
}
public TimeSeries getForecast(TimeSeries timeSeries, int countPoints) throws ModelingException {
Method method;
method = new AddTrendAddSeason();
return method.getForecast(timeSeries, Collections.emptyList(), countPoints);
public ModelingResult getForecast(TimeSeries timeSeries, int countPoints) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
return methodParamBruteForce.getForecast(timeSeries, countPoints);
}
public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ModelingException {
public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
return methodParamBruteForce.getSmoothedTimeSeries(timeSeries).getTimeSeries();
}
}

View File

@ -1,4 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?>
<!--
- Copyright (C) 2021 Anton Romanov - All Rights Reserved
- You may use, distribute and modify this code, please write to: romanov73@gmail.com.
-
-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml"
@ -32,10 +38,18 @@
<div class="ui-g-12">
<p:outputPanel>
<div class="ui-g">
<div class="ui-g-12 ui-md-12 ui-lg-4">Ulyanovsk State Technical University</div>
<div class="ui-g-6 ui-md-6 ui-lg-4"><h:outputLink
value="http://ulstu.ru">ulstu.ru</h:outputLink></div>
<div class="ui-g-6 ui-md-6 ui-lg-4">2020</div>
<div class="ui-g-12 ui-md-12 ui-lg-3">Ulyanovsk State Technical University</div>
<div class="ui-g-6 ui-md-6 ui-lg-3">
<h:outputLink
value="http://ulstu.ru">ulstu.ru
</h:outputLink>
</div>
<div class="ui-g-6 ui-md-6 ui-lg-3">
<h:outputLink target="_blank"
value="/swagger-ui.html">api for developers
</h:outputLink>
</div>
<div class="ui-g-6 ui-md-6 ui-lg-3">2020</div>
</div>
</p:outputPanel>
</div>