CVFTS improvements and bugfixes; FTS.fit bugfix for multivariate models; Util.plot_rules high order capability

This commit is contained in:
Petrônio Cândido 2018-06-07 09:58:34 -03:00
parent a2002c20b1
commit 68a4a953b8
5 changed files with 113 additions and 31 deletions

View File

@ -8,15 +8,82 @@ import dill
import numpy as np
def plot_rules(model, size=[5, 5], axis=None):
def plot_rules(model, size=[5, 5], axis=None, rules_by_axis=None, columns=1):
if axis is None and rules_by_axis is None:
fig, axis = plt.subplots(nrows=1, ncols=1, figsize=size)
elif axis is None and rules_by_axis is not None:
rows = (((len(model.flrgs.keys())//rules_by_axis)) // columns)+1
fig, axis = plt.subplots(nrows=rows, ncols=columns, figsize=size)
if rules_by_axis is None:
draw_sets_on_axis(axis, model, size)
_lhs = model.partitioner.ordered_sets if not model.is_high_order else model.flrgs.keys()
for ct, key in enumerate(_lhs):
if rules_by_axis is None:
ax = axis
else:
colcount = (ct // rules_by_axis) % columns
rowcount = (ct // rules_by_axis) // columns
ax = axis[rowcount, colcount] if columns > 1 else axis[rowcount]
if ct % rules_by_axis == 0:
xticks = []
xtickslabels = []
draw_sets_on_axis(ax, model, size)
if not model.is_high_order:
if key in model.flrgs:
flrg = model.flrgs[key]
orig = model.sets[key].centroid
ax.plot([ct+1],[orig],'o')
xticks.append(ct+1)
xtickslabels.append(key)
for rhs in flrg.RHS:
dest = model.sets[rhs].centroid
ax.arrow(ct+1.1, orig, 0.8, dest - orig, #length_includes_head=True,
head_width=0.1, head_length=0.1, shape='full', overhang=0,
fc='k', ec='k')
else:
flrg = model.flrgs[key]
disp = (ct%rules_by_axis)*model.order + 1
for ct2, lhs in enumerate(flrg.LHS):
orig = model.sets[lhs].centroid
ax.plot([disp+ct2], [orig], 'o')
xticks.append(disp+ct2)
xtickslabels.append(lhs)
for ct2 in range(1, model.order):
fs1 = flrg.LHS[ct2-1]
fs2 = flrg.LHS[ct2]
orig = model.sets[fs1].centroid
dest = model.sets[fs2].centroid
ax.plot([disp+ct2-1,disp+ct2], [orig,dest],'-')
orig = model.sets[flrg.LHS[-1]].centroid
for rhs in flrg.RHS:
dest = model.sets[rhs].centroid
ax.arrow(disp + model.order -1 + .1, orig, 0.8, dest - orig, # length_includes_head=True,
head_width=0.1, head_length=0.1, shape='full', overhang=0,
fc='k', ec='k')
ax.set_xticks(xticks)
ax.set_xticklabels(xtickslabels)
plt.tight_layout()
plt.show()
def draw_sets_on_axis(axis, model, size):
if axis is None:
fig, axis = plt.subplots(nrows=1, ncols=1, figsize=size)
for ct, key in enumerate(model.partitioner.ordered_sets):
fs = model.sets[key]
axis.plot([0, 1, 0], fs.parameters, label=fs.name)
axis.axhline(fs.centroid, c="lightgray", alpha=0.5)
axis.set_xlim([0, len(model.partitioner.ordered_sets)])
axis.set_xticks(range(0, len(model.partitioner.ordered_sets)))
tmp = ['']
@ -27,23 +94,6 @@ def plot_rules(model, size=[5, 5], axis=None):
axis.set_yticklabels([str(round(model.sets[k].centroid, 1)) + " - " + k
for k in model.partitioner.ordered_sets])
if not model.is_high_order:
for ct, key in enumerate(model.partitioner.ordered_sets):
if key in model.flrgs:
flrg = model.flrgs[key]
orig = model.sets[key].centroid
axis.plot([ct+1],[orig],'o')
for rhs in flrg.RHS:
dest = model.sets[rhs].centroid
axis.arrow(ct+1.1, orig, 0.8, dest - orig, #length_includes_head=True,
head_width=0.1, head_length=0.1, shape='full', overhang=0,
fc='k', ec='k')
plt.tight_layout()
plt.show()
print("fim")
current_milli_time = lambda: int(round(time.time() * 1000))

View File

@ -235,7 +235,7 @@ class FTS(object):
if 'partitioner' in kwargs:
self.partitioner = kwargs.pop('partitioner')
if (self.sets is None or len(self.sets) == 0) and not self.benchmark_only:
if (self.sets is None or len(self.sets) == 0) and not self.benchmark_only and not self.is_multivariate:
if self.partitioner is not None:
self.sets = self.partitioner.sets
else:

View File

@ -1,10 +1,36 @@
import numpy as np
from pyFTS.models import chen
from pyFTS.models import hofts
from pyFTS.models.nonstationary import common,nsfts
from pyFTS.common import FLR
from pyFTS.common import FLR, flrg, tree
class HighOrderNonstationaryFLRG(hofts.HighOrderFTS):
"""Conventional High Order Fuzzy Logical Relationship Group"""
def __init__(self, order, **kwargs):
super(HighOrderNonstationaryFLRG, self).__init__(order, **kwargs)
self.LHS = []
self.RHS = {}
self.strlhs = ""
def append_rhs(self, c, **kwargs):
if c not in self.RHS:
self.RHS[c] = c
def append_lhs(self, c):
self.LHS.append(c)
def __str__(self):
tmp = ""
for c in sorted(self.RHS):
if len(tmp) > 0:
tmp = tmp + ","
tmp = tmp + c
return self.get_key() + " -> " + tmp
class ConditionalVarianceFTS(chen.ConventionalFTS):
def __len__(self):
return len(self.RHS)
class ConditionalVarianceFTS(hofts.HighOrderFTS):
def __init__(self, **kwargs):
super(ConditionalVarianceFTS, self).__init__(**kwargs)
self.name = "Conditional Variance FTS"
@ -17,6 +43,8 @@ class ConditionalVarianceFTS(chen.ConventionalFTS):
self.min_stack = [0,0,0]
self.max_stack = [0,0,0]
self.uod_clip = False
self.order = 1
self.min_order = 1
def train(self, ndata, **kwargs):
@ -32,6 +60,7 @@ class ConditionalVarianceFTS(chen.ConventionalFTS):
self.flrgs[flr.LHS.name] = nsfts.ConventionalNonStationaryFLRG(flr.LHS)
self.flrgs[flr.LHS.name].append_rhs(flr.RHS)
def _smooth(self, a):
return .1 * a[0] + .3 * a[1] + .6 * a[2]

View File

@ -28,11 +28,14 @@ from pyFTS.benchmarks import benchmarks as bchmk, Util as bUtil, Measures, knn,
from pyFTS.models import pwfts, song, chen, ifts, hofts
from pyFTS.models.ensemble import ensemble
model = chen.ConventionalFTS(partitioner=partitioner)
#model = chen.ConventionalFTS(partitioner=partitioner)
model = hofts.HighOrderFTS(partitioner=partitioner,order=2)
model.append_transformation(tdiff)
model.fit(dataset[:800])
cUtil.plot_rules(model)
print(model)
cUtil.plot_rules(model, size=[20,20], rules_by_axis=6, columns=1)
'''
model = knn.KNearestNeighbors(order=3)

View File

@ -59,7 +59,7 @@ model1.target_variable = vavg
#model.fit(train, num_batches=60, save=True, batch_save=True, file_path='mvfts_sonda')
model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=True,
nodes=['192.168.1.35'], batch_save_interval=10)
nodes=['192.168.0.110'], batch_save_interval=10)
#model = Util.load_obj('mvfts_sonda')