Bugfix in benchmarks.mv_run methods
This commit is contained in:
parent
715772bbcc
commit
dbfa1ac86e
@ -770,7 +770,9 @@ def mv_run_point2(mfts, train_data, test_data, window_key=None, **kwargs):
|
||||
_end = time.time()
|
||||
times += _end - _start
|
||||
|
||||
eval = Measures.get_point_ahead_statistics(test_data[mfts.order:mfts.order + steps_ahead], forecasts)
|
||||
tmp_test = test_data[mfts.target_variable.data_label].values[mfts.order:mfts.order + steps_ahead]
|
||||
|
||||
eval = Measures.get_point_ahead_statistics(tmp_test, forecasts)
|
||||
|
||||
for key in eval.keys():
|
||||
eval[key]["time"] = times
|
||||
@ -898,7 +900,9 @@ def mv_run_interval2(mfts,train_data, test_data, window_key=None, **kwargs):
|
||||
_end = time.time()
|
||||
times += _end - _start
|
||||
|
||||
eval = Measures.get_interval_ahead_statistics(test_data[mfts.order:mfts.order+steps_ahead], intervals)
|
||||
tmp_test = test_data[mfts.target_variable.data_label].values[mfts.order:mfts.order + steps_ahead]
|
||||
|
||||
eval = Measures.get_interval_ahead_statistics(tmp_test, intervals)
|
||||
|
||||
for key in eval.keys():
|
||||
eval[key]["time"] = times
|
||||
@ -1017,7 +1021,9 @@ def mv_run_probabilistic2(mfts, train_data, test_data, window_key=None, **kwargs
|
||||
_end = time.time()
|
||||
times += _end - _start
|
||||
|
||||
eval = Measures.get_distribution_ahead_statistics(test_data[mfts.order:mfts.order+steps_ahead], distributions)
|
||||
tmp_test = test_data[mfts.target_variable.data_label].values[mfts.order:mfts.order + steps_ahead]
|
||||
|
||||
eval = Measures.get_distribution_ahead_statistics(tmp_test, distributions)
|
||||
|
||||
for key in eval.keys():
|
||||
eval[key]["time"] = times
|
||||
|
@ -69,9 +69,9 @@ detrend = trend.apply(data)
|
||||
plt.plot(trend.inverse(detrend, data))
|
||||
'''
|
||||
|
||||
dataset = pd.read_csv('https://query.data.world/s/nxst4hzhjrqld4bxhbpn6twmjbwqk7')
|
||||
dataset['data'] = pd.to_datetime([str(y)+'-'+str(m) for y,m in zip(dataset['Ano'].values, dataset['Mes'].values)],
|
||||
format='%Y-%m')
|
||||
#dataset = pd.read_csv('https://query.data.world/s/nxst4hzhjrqld4bxhbpn6twmjbwqk7')
|
||||
#dataset['data'] = pd.to_datetime([str(y)+'-'+str(m) for y,m in zip(dataset['Ano'].values, dataset['Mes'].values)],
|
||||
# format='%Y-%m')
|
||||
roi = Transformations.ROI()
|
||||
|
||||
'''
|
||||
@ -93,6 +93,8 @@ model.fit(train)
|
||||
fig, ax = plt.subplots(nrows=1, ncols=1, figsize=[10,5])
|
||||
ax.plot(test)
|
||||
|
||||
'''
|
||||
|
||||
'''
|
||||
train = dataset.iloc[:30]
|
||||
test = dataset.iloc[30:]
|
||||
@ -129,3 +131,47 @@ ax.plot(forecast)
|
||||
plt.show()
|
||||
|
||||
print(dataset)
|
||||
'''
|
||||
|
||||
eto = pd.read_csv('https://raw.githubusercontent.com/PatriciaLucas/Evapotranspiracao/master/ETo_setelagoas.csv', sep=',')
|
||||
eto['Data'] = pd.to_datetime(eto["Data"], format='%Y-%m-%d')
|
||||
|
||||
from pyFTS.models.multivariate import common, variable, mvfts, wmvfts, granular
|
||||
from pyFTS.models import hofts, pwfts
|
||||
from pyFTS.partitioners import Grid, Entropy
|
||||
from pyFTS.common import Membership
|
||||
from pyFTS.models.seasonal.common import DateTime
|
||||
from pyFTS.models.seasonal import partitioner as seasonal
|
||||
from pyFTS.benchmarks import Measures
|
||||
from pyFTS.benchmarks import arima, quantreg, knn, benchmarks as bchmk
|
||||
|
||||
variables = {
|
||||
"Month": dict(data_label="Data", partitioner=seasonal.TimeGridPartitioner, npart=6),
|
||||
"Eto": dict(data_label="Eto", alias='eto',
|
||||
partitioner=Grid.GridPartitioner, npart=50)
|
||||
}
|
||||
|
||||
methods = [mvfts.MVFTS, wmvfts.WeightedMVFTS, granular.GranularWMVFTS]
|
||||
|
||||
time_generator = lambda x : pd.to_datetime(x) + pd.to_timedelta(1, unit='d')
|
||||
|
||||
parameters = [
|
||||
{},{},
|
||||
dict(fts_method=pwfts.ProbabilisticWeightedFTS, fuzzyfy_mode='both',
|
||||
order=1, knn=3)
|
||||
]
|
||||
|
||||
|
||||
|
||||
bchmk.multivariate_sliding_window_benchmarks2(eto, 2000, train=0.8, inc=0.2,
|
||||
methods=methods,
|
||||
methods_parameters=parameters,
|
||||
variables=variables,
|
||||
target_variable='Eto',
|
||||
type='point',
|
||||
steps_ahead=[7],
|
||||
file="hyperparam.db", dataset='Eto',
|
||||
tag="experiments",
|
||||
generators= {'Data': time_generator}
|
||||
)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user