Bugfix in models.nonstationary.util.plot_sets_conditional
This commit is contained in:
parent
2055b512ca
commit
39e0c6aa88
@ -63,7 +63,7 @@ def plot_sets_conditional(model, data, step=1, size=[5, 5], colors=None,
|
||||
|
||||
for t in range:
|
||||
model.forecast([data[t]])
|
||||
perturb = model.perturbation_factors(data[t])
|
||||
perturb = model.conditional_perturbation_factors(data[t])
|
||||
|
||||
for ct, key in enumerate(model.partitioner.ordered_sets):
|
||||
set = model.partitioner.sets[key]
|
||||
|
@ -11,7 +11,7 @@ import pandas as pd
|
||||
from pyFTS.data import TAIEX, NASDAQ, SP500, artificial, mackey_glass
|
||||
|
||||
mackey_glass.get_data()
|
||||
'''
|
||||
|
||||
datasets = {
|
||||
"TAIEX": TAIEX.get_data()[:4000],
|
||||
"SP500": SP500.get_data()[10000:14000],
|
||||
@ -53,7 +53,38 @@ partitions = {'CMIV': {'BoxCox(0)': 36, 'Differential(1)': 11, 'None': 8},
|
||||
'SP500': {'BoxCox(0)': 33, 'Differential(1)': 7, 'None': 33},
|
||||
'TAIEX': {'BoxCox(0)': 39, 'Differential(1)': 31, 'None': 33}}
|
||||
|
||||
from pyFTS.models.nonstationary import partitioners as nspart, cvfts, util as nsUtil
|
||||
|
||||
|
||||
def model_details(ds, tf, train_split, test_split):
|
||||
data = datasets[ds]
|
||||
train = data[:train_split]
|
||||
test = data[train_split:test_split]
|
||||
transformation = transformations[tf]
|
||||
fs = nspart.simplenonstationary_gridpartitioner_builder(data=train, npart=partitions[ds][tf],
|
||||
transformation=transformation)
|
||||
model = nsfts.NonStationaryFTS(partitioner=fs)
|
||||
model.fit(train)
|
||||
print(model)
|
||||
forecasts = model.predict(test)
|
||||
residuals = np.array(test[1:]) - np.array(forecasts[:-1])
|
||||
|
||||
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=[15, 10])
|
||||
|
||||
axes[0].plot(test[1:], label="Original")
|
||||
axes[0].plot(forecasts[:-1], label="Forecasts")
|
||||
|
||||
axes[1].set_title("Residuals")
|
||||
axes[1].plot(residuals)
|
||||
handles0, labels0 = axes[0].get_legend_handles_labels()
|
||||
lgd = axes[0].legend(handles0, labels0, loc=2)
|
||||
|
||||
nsUtil.plot_sets_conditional(model, test, step=10, size=[12, 5])
|
||||
|
||||
model_details('NASDAQ','None',200,2000)
|
||||
|
||||
|
||||
'''
|
||||
tag = 'benchmarks'
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user