GridSearch improvements

This commit is contained in:
Petrônio Cândido 2018-11-14 00:42:59 -02:00
parent 2db3b0311e
commit edceece6e2
2 changed files with 56 additions and 39 deletions

View File

@ -42,7 +42,7 @@ def metodo_cluster(individual, train, test):
partitioner = Entropy.EntropyPartitioner(data=train, npart=npart, func=mf)
model = hofts.HighOrderFTS(partitioner=partitioner,
model = hofts.WeightedHighOrderFTS(partitioner=partitioner,
lags=individual['lags'],
alpha_cut=individual['alpha'],
order=individual['order'])
@ -51,8 +51,32 @@ def metodo_cluster(individual, train, test):
rmse, mape, u = Measures.get_point_statistics(test, model)
return individual, rmse
size = len(model)
return individual, rmse, size
def process_jobs(jobs, datasetname, conn):
for job in jobs:
result, rmse, size = job()
if job.status == dispy.DispyJob.Finished and result is not None:
print(result)
record = (datasetname, 'GridSearch', 'WHOFTS', None, result['mf'],
result['order'], result['partitioner'], result['npart'],
result['alpha'], str(result['lags']), 'rmse', rmse)
hUtil.insert_hyperparam(record, conn)
record = (datasetname, 'GridSearch', 'WHOFTS', None, result['mf'],
result['order'], result['partitioner'], result['npart'],
result['alpha'], str(result['lags']), 'size', size)
hUtil.insert_hyperparam(record, conn)
else:
print(job.exception)
print(job.stdout)
def execute(hyperparams, datasetname, train, test, **kwargs):
@ -75,20 +99,27 @@ def execute(hyperparams, datasetname, train, test, **kwargs):
[v for v in hyperparams[hp]]
for hp in keys_sorted
]
cluster, http_server = Util.start_dispy_cluster(metodo_cluster, nodes=nodes)
conn = hUtil.open_hyperparam_db('hyperparam.db')
for instance in product(*hp_values):
for ct, instance in enumerate(product(*hp_values)):
partitions = instance[index['partitions']]
partitioner = instance[index['partitioner']]
mf = instance[index['mf']]
alpha_cut = instance[index['alpha']]
order = instance[index['order']]
count = 0
for lag1 in lags: # o é o lag1
_lags = [lag1]
count += 1
if order > 1:
for lag2 in lags: # o é o lag1
_lags2 = [lag1, lag1+lag2]
count += 1
if order > 2:
for lag3 in lags: # o é o lag1
count += 1
_lags3 = [lag1, lag1 + lag2, lag1 + lag2+lag3 ]
individuals.append(dict_individual(mf, partitioner, partitions, order, _lags3, alpha_cut))
else:
@ -96,32 +127,18 @@ def execute(hyperparams, datasetname, train, test, **kwargs):
dict_individual(mf, partitioner, partitions, order, _lags2, alpha_cut))
else:
individuals.append(dict_individual(mf, partitioner, partitions, order, _lags, alpha_cut))
if count > 50:
jobs = []
for ind in individuals:
job = cluster.submit(ind, train, test)
jobs.append(job)
process_jobs(jobs, datasetname, conn)
count = 0
individuals = []
cluster, http_server = Util.start_dispy_cluster(metodo_cluster, nodes=nodes)
jobs = []
for ind in individuals:
job = cluster.submit(ind, train, test)
jobs.append(job)
conn = hUtil.open_hyperparam_db('hyperparam.db')
for job in jobs:
result, rmse = job()
if job.status == dispy.DispyJob.Finished and result is not None:
print(result)
record = (datasetname, 'GridSearch', 'HOFTS', None, result['mf'],
result['order'], result['partitioner'], result['npart'],
result['alpha'], str(result['lags']), 'rmse', rmse)
hUtil.insert_hyperparam(record, conn)
else:
print(job.exception)
print(job.stdout)
Util.stop_dispy_cluster(cluster, http_server)
Util.stop_dispy_cluster(cluster, http_server)

View File

@ -1,4 +1,4 @@
import numpy as np
from pyFTS.hyperparam import GridSearch
def get_train_test():
@ -12,16 +12,16 @@ def get_train_test():
return 'Malaysia.temperature', train, test
hyperparams = {
'order':[1],
'partitions':[10, 15],
'partitioner': [1],
'mf': [1],
'lags': [1, 2, 3],
'alpha': [.1, .2, .5]
'order':[1, 2, 3],
'partitions': np.arange(10,100,3),
'partitioner': [1,2],
'mf': [1, 2, 3, 4],
'lags': np.arange(1,35,2),
'alpha': np.arange(0,.5, .05)
}
nodes = ['192.168.0.110','192.168.0.106']
nodes = ['192.168.0.110','192.168.0.106', '192.168.0.107']
ds, train, test = get_train_test()
GridSearch.execute(hyperparams, ds, train, test, nodes=nodes)
GridSearch.execute(hyperparams, ds, train, test, nodes=nodes)