diff --git a/pyFTS/benchmarks/Measures.py b/pyFTS/benchmarks/Measures.py index 6aadfc5..b07fd47 100644 --- a/pyFTS/benchmarks/Measures.py +++ b/pyFTS/benchmarks/Measures.py @@ -323,6 +323,9 @@ def get_point_statistics(data, model, **kwargs): if steps_ahead == 1: forecasts = model.predict(ndata, **kwargs) + + if model.is_multivariate: + ndata = ndata[model1.target_variable.data_label].values if not isinstance(forecasts, (list, np.ndarray)): forecasts = [forecasts]