Bugfix in EntropyPartitioner
This commit is contained in:
parent
4eaeb90e4a
commit
72b34217f8
@ -40,7 +40,7 @@ def informationGain(data, thres1, thres2):
|
|||||||
|
|
||||||
def bestSplit(data, npart):
|
def bestSplit(data, npart):
|
||||||
if len(data) < 2:
|
if len(data) < 2:
|
||||||
return None
|
return []
|
||||||
count = 1
|
count = 1
|
||||||
ndata = list(set(np.array(data).flatten()))
|
ndata = list(set(np.array(data).flatten()))
|
||||||
ndata.sort()
|
ndata.sort()
|
||||||
|
@ -17,36 +17,13 @@ from pyFTS.common import Transformations
|
|||||||
tdiff = Transformations.Differential(1)
|
tdiff = Transformations.Differential(1)
|
||||||
|
|
||||||
|
|
||||||
from pyFTS.data import TAIEX, SP500, NASDAQ
|
from pyFTS.data import TAIEX, SP500, NASDAQ, Malaysia
|
||||||
|
|
||||||
dataset = TAIEX.get_data()
|
dataset = Malaysia.get_data('temperature')[:1000]
|
||||||
|
|
||||||
from pyFTS.models.incremental import Retrainer
|
p = Entropy.EntropyPartitioner(data=dataset, npart=19)
|
||||||
|
|
||||||
from pyFTS.models.incremental import Retrainer
|
print(p)
|
||||||
from pyFTS.benchmarks import benchmarks as bchmk
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for method in bchmk.get_point_methods():
|
|
||||||
model = Retrainer.Retrainer(partitioner_params = {'npart': 30},
|
|
||||||
fts_method=method,
|
|
||||||
window_length = 500, batch_size = 100)
|
|
||||||
models.append(model)
|
|
||||||
|
|
||||||
#model.predict(dataset)
|
|
||||||
|
|
||||||
from pyFTS.partitioners import Grid, Util as pUtil
|
|
||||||
from pyFTS.benchmarks import benchmarks as bchmk, naive
|
|
||||||
|
|
||||||
tag = 'benchmarks_retrainer'
|
|
||||||
|
|
||||||
bchmk.sliding_window_benchmarks(dataset, 2000, train=.1, inc=0.1,
|
|
||||||
models=[model],
|
|
||||||
build_methods = False,
|
|
||||||
benchmark_models=False,
|
|
||||||
partitions=[35],
|
|
||||||
progress=False, type='point',
|
|
||||||
file="nsfts_benchmarks.db", dataset='teste', tag=tag)
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
#dataset = SP500.get_data()[11500:16000]
|
#dataset = SP500.get_data()[11500:16000]
|
||||||
|
Loading…
Reference in New Issue
Block a user