Bugfixes on nonstationary methods

This commit is contained in:
Petrônio Cândido 2018-04-11 15:52:11 -03:00
parent 4ba6c16a2f
commit d980078a8e

View File

@ -6,7 +6,7 @@ import matplotlib.pyplot as plt
from pyFTS.common import Membership, Util from pyFTS.common import Membership, Util
def plot_sets(sets, start=0, end=10, step=1, tam=[5, 5], colors=None, def plot_sets(partitioner, start=0, end=10, step=1, tam=[5, 5], colors=None,
save=False, file=None, axes=None, data=None, window_size = 1, only_lines=False): save=False, file=None, axes=None, data=None, window_size = 1, only_lines=False):
range = np.arange(start,end,step) range = np.arange(start,end,step)
@ -14,7 +14,8 @@ def plot_sets(sets, start=0, end=10, step=1, tam=[5, 5], colors=None,
if axes is None: if axes is None:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam) fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam)
for ct, set in enumerate(sets): for ct, key in enumerate(partitioner.ordered_sets):
set = partitioner.sets[key]
if not only_lines: if not only_lines:
for t in range: for t in range:
tdisp = t - (t % window_size) tdisp = t - (t % window_size)
@ -61,20 +62,21 @@ def plot_sets_conditional(model, data, start=0, end=10, step=1, tam=[5, 5], colo
if axes is None: if axes is None:
fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam) fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam)
for ct, set in enumerate(model.sets): for ct, key in enumerate(model.partitioner.ordered_sets):
for t in range: set = model.partitioner.sets[key]
tdisp = model.perturbation_factors(data[t]) for t in range:
set.perturbate_parameters(tdisp[ct]) tdisp = model.perturbation_factors(data[t])
param = set.perturbated_parameters[str(tdisp[ct])] set.perturbate_parameters(tdisp[ct])
param = set.perturbated_parameters[str(tdisp[ct])]
if set.mf == Membership.trimf: if set.mf == Membership.trimf:
if t == start: if t == start:
line = axes.plot([t, t+1, t], param, label=set.name) line = axes.plot([t, t+1, t], param, label=set.name)
set.metadata['color'] = line[0].get_color() set.metadata['color'] = line[0].get_color()
else: else:
axes.plot([t, t + 1, t], param,c=set.metadata['color']) axes.plot([t, t + 1, t], param,c=set.metadata['color'])
ticks.extend(["t+"+str(t),""]) ticks.extend(["t+"+str(t),""])
axes.set_ylabel("Universe of Discourse") axes.set_ylabel("Universe of Discourse")
axes.set_xlabel("Time") axes.set_xlabel("Time")