Bugfixes on hofts and pwfts

This commit is contained in:
Petrônio Cândido 2018-12-11 13:15:44 -02:00
parent fc2c779266
commit 8d9d0e09c1
3 changed files with 15 additions and 13 deletions

View File

@ -458,7 +458,8 @@ class FTS(object):
return data
def get_UoD(self):
return [self.original_min, self.original_max]
#return [self.original_min, self.original_max]
return [self.partitioner.min, self.partitioner.max]
def __str__(self):
"""String representation of the model"""

View File

@ -88,13 +88,13 @@ class HighOrderFTS(fts.FTS):
self.detail = "Severiano, Silva, Sadaei and Guimarães"
self.is_high_order = True
self.min_order = 1
self.order= kwargs.get("order", 2)
self.order= kwargs.get("order", self.min_order)
self.lags = kwargs.get("lags", None)
self.configure_lags(**kwargs)
def configure_lags(self, **kwargs):
if "order" in kwargs:
self.order = kwargs.get("order", 2)
self.order = kwargs.get("order", self.min_order)
if "lags" in kwargs:
self.lags = kwargs.get("lags", None)

View File

@ -17,21 +17,22 @@ from pyFTS.common import Transformations
tdiff = Transformations.Differential(1)
from pyFTS.data import TAIEX, SP500, NASDAQ, Malaysia
from pyFTS.data import TAIEX, SP500, NASDAQ, Malaysia, Enrollments
dataset = Malaysia.get_data('temperature')[:1000]
train_split = 2000
test_length = 200
p = Grid.GridPartitioner(data=dataset, npart=20)
dataset = TAIEX.get_data()
print(p)
partitioner = Grid.GridPartitioner(data=dataset[:train_split], npart=35)
partitioner_diff = Grid.GridPartitioner(data=dataset[:train_split], npart=5, transformation=tdiff)
model = pwfts.ProbabilisticWeightedFTS(partitioner=p, order=2)
pfts1_taiex = pwfts.ProbabilisticWeightedFTS(partitioner=partitioner)
pfts1_taiex.fit(dataset[:train_split], save_model=True, file_path='pwfts', order=1)
pfts1_taiex.shortname = "1st Order"
#print(pfts1_taiex)
model.fit(dataset) #[22, 22, 23, 23, 24])
print(model)
Measures.get_point_statistics(dataset, model)
tmp = pfts1_taiex.predict(dataset[train_split:train_split+200], type='distribution')
'''
#dataset = SP500.get_data()[11500:16000]