time-series-smoothing/src/main/java/ru/ulstu/service/MethodParamBruteForce.java

217 lines
11 KiB
Java

package ru.ulstu.service;
import org.springframework.stereotype.Service;
import ru.ulstu.datamodel.Model;
import ru.ulstu.datamodel.ModelingResult;
import ru.ulstu.datamodel.exception.ModelingException;
import ru.ulstu.datamodel.ts.TimeSeries;
import ru.ulstu.datamodel.ts.TimeSeriesValue;
import ru.ulstu.method.Method;
import ru.ulstu.method.MethodParamValue;
import ru.ulstu.method.MethodParameter;
import ru.ulstu.score.ScoreMethod;
import ru.ulstu.score.Smape;
import java.lang.reflect.InvocationTargetException;
import java.time.LocalDateTime;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.stream.Collectors;
@Service
class MethodParamBruteForce {
private final int DEFAULT_THREAD_COUNT = 50;
private final List<Method> methods;
private final ScoreMethod scoreMethod = new Smape();
private final ExecutorService executors = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT);
public MethodParamBruteForce(List<Method> methods) {
this.methods = methods;
}
private ModelingResult getForecastByMethods(TimeSeries timeSeries, List<Method> methods, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
List<Future<ModelingResult>> futureModelingResults = new ArrayList<>();
List<ModelingResult> modelingResults = new CopyOnWriteArrayList<>();
final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast;
TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()),
"test part of " + timeSeries.getKey());
Map<LocalDateTime, Double> tsValues = timeSeries.getValues().stream()
.collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue));
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(timeSeries, method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) {
futureModelingResults.add(executors.submit(() -> {
TimeSeries forecast = syncDates(methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints), timeSeries);
return new ModelingResult(forecast, null,
parametersValues,
scoreMethod.getScore(tsValues, forecast),
methodInstance);
}));
}
}
}
for (Future<ModelingResult> futureModelingResult : futureModelingResults) {
modelingResults.add(futureModelingResult.get());
}
return getBestResultForecast(modelingResults, timeSeries, countPoints);
}
public ModelingResult getForecast(TimeSeries timeSeries, String methodClassName, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
Method method = methods.stream()
.filter(m -> m.getClass().getSimpleName().equals(methodClassName))
.findAny()
.orElseThrow(() -> new ModelingException("Неизвестный метод прогнозирования"));
return getForecastByMethods(timeSeries, List.of(method), countPointsForecast);
}
public ModelingResult getForecast(TimeSeries timeSeries, Method method, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
return getForecastByMethods(timeSeries, List.of(method), countPointsForecast);
}
public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
return getForecastByMethods(timeSeries, methods, countPointsForecast);
}
private ModelingResult getBestResultForecast(List<ModelingResult> modelingResults,
TimeSeries timeSeries,
int countPoints) throws ModelingException {
ModelingResult bestResult = modelingResults.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElseThrow(() -> new ModelingException("Лучший метод не найден"));
TimeSeries forecast = bestResult.getTimeSeriesMethod().getForecast(timeSeries,
bestResult.getParamValues(),
countPoints);
return new ModelingResult(forecast,
bestResult.getTimeSeries(),
bestResult.getParamValues(),
bestResult.getScore(),
bestResult.getTimeSeriesMethod());
}
private TimeSeries syncDates(TimeSeries forecast, TimeSeries timeSeries) {
List<TimeSeriesValue> forecastValues = forecast.getValues();
for (int i = 1; i <= forecastValues.size(); i++) {
forecastValues.get(forecastValues.size() - i)
.setDate(timeSeries.getValues().get(timeSeries.getValues().size() - i).getDate());
}
return forecast;
}
/*
TODO:
public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) {
throw new RuntimeException("Not implemented");
}
*/
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries, List<Method> methods) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
List<Future<ModelingResult>> results = new ArrayList<>();
List<ModelingResult> results2 = new CopyOnWriteArrayList<>();
Map<LocalDateTime, Double> tsValues = timeSeries.getValues().stream()
.collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue));
for (Method method : methods) {
List<List<MethodParamValue>> availableParametersValues = getAvailableParametersValues(timeSeries, method.getAvailableParameters());
for (List<MethodParamValue> parametersValues : availableParametersValues) {
Method methodInstance = method.getClass().getDeclaredConstructor().newInstance();
if (methodInstance.canMakeModel(timeSeries, parametersValues)) {
results.add(executors.submit(() -> {
Model model = methodInstance.getModel(timeSeries, parametersValues);
return new ModelingResult(model.getTimeSeriesModel(),
null,
parametersValues,
scoreMethod.getScore(tsValues, model.getTimeSeriesModel()),
methodInstance);
}));
}
}
}
for (Future<ModelingResult> futureModelingResult : results) {
results2.add(futureModelingResult.get());
}
return results2.stream()
.min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue()))
.orElse(null);
}
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries, String methodClassName) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException {
Method method = methods.stream()
.filter(m -> m.getClass().getSimpleName().equals(methodClassName))
.findAny()
.orElseThrow(() -> new ModelingException("Неизвестный метод прогнозирования"));
return getSmoothedTimeSeries(timeSeries, List.of(method));
}
public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException {
return getSmoothedTimeSeries(timeSeries, methods);
}
private List<List<MethodParamValue>> getAvailableParametersValues(TimeSeries timeSeries, List<MethodParameter> availableParameters) {
List<List<MethodParamValue>> result = new ArrayList<>();
Map<MethodParameter, Integer> parameterOffset = new TreeMap<>();
Map<MethodParameter, List<Number>> parameterValues = new TreeMap<>();
for (MethodParameter methodParameter : availableParameters) {
parameterOffset.put(methodParameter, 0);
parameterValues.put(methodParameter, methodParameter.getAvailableValues(timeSeries));
}
while (isNotAllParameterValuesUsed(parameterOffset, parameterValues)) {
List<MethodParamValue> resultRow = new ArrayList<>();
for (MethodParameter methodParameter : parameterOffset.keySet()) {
resultRow.add(new MethodParamValue(methodParameter,
parameterValues.get(methodParameter).get(parameterOffset.get(methodParameter))));
}
incrementOffset(parameterOffset, parameterValues);
result.add(resultRow);
}
return result;
}
private void incrementOffset(Map<MethodParameter, Integer> parameterOffset,
Map<MethodParameter, List<Number>> parameterValues) {
List<MethodParameter> parameters = new ArrayList<>(parameterOffset.keySet());
int i = 0;
while (i < parameters.size() && isNotAllParameterValuesUsed(parameterOffset, parameterValues)) {
if (parameterOffset.get(parameters.get(i)) == parameterValues.get(parameters.get(i)).size() - 1) {
parameterOffset.put(parameters.get(i), 0);
i++;
continue;
}
if (parameterOffset.get(parameters.get(i)) < parameterValues.get(parameters.get(i)).size()) {
parameterOffset.put(parameters.get(i), parameterOffset.get(parameters.get(i)) + 1);
return;
}
i++;
}
}
private boolean isNotAllParameterValuesUsed(Map<MethodParameter, Integer> parameterOffset,
Map<MethodParameter, List<Number>> parameterValues) {
for (MethodParameter methodParameter : parameterOffset.keySet()) {
if (parameterOffset.get(methodParameter) != parameterValues.get(methodParameter).size() - 1) {
return true;
}
}
return false;
}
public List<Method> getAvailableMethods() {
return methods;
}
}