fix forecast by best model

This commit is contained in:
Anton Romanov 2021-06-23 14:42:02 +04:00
parent 518e01ca93
commit 90a107648d
9 changed files with 122 additions and 69 deletions

View File

@ -21,6 +21,9 @@ import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.service.MethodParamBruteForce; import ru.ulstu.service.MethodParamBruteForce;
import ru.ulstu.service.TimeSeriesService; import ru.ulstu.service.TimeSeriesService;
import java.lang.reflect.InvocationTargetException;
import java.util.concurrent.ExecutionException;
@RestController @RestController
@RequestMapping(ApiConfiguration.API_1_0) @RequestMapping(ApiConfiguration.API_1_0)
public class TimeSeriesController { public class TimeSeriesController {
@ -36,14 +39,14 @@ public class TimeSeriesController {
@PostMapping("getForecast") @PostMapping("getForecast")
@ApiOperation("Получить прогноз временного ряда") @ApiOperation("Получить прогноз временного ряда")
public ResponseEntity<TimeSeries> getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ModelingException { public ResponseEntity<ModelingResult> getForecastTimeSeries(@RequestBody ForecastParams forecastParams) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
return new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(), return new ResponseEntity<>(timeSeriesService.getForecast(forecastParams.getOriginalTimeSeries(),
forecastParams.getCountForecast()), HttpStatus.OK); forecastParams.getCountForecast()), HttpStatus.OK);
} }
@PostMapping("getSmoothed") @PostMapping("getSmoothed")
@ApiOperation("Получить сглаженный временной ряд") @ApiOperation("Получить сглаженный временной ряд")
public ResponseEntity<ModelingResult> getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ModelingException { public ResponseEntity<ModelingResult> getSmoothedTimeSeries(@RequestBody TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
return new ResponseEntity<>(methodParamBruteForce.getSmoothedTimeSeries(timeSeries), HttpStatus.OK); return new ResponseEntity<>(methodParamBruteForce.getSmoothedTimeSeries(timeSeries), HttpStatus.OK);
} }
} }

View File

@ -48,7 +48,10 @@ public class TimeSeriesValue {
@Override @Override
public String toString() { public String toString() {
return value.toString(); return "TimeSeriesValue{" +
"date=" + date +
", value=" + value +
'}';
} }
@Override @Override
@ -64,4 +67,5 @@ public class TimeSeriesValue {
public int hashCode() { public int hashCode() {
return Objects.hash(date, value); return Objects.hash(date, value);
} }
} }

View File

@ -7,6 +7,8 @@
package ru.ulstu.method; package ru.ulstu.method;
import com.fasterxml.jackson.annotation.JsonIgnore; import com.fasterxml.jackson.annotation.JsonIgnore;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import ru.ulstu.TimeSeriesUtils; import ru.ulstu.TimeSeriesUtils;
import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.Model;
import ru.ulstu.datamodel.exception.ForecastValidateException; import ru.ulstu.datamodel.exception.ForecastValidateException;
@ -22,6 +24,8 @@ import java.util.List;
* Наиболее общая логика моделировании и прогнозирования временных рядов * Наиболее общая логика моделировании и прогнозирования временных рядов
*/ */
public abstract class Method { public abstract class Method {
private static final Logger LOGGER = LoggerFactory.getLogger(Method.class);
@JsonIgnore @JsonIgnore
public abstract List<MethodParameter> getAvailableParameters(); public abstract List<MethodParameter> getAvailableParameters();
@ -31,8 +35,7 @@ public abstract class Method {
* *
* @return модель временного ряда * @return модель временного ряда
*/ */
protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries, protected abstract Model getModelOfValidTimeSeries(TimeSeries timeSeries, List<MethodParamValue> parameters);
List<MethodParamValue> parameters) throws ModelingException;
/** /**
* Возвращает модельное представление временного ряда: для тех же точек времени что и в параметре timeSeries * Возвращает модельное представление временного ряда: для тех же точек времени что и в параметре timeSeries
@ -58,9 +61,10 @@ public abstract class Method {
*/ */
protected abstract TimeSeries getForecastWithValidParams(Model model, TimeSeries forecast) throws ModelingException; protected abstract TimeSeries getForecastWithValidParams(Model model, TimeSeries forecast) throws ModelingException;
public boolean canMakeForecast(TimeSeries timeSeries, int countPoints) { public boolean canMakeForecast(TimeSeries timeSeries, List<MethodParamValue> parameters, int countPoints) {
try { try {
validateTimeSeries(timeSeries); validateTimeSeries(timeSeries);
validateAdditionalParams(timeSeries, parameters);
validateForecastParams(countPoints); validateForecastParams(countPoints);
} catch (ModelingException ex) { } catch (ModelingException ex) {
return false; return false;
@ -68,6 +72,16 @@ public abstract class Method {
return true; return true;
} }
public boolean canMakeModel(TimeSeries timeSeries, List<MethodParamValue> parameters) {
try {
validateTimeSeries(timeSeries);
validateAdditionalParams(timeSeries, parameters);
} catch (ModelingException ex) {
return false;
}
return true;
}
/** /**
* Выполняет построение модели и прогноза временного ряда. Даты спрогнозированных точек будут сгенерированы * Выполняет построение модели и прогноза временного ряда. Даты спрогнозированных точек будут сгенерированы
* по модельным точкам. * по модельным точкам.

View File

@ -9,7 +9,6 @@ package ru.ulstu.method.exponential.addtrendaddseason;
import org.springframework.stereotype.Component; import org.springframework.stereotype.Component;
import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.Model;
import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.exception.ModelingException;
import ru.ulstu.datamodel.exception.TimeSeriesValidateException;
import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.method.Method; import ru.ulstu.method.Method;
import ru.ulstu.method.MethodParamValue; import ru.ulstu.method.MethodParamValue;
@ -67,11 +66,10 @@ public class AddTrendAddSeason extends Method {
for (MethodParamValue parameter : parameters) { for (MethodParamValue parameter : parameters) {
if (parameter.getParameter() instanceof Season) { if (parameter.getParameter() instanceof Season) {
if (ts.getLength() < parameter.getValue().intValue()) { if (ts.getLength() < parameter.getValue().intValue()) {
throw new TimeSeriesValidateException("Период больше чем длина ряда"); throw new ModelingException("Период больше чем длина ряда");
} }
} }
} }
} }
@Override @Override
@ -82,10 +80,10 @@ public class AddTrendAddSeason extends Method {
List<Double> iComponent = currentModel.getSeasonComponent(); List<Double> iComponent = currentModel.getSeasonComponent();
for (int t = 1; t < forecast.getLength(); t++) { for (int t = 1; t < forecast.getLength(); t++) {
iComponent.add(currentModel.getGamma().getDoubleValue() * forecast.getNumericValue(t - 1) / sComponent.get(sComponent.size() - 1) iComponent.add(currentModel.getGamma().getDoubleValue() * forecast.getNumericValue(t - 1) / sComponent.get(sComponent.size() - 1)
+ (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue())); + (1 - currentModel.getGamma().getDoubleValue()) * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1));
forecast.getValues().get(t).setValue((sComponent.get(sComponent.size() - 1) + tComponent.get(tComponent.size() - 1) * t) forecast.getValues().get(t).setValue((sComponent.get(sComponent.size() - 1) + tComponent.get(tComponent.size() - 1) * t)
* iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue())); * iComponent.get(t + model.getTimeSeriesModel().getLength() - currentModel.getSeason().getIntValue() - 1));
} }
return forecast; return forecast;
} }

View File

@ -24,7 +24,9 @@ import javax.annotation.PostConstruct;
import javax.faces.view.ViewScoped; import javax.faces.view.ViewScoped;
import javax.inject.Named; import javax.inject.Named;
import java.io.Serializable; import java.io.Serializable;
import java.lang.reflect.InvocationTargetException;
import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatter;
import java.util.concurrent.ExecutionException;
@Named @Named
@ViewScoped @ViewScoped
@ -41,12 +43,12 @@ public class IndexView implements Serializable {
private String timeSeriesString; private String timeSeriesString;
@PostConstruct @PostConstruct
public void init() { public void init() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
timeSeriesString = utilService.getTimeSeriesToString(utilService.getRandomTimeSeries(50)); timeSeriesString = utilService.getTimeSeriesToString(utilService.getRandomTimeSeries(50));
createChart(); createChart();
} }
private LineChartModel initLinearModel() { private LineChartModel initLinearModel() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
LineChartModel model = new LineChartModel(); LineChartModel model = new LineChartModel();
LineChartSeries series1 = new LineChartSeries(); LineChartSeries series1 = new LineChartSeries();
@ -60,20 +62,18 @@ public class IndexView implements Serializable {
LineChartSeries series2 = new LineChartSeries(); LineChartSeries series2 = new LineChartSeries();
series2.setLabel("Сглаженный ряд"); series2.setLabel("Сглаженный ряд");
try { try {
for (TimeSeriesValue value : timeSeriesService.smoothTimeSeries(timeSeries).getValues()) { TimeSeries smoothedTimeSeries = timeSeriesService.smoothTimeSeries(timeSeries);
for (TimeSeriesValue value : smoothedTimeSeries.getValues()) {
series2.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue()); series2.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue());
} }
} catch (ModelingException ex) { } catch (Exception ex) {
LOG.warn(ex.getMessage()); LOG.warn(ex.getMessage());
} }
LineChartSeries series3 = new LineChartSeries(); LineChartSeries series3 = new LineChartSeries();
series3.setLabel("Прогноз"); series3.setLabel("Прогноз");
try { TimeSeries forecast = timeSeriesService.getForecast(timeSeries, 20).getTimeSeries();
for (TimeSeriesValue value : timeSeriesService.getForecast(timeSeries, 20).getValues()) { for (TimeSeriesValue value : forecast.getValues()) {
series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue()); series3.set(DateTimeFormatter.ISO_LOCAL_DATE.format(value.getDate()), value.getValue());
}
} catch (ModelingException ex) {
LOG.warn(ex.getMessage());
} }
model.addSeries(series1); model.addSeries(series1);
model.addSeries(series2); model.addSeries(series2);
@ -81,7 +81,7 @@ public class IndexView implements Serializable {
return model; return model;
} }
public void createChart() { public void createChart() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
model = initLinearModel(); model = initLinearModel();
model.setTitle("Сглаживание временного ряда"); model.setTitle("Сглаживание временного ряда");
model.setLegendPosition("d"); model.setLegendPosition("d");
@ -95,7 +95,7 @@ public class IndexView implements Serializable {
return model; return model;
} }
public String getTimeSeriesString() { public String getTimeSeriesString() throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
createChart(); createChart();
return timeSeriesString; return timeSeriesString;
} }

View File

@ -33,6 +33,8 @@ public abstract class ScoreMethod {
.stream() .stream()
.filter(v -> v.getDate().equals(timeSeriesValueToFind.getDate())) .filter(v -> v.getDate().equals(timeSeriesValueToFind.getDate()))
.findAny() .findAny()
.orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: " + timeSeriesValueToFind.getDate())); .orElseThrow(() -> new ModelingException("Значение модельного ряда не найдено в оригинальном ряде: "
+ timeSeriesValueToFind.getDate()
+ " " + timeSeries));
} }
} }

View File

@ -17,15 +17,18 @@ import ru.ulstu.method.MethodParameter;
import ru.ulstu.score.ScoreMethod; import ru.ulstu.score.ScoreMethod;
import ru.ulstu.score.Smape; import ru.ulstu.score.Smape;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Comparator; import java.util.Comparator;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService; import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors; import java.util.concurrent.Executors;
import java.util.concurrent.Future; import java.util.concurrent.Future;
import java.util.stream.Collectors;
@Service @Service
public class MethodParamBruteForce { public class MethodParamBruteForce {
@ -38,52 +41,69 @@ public class MethodParamBruteForce {
this.methods = methods; this.methods = methods;
} }
public TimeSeries getForecast(TimeSeries timeSeries) { public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
throw new RuntimeException("Not implemented"); List<Future<ModelingResult>> results = new ArrayList<>();
List<ModelingResult> results2 = new CopyOnWriteArrayList<>();
final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast;
TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()),
"test part of " + timeSeries.getName());
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) {
results.add(executors.submit(() -> {
TimeSeries forecast = methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints);
return new ModelingResult(forecast,
parametersValues,
scoreMethod.getScore(timeSeries, forecast),
methodInstance);
}));
}
}
}
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
ModelingResult bestResult = results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
return new ModelingResult(bestResult.getTimeSeriesMethod().getForecast(timeSeries,
bestResult.getParamValues(),
countPoints),
bestResult.getParamValues(),
bestResult.getScore(),
bestResult.getTimeSeriesMethod());
} }
public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) { public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) {
throw new RuntimeException("Not implemented"); throw new RuntimeException("Not implemented");
} }
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ModelingException { public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
List<Future<ModelingResult>> results = new ArrayList<>(); List<Future<ModelingResult>> results = new ArrayList<>();
List<ModelingResult> results2 = new CopyOnWriteArrayList<>(); List<ModelingResult> results2 = new CopyOnWriteArrayList<>();
try { for (Method method : methods) {
for (Method method : methods) { List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters());
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(method.getAvailableParameters()); for (List<MethodParamValue> parametersValues : availableParametersValues) {
for (List<MethodParamValue> parametersValues : availableParametersValues) { Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeModel(timeSeries, parametersValues)) {
results.add(executors.submit(() -> { results.add(executors.submit(() -> {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); Model model = methodInstance.getModel(timeSeries, parametersValues);
try { return new ModelingResult(model.getTimeSeriesModel(),
Model model = methodInstance.getModel(timeSeries, parametersValues); parametersValues,
return new ModelingResult(model.getTimeSeriesModel(), scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()),
parametersValues, methodInstance);
scoreMethod.getScore(timeSeries, model.getTimeSeriesModel()),
methodInstance);
} catch (ModelingException ex) {
ex.printStackTrace();
return null;
}
})); }));
} }
} }
results.forEach(modelingResultFuture -> {
try {
modelingResultFuture.get();
} catch (Exception e) {
e.printStackTrace();
}
});
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
return results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
} catch (Exception e) {
throw new ModelingException(e.getMessage());
} }
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
return results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
} }
private List<List<MethodParamValue>> getAvailableParametersValues(List<MethodParameter> availableParameters) { private List<List<MethodParamValue>> getAvailableParametersValues(List<MethodParameter> availableParameters) {

View File

@ -9,12 +9,12 @@ package ru.ulstu.service;
import org.slf4j.Logger; import org.slf4j.Logger;
import org.slf4j.LoggerFactory; import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service; import org.springframework.stereotype.Service;
import ru.ulstu.datamodel.ModelingResult;
import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.exception.ModelingException;
import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.method.Method;
import ru.ulstu.method.exponential.addtrendaddseason.AddTrendAddSeason;
import java.util.Collections; import java.lang.reflect.InvocationTargetException;
import java.util.concurrent.ExecutionException;
@Service @Service
@ -26,13 +26,11 @@ public class TimeSeriesService {
this.methodParamBruteForce = methodParamBruteForce; this.methodParamBruteForce = methodParamBruteForce;
} }
public TimeSeries getForecast(TimeSeries timeSeries, int countPoints) throws ModelingException { public ModelingResult getForecast(TimeSeries timeSeries, int countPoints) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException, ModelingException {
Method method; return methodParamBruteForce.getForecast(timeSeries, countPoints);
method = new AddTrendAddSeason();
return method.getForecast(timeSeries, Collections.emptyList(), countPoints);
} }
public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ModelingException { public TimeSeries smoothTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, InvocationTargetException, NoSuchMethodException, InstantiationException, IllegalAccessException {
return methodParamBruteForce.getSmoothedTimeSeries(timeSeries).getTimeSeries(); return methodParamBruteForce.getSmoothedTimeSeries(timeSeries).getTimeSeries();
} }
} }

View File

@ -1,4 +1,10 @@
<?xml version="1.0" encoding="UTF-8"?> <?xml version="1.0" encoding="UTF-8"?>
<!--
- Copyright (C) 2021 Anton Romanov - All Rights Reserved
- You may use, distribute and modify this code, please write to: romanov73@gmail.com.
-
-->
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" <!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN"
"http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd"> "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
<html xmlns="http://www.w3.org/1999/xhtml" <html xmlns="http://www.w3.org/1999/xhtml"
@ -32,10 +38,18 @@
<div class="ui-g-12"> <div class="ui-g-12">
<p:outputPanel> <p:outputPanel>
<div class="ui-g"> <div class="ui-g">
<div class="ui-g-12 ui-md-12 ui-lg-4">Ulyanovsk State Technical University</div> <div class="ui-g-12 ui-md-12 ui-lg-3">Ulyanovsk State Technical University</div>
<div class="ui-g-6 ui-md-6 ui-lg-4"><h:outputLink <div class="ui-g-6 ui-md-6 ui-lg-3">
value="http://ulstu.ru">ulstu.ru</h:outputLink></div> <h:outputLink
<div class="ui-g-6 ui-md-6 ui-lg-4">2020</div> value="http://ulstu.ru">ulstu.ru
</h:outputLink>
</div>
<div class="ui-g-6 ui-md-6 ui-lg-3">
<h:outputLink target="_blank"
value="/swagger-ui.html">api for developers
</h:outputLink>
</div>
<div class="ui-g-6 ui-md-6 ui-lg-3">2020</div>
</div> </div>
</p:outputPanel> </p:outputPanel>
</div> </div>