package ru.ulstu.service; import org.springframework.stereotype.Service; import ru.ulstu.datamodel.Model; import ru.ulstu.datamodel.ModelingResult; import ru.ulstu.datamodel.exception.ModelingException; import ru.ulstu.datamodel.ts.TimeSeries; import ru.ulstu.datamodel.ts.TimeSeriesValue; import ru.ulstu.method.Method; import ru.ulstu.method.MethodParamValue; import ru.ulstu.method.MethodParameter; import ru.ulstu.score.ScoreMethod; import ru.ulstu.score.Smape; import java.lang.reflect.InvocationTargetException; import java.time.LocalDateTime; import java.util.ArrayList; import java.util.Comparator; import java.util.List; import java.util.Map; import java.util.TreeMap; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.Future; import java.util.stream.Collectors; @Service class MethodParamBruteForce { private final int DEFAULT_THREAD_COUNT = 50; private final List methods; private final ScoreMethod scoreMethod = new Smape(); private final ExecutorService executors = Executors.newFixedThreadPool(DEFAULT_THREAD_COUNT); public MethodParamBruteForce(List methods) { this.methods = methods; } private ModelingResult getForecastByMethods(TimeSeries timeSeries, List methods, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { List> futureModelingResults = new ArrayList<>(); List modelingResults = new CopyOnWriteArrayList<>(); final int countPoints = (countPointsForecast > timeSeries.getLength()) ? timeSeries.getLength() / 3 : countPointsForecast; TimeSeries reducedTimeSeries = new TimeSeries(timeSeries.getValues().stream().limit(timeSeries.getLength() - countPoints).collect(Collectors.toList()), "test part of " + timeSeries.getKey()); try { ValidationUtils.validateTimeSeries(reducedTimeSeries); } catch (ModelingException ex) { throw new ModelingException("Тестовая часть временного ряда не прошла валидацию: " + ex.getMessage()); } Map tsValues = timeSeries.getValues().stream() .collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue)); for (Method method : methods) { List> availableParametersValues = getAvailableParametersValues(timeSeries, method.getAvailableParameters()); for (List parametersValues : availableParametersValues) { Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); if (methodInstance.canMakeForecast(reducedTimeSeries, parametersValues, countPoints)) { futureModelingResults.add(executors.submit(() -> { TimeSeries forecast = syncDates(methodInstance.getForecast(reducedTimeSeries, parametersValues, countPoints), timeSeries); return new ModelingResult(forecast, null, parametersValues, scoreMethod.getScore(tsValues, forecast), methodInstance); })); } } } for (Future futureModelingResult : futureModelingResults) { modelingResults.add(futureModelingResult.get()); } return getBestResultForecast(modelingResults, timeSeries, countPoints); } public ModelingResult getForecast(TimeSeries timeSeries, String methodClassName, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { Method method = methods.stream() .filter(m -> m.getClass().getSimpleName().equals(methodClassName)) .findAny() .orElseThrow(() -> new ModelingException("Неизвестный метод прогнозирования")); method.validateForForecast(timeSeries, countPointsForecast); return getForecastByMethods(timeSeries, List.of(method), countPointsForecast); } public ModelingResult getForecast(TimeSeries timeSeries, Method method, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { return getForecastByMethods(timeSeries, List.of(method), countPointsForecast); } public ModelingResult getForecast(TimeSeries timeSeries, int countPointsForecast) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { return getForecastByMethods(timeSeries, methods, countPointsForecast); } private ModelingResult getBestResultForecast(List modelingResults, TimeSeries timeSeries, int countPoints) throws ModelingException { if (modelingResults.size() == 0) { throw new ModelingException("Нет результатов моделирования"); } ModelingResult bestResult = modelingResults.stream() .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) .orElseThrow(() -> new ModelingException("Лучший метод не найден")); TimeSeries forecast = bestResult.getTimeSeriesMethod().getForecast(timeSeries, bestResult.getParamValues(), countPoints); return new ModelingResult(forecast, bestResult.getTimeSeries(), bestResult.getParamValues(), bestResult.getScore(), bestResult.getTimeSeriesMethod()); } private TimeSeries syncDates(TimeSeries forecast, TimeSeries timeSeries) { List forecastValues = forecast.getValues(); for (int i = 1; i <= forecastValues.size(); i++) { forecastValues.get(forecastValues.size() - i) .setDate(timeSeries.getValues().get(timeSeries.getValues().size() - i).getDate()); } return forecast; } /* TODO: public TimeSeries getForecastWithOptimalLength(TimeSeries timeSeries) { throw new RuntimeException("Not implemented"); } */ public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries, List methods) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { List> results = new ArrayList<>(); List results2 = new CopyOnWriteArrayList<>(); Map tsValues = timeSeries.getValues().stream() .collect(Collectors.toMap(TimeSeriesValue::getDate, TimeSeriesValue::getValue)); for (Method method : methods) { List> availableParametersValues = getAvailableParametersValues(timeSeries, method.getAvailableParameters()); for (List parametersValues : availableParametersValues) { Method methodInstance = method.getClass().getDeclaredConstructor().newInstance(); if (methodInstance.canMakeModel(timeSeries, parametersValues)) { results.add(executors.submit(() -> { Model model = methodInstance.getModel(timeSeries, parametersValues); return new ModelingResult(model.getTimeSeriesModel(), null, parametersValues, scoreMethod.getScore(tsValues, model.getTimeSeriesModel()), methodInstance); })); } } } for (Future futureModelingResult : results) { results2.add(futureModelingResult.get()); } return results2.stream() .min(Comparator.comparing(modelingResult -> modelingResult.getScore().getDoubleValue())) .orElse(null); } public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries, String methodClassName) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException, ModelingException { Method method = methods.stream() .filter(m -> m.getClass().getSimpleName().equals(methodClassName)) .findAny() .orElseThrow(() -> new ModelingException("Неизвестный метод прогнозирования")); return getSmoothedTimeSeries(timeSeries, List.of(method)); } public ModelingResult getSmoothedTimeSeries(TimeSeries timeSeries) throws ExecutionException, InterruptedException, NoSuchMethodException, InvocationTargetException, InstantiationException, IllegalAccessException { return getSmoothedTimeSeries(timeSeries, methods); } private List> getAvailableParametersValues(TimeSeries timeSeries, List availableParameters) { List> result = new ArrayList<>(); Map parameterOffset = new TreeMap<>(); Map> parameterValues = new TreeMap<>(); for (MethodParameter methodParameter : availableParameters) { parameterOffset.put(methodParameter, 0); parameterValues.put(methodParameter, methodParameter.getAvailableValues(timeSeries)); } while (isNotAllParameterValuesUsed(parameterOffset, parameterValues)) { List resultRow = new ArrayList<>(); for (MethodParameter methodParameter : parameterOffset.keySet()) { resultRow.add(new MethodParamValue(methodParameter, parameterValues.get(methodParameter).get(parameterOffset.get(methodParameter)))); } incrementOffset(parameterOffset, parameterValues); result.add(resultRow); } return result; } private void incrementOffset(Map parameterOffset, Map> parameterValues) { List parameters = new ArrayList<>(parameterOffset.keySet()); int i = 0; while (i < parameters.size() && isNotAllParameterValuesUsed(parameterOffset, parameterValues)) { if (parameterOffset.get(parameters.get(i)) == parameterValues.get(parameters.get(i)).size() - 1) { parameterOffset.put(parameters.get(i), 0); i++; continue; } if (parameterOffset.get(parameters.get(i)) < parameterValues.get(parameters.get(i)).size()) { parameterOffset.put(parameters.get(i), parameterOffset.get(parameters.get(i)) + 1); return; } i++; } } private boolean isNotAllParameterValuesUsed(Map parameterOffset, Map> parameterValues) { for (MethodParameter methodParameter : parameterOffset.keySet()) { if (parameterOffset.get(methodParameter) != parameterValues.get(methodParameter).size() - 1) { return true; } } return false; } public List getAvailableMethods() { return methods; } }