FTS.predict bugfix for multivariate FTS

This commit is contained in:
Petrônio Cândido 2018-06-07 12:51:55 -03:00
parent 68a4a953b8
commit 409f9d5b6b
2 changed files with 8 additions and 6 deletions

View File

@ -76,8 +76,8 @@ class FTS(object):
else: else:
ndata = self.apply_transformations(data) ndata = self.apply_transformations(data)
if self.uod_clip: if self.uod_clip:
ndata = np.clip(ndata, self.original_min, self.original_max) ndata = np.clip(ndata, self.original_min, self.original_max)
if 'distributed' in kwargs: if 'distributed' in kwargs:
distributed = kwargs.pop('distributed') distributed = kwargs.pop('distributed')

View File

@ -7,7 +7,7 @@ from pyFTS.common import Transformations
from pyFTS.data import SONDA from pyFTS.data import SONDA
df = SONDA.get_dataframe() df = SONDA.get_dataframe()
train = df.iloc[0:578241] #three years train = df.iloc[0:578241] #three years
#test = df.iloc[1572480:2096640] #ears test = df.iloc[1572480:2096640] #one year
del df del df
from pyFTS.partitioners import Grid, Util as pUtil from pyFTS.partitioners import Grid, Util as pUtil
@ -58,8 +58,10 @@ model1.target_variable = vavg
#model.fit(train, num_batches=60, save=True, batch_save=True, file_path='mvfts_sonda') #model.fit(train, num_batches=60, save=True, batch_save=True, file_path='mvfts_sonda')
model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=True, #model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=False,
nodes=['192.168.0.110'], batch_save_interval=10) # nodes=['192.168.0.110'], batch_save_interval=10)
#model = Util.load_obj('mvfts_sonda') model = Util.load_obj('mvfts_sonda')
forecasts = model.predict(test)