diff --git a/pyFTS/models/nonstationary/util.py b/pyFTS/models/nonstationary/util.py index b174195..4b0a90b 100644 --- a/pyFTS/models/nonstationary/util.py +++ b/pyFTS/models/nonstationary/util.py @@ -6,7 +6,7 @@ import matplotlib.pyplot as plt from pyFTS.common import Membership, Util -def plot_sets(sets, start=0, end=10, step=1, tam=[5, 5], colors=None, +def plot_sets(partitioner, start=0, end=10, step=1, tam=[5, 5], colors=None, save=False, file=None, axes=None, data=None, window_size = 1, only_lines=False): range = np.arange(start,end,step) @@ -14,7 +14,8 @@ def plot_sets(sets, start=0, end=10, step=1, tam=[5, 5], colors=None, if axes is None: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam) - for ct, set in enumerate(sets): + for ct, key in enumerate(partitioner.ordered_sets): + set = partitioner.sets[key] if not only_lines: for t in range: tdisp = t - (t % window_size) @@ -61,20 +62,21 @@ def plot_sets_conditional(model, data, start=0, end=10, step=1, tam=[5, 5], colo if axes is None: fig, axes = plt.subplots(nrows=1, ncols=1, figsize=tam) - for ct, set in enumerate(model.sets): - for t in range: - tdisp = model.perturbation_factors(data[t]) - set.perturbate_parameters(tdisp[ct]) - param = set.perturbated_parameters[str(tdisp[ct])] + for ct, key in enumerate(model.partitioner.ordered_sets): + set = model.partitioner.sets[key] + for t in range: + tdisp = model.perturbation_factors(data[t]) + set.perturbate_parameters(tdisp[ct]) + param = set.perturbated_parameters[str(tdisp[ct])] - if set.mf == Membership.trimf: - if t == start: - line = axes.plot([t, t+1, t], param, label=set.name) - set.metadata['color'] = line[0].get_color() - else: - axes.plot([t, t + 1, t], param,c=set.metadata['color']) + if set.mf == Membership.trimf: + if t == start: + line = axes.plot([t, t+1, t], param, label=set.name) + set.metadata['color'] = line[0].get_color() + else: + axes.plot([t, t + 1, t], param,c=set.metadata['color']) - ticks.extend(["t+"+str(t),""]) + ticks.extend(["t+"+str(t),""]) axes.set_ylabel("Universe of Discourse") axes.set_xlabel("Time")