diff --git a/pyFTS/benchmarks/knn.py b/pyFTS/benchmarks/knn.py index ca3ac5c..e51ba18 100644 --- a/pyFTS/benchmarks/knn.py +++ b/pyFTS/benchmarks/knn.py @@ -58,12 +58,12 @@ class KNearestNeighbors(fts.FTS): def train(self, data, **kwargs): X,Y = self._prepare_xy(data) - self.kdtree = KDTree(X) + self.kdtree = KDTree(np.array(X)) self.values = Y def knn(self, sample): X = self._prepare_x(sample) - _, ix = self.kdtree.query(X, self.k) + _, ix = self.kdtree.query(np.array(X), self.k) return [self.values[k] for k in ix.flatten() ] diff --git a/pyFTS/tests/general.py b/pyFTS/tests/general.py index 091e2b6..2799453 100644 --- a/pyFTS/tests/general.py +++ b/pyFTS/tests/general.py @@ -19,7 +19,7 @@ from pyFTS.fcm import fts, common, GA from pyFTS.data import TAIEX, NASDAQ, SP500 -''' +#''' train = TAIEX.get_data()[:800] test = TAIEX.get_data()[800:1000] @@ -61,19 +61,19 @@ methods_parameters = [ {'order':2 } ] -#for dataset_name, dataset in datasets.items(): -bchmk.sliding_window_benchmarks2(TAIEX.get_data()[:5000], 1000, train=0.8, inc=0.2, - benchmark_models=False, - benchmark_methods=methods, - benchmark_methods_parameters=methods_parameters, - methods=[ifts.IntervalFTS, ifts.WeightedIntervalFTS], - methods_parameters=[{},{}], - transformations=[None], - orders=[1,2,3], - steps_ahead=[10], - partitions=[33], - type='interval', - #distributed=True, nodes=['192.168.0.110', '192.168.0.107','192.168.0.106'], - #file="tmp.db", dataset=dataset_name, tag="experiments") - file="tmp.db", dataset='TAIEX', tag="experiments") -#''' \ No newline at end of file +for dataset_name, dataset in datasets.items(): + bchmk.sliding_window_benchmarks2(dataset, 1000, train=0.8, inc=0.2, + benchmark_models=True, + benchmark_methods=methods, + benchmark_methods_parameters=methods_parameters, + methods=[], + methods_parameters=[{},{}], + transformations=[None], + orders=[], + steps_ahead=[10], + partitions=[], + type='distribution', + distributed=True, nodes=['192.168.0.110', '192.168.0.107','192.168.0.106'], + file="experiments.db", dataset=dataset_name, tag="experiments") +# file="tmp.db", dataset='TAIEX', tag="experiments") +''' \ No newline at end of file