Bugfix in benchmarks.knn

This commit is contained in:
Petrônio Cândido 2019-06-02 15:36:17 -03:00
parent 2149b4d041
commit f072d73d1b
2 changed files with 19 additions and 19 deletions

View File

@ -58,12 +58,12 @@ class KNearestNeighbors(fts.FTS):
def train(self, data, **kwargs): def train(self, data, **kwargs):
X,Y = self._prepare_xy(data) X,Y = self._prepare_xy(data)
self.kdtree = KDTree(X) self.kdtree = KDTree(np.array(X))
self.values = Y self.values = Y
def knn(self, sample): def knn(self, sample):
X = self._prepare_x(sample) X = self._prepare_x(sample)
_, ix = self.kdtree.query(X, self.k) _, ix = self.kdtree.query(np.array(X), self.k)
return [self.values[k] for k in ix.flatten() ] return [self.values[k] for k in ix.flatten() ]

View File

@ -19,7 +19,7 @@ from pyFTS.fcm import fts, common, GA
from pyFTS.data import TAIEX, NASDAQ, SP500 from pyFTS.data import TAIEX, NASDAQ, SP500
''' #'''
train = TAIEX.get_data()[:800] train = TAIEX.get_data()[:800]
test = TAIEX.get_data()[800:1000] test = TAIEX.get_data()[800:1000]
@ -61,19 +61,19 @@ methods_parameters = [
{'order':2 } {'order':2 }
] ]
#for dataset_name, dataset in datasets.items(): for dataset_name, dataset in datasets.items():
bchmk.sliding_window_benchmarks2(TAIEX.get_data()[:5000], 1000, train=0.8, inc=0.2, bchmk.sliding_window_benchmarks2(dataset, 1000, train=0.8, inc=0.2,
benchmark_models=False, benchmark_models=True,
benchmark_methods=methods, benchmark_methods=methods,
benchmark_methods_parameters=methods_parameters, benchmark_methods_parameters=methods_parameters,
methods=[ifts.IntervalFTS, ifts.WeightedIntervalFTS], methods=[],
methods_parameters=[{},{}], methods_parameters=[{},{}],
transformations=[None], transformations=[None],
orders=[1,2,3], orders=[],
steps_ahead=[10], steps_ahead=[10],
partitions=[33], partitions=[],
type='interval', type='distribution',
#distributed=True, nodes=['192.168.0.110', '192.168.0.107','192.168.0.106'], distributed=True, nodes=['192.168.0.110', '192.168.0.107','192.168.0.106'],
#file="tmp.db", dataset=dataset_name, tag="experiments") file="experiments.db", dataset=dataset_name, tag="experiments")
file="tmp.db", dataset='TAIEX', tag="experiments") # file="tmp.db", dataset='TAIEX', tag="experiments")
#''' '''