FTS.predict bugfix for multivariate FTS
This commit is contained in:
parent
68a4a953b8
commit
409f9d5b6b
@ -76,8 +76,8 @@ class FTS(object):
|
||||
else:
|
||||
ndata = self.apply_transformations(data)
|
||||
|
||||
if self.uod_clip:
|
||||
ndata = np.clip(ndata, self.original_min, self.original_max)
|
||||
if self.uod_clip:
|
||||
ndata = np.clip(ndata, self.original_min, self.original_max)
|
||||
|
||||
if 'distributed' in kwargs:
|
||||
distributed = kwargs.pop('distributed')
|
||||
|
@ -7,7 +7,7 @@ from pyFTS.common import Transformations
|
||||
from pyFTS.data import SONDA
|
||||
df = SONDA.get_dataframe()
|
||||
train = df.iloc[0:578241] #three years
|
||||
#test = df.iloc[1572480:2096640] #ears
|
||||
test = df.iloc[1572480:2096640] #one year
|
||||
del df
|
||||
|
||||
from pyFTS.partitioners import Grid, Util as pUtil
|
||||
@ -58,8 +58,10 @@ model1.target_variable = vavg
|
||||
|
||||
#model.fit(train, num_batches=60, save=True, batch_save=True, file_path='mvfts_sonda')
|
||||
|
||||
model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=True,
|
||||
nodes=['192.168.0.110'], batch_save_interval=10)
|
||||
#model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=False,
|
||||
# nodes=['192.168.0.110'], batch_save_interval=10)
|
||||
|
||||
|
||||
#model = Util.load_obj('mvfts_sonda')
|
||||
model = Util.load_obj('mvfts_sonda')
|
||||
|
||||
forecasts = model.predict(test)
|
Loading…
Reference in New Issue
Block a user