From 87a50c13423d77f09b10b92d33ffefcc09ef0c97 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido?= Date: Thu, 26 Dec 2019 08:05:35 -0300 Subject: [PATCH] Random Search in hyperparam --- pyFTS/hyperparam/mvfts.py | 32 +++++++++++++++++-------------- pyFTS/hyperparam/random_search.py | 4 +--- pyFTS/tests/hyperparam.py | 2 +- 3 files changed, 20 insertions(+), 18 deletions(-) diff --git a/pyFTS/hyperparam/mvfts.py b/pyFTS/hyperparam/mvfts.py index b4faf41..0777d4f 100644 --- a/pyFTS/hyperparam/mvfts.py +++ b/pyFTS/hyperparam/mvfts.py @@ -365,42 +365,46 @@ def mutation_random_search(individual, **kwargs): :param pmut: individual probability o :return: """ + import copy + + new = copy.deepcopy(individual) vars = kwargs.get('variables', None) tvar = kwargs.get('target_variable', None) l = len(vars) - il = len(individual['explanatory_variables']) + il = len(new['explanatory_variables']) # if il > 1: for l in range(il): - il = len(individual['explanatory_variables']) + il = len(new['explanatory_variables']) rnd = random.uniform(0, 1) if rnd > .5: rnd = random.randint(0, il-1) - val = individual['explanatory_variables'][rnd] - individual['explanatory_variables'].remove(val) - individual['explanatory_params'].pop(rnd) + if rnd < il and il > 1: + val = individual['explanatory_variables'][rnd] + new['explanatory_variables'].remove(val) + new['explanatory_params'].pop(rnd) else: rnd = random.randint(0, l-1) - while rnd in individual['explanatory_variables']: + while rnd in new['explanatory_variables']: rnd = random.randint(0, l-1) - individual['explanatory_variables'].append(rnd) - individual['explanatory_params'].append(random_param(vars[rnd])) + new['explanatory_variables'].append(rnd) + new['explanatory_params'].append(random_param(vars[rnd])) - for ct in np.arange(len(individual['explanatory_variables'])): + for ct in np.arange(len(new['explanatory_variables'])): rnd = random.uniform(0, 1) if rnd > .5: - mutate_variable_params(individual['explanatory_params'][ct], vars[ct]) + mutate_variable_params(new['explanatory_params'][ct], vars[ct]) rnd = random.uniform(0, 1) if rnd > .5: - mutate_variable_params(individual['target_params'], tvar) + mutate_variable_params(new['target_params'], tvar) - individual['f1'] = None - individual['f2'] = None + new['f1'] = None + new['f2'] = None - return individual + return new def mutate_variable_params(param, var): diff --git a/pyFTS/hyperparam/random_search.py b/pyFTS/hyperparam/random_search.py index 3f94995..17227a8 100644 --- a/pyFTS/hyperparam/random_search.py +++ b/pyFTS/hyperparam/random_search.py @@ -66,9 +66,7 @@ def execute( dataset, **kwargs): new[key] = ret[key] new_stat[key] = ret[key] - print(new) - - if new['f1'] <= individual['f1'] and new['f2'] <= individual['f2']: + if new['f1'] < individual['f1'] or (new['f1'] == individual['f1'] and new['f2'] < individual['f2']): individual = new no_improvement_count = 0 stat[i] = new_stat diff --git a/pyFTS/tests/hyperparam.py b/pyFTS/tests/hyperparam.py index b93c6bd..29fc653 100644 --- a/pyFTS/tests/hyperparam.py +++ b/pyFTS/tests/hyperparam.py @@ -58,7 +58,7 @@ target_variable = {'name': 'Load', 'data_label': 'load', 'type': 'common'} nodes=['192.168.28.38'] deho_mv.random_search(datsetname, dataset, - ngen=200, mgen=200, + ngen=200, mgen=70, window_size=2000, train_rate=.9, increment_rate=1, experiments=1, fts_method=wmvfts.WeightedMVFTS,