common.Util.plot_rules bugfix

This commit is contained in:
Petrônio Cândido 2018-06-13 10:14:18 -03:00
parent 409f9d5b6b
commit bd4c0c432e
2 changed files with 24 additions and 17 deletions

View File

@ -22,6 +22,9 @@ def plot_rules(model, size=[5, 5], axis=None, rules_by_axis=None, columns=1):
for ct, key in enumerate(_lhs): for ct, key in enumerate(_lhs):
xticks = []
xtickslabels = []
if rules_by_axis is None: if rules_by_axis is None:
ax = axis ax = axis
else: else:
@ -31,47 +34,47 @@ def plot_rules(model, size=[5, 5], axis=None, rules_by_axis=None, columns=1):
ax = axis[rowcount, colcount] if columns > 1 else axis[rowcount] ax = axis[rowcount, colcount] if columns > 1 else axis[rowcount]
if ct % rules_by_axis == 0: if ct % rules_by_axis == 0:
xticks = []
xtickslabels = []
draw_sets_on_axis(ax, model, size) draw_sets_on_axis(ax, model, size)
if not model.is_high_order: if not model.is_high_order:
if key in model.flrgs: if key in model.flrgs:
x = (ct % rules_by_axis) + 1
flrg = model.flrgs[key] flrg = model.flrgs[key]
orig = model.sets[key].centroid y = model.sets[key].centroid
ax.plot([ct+1],[orig],'o') ax.plot([x],[y],'o')
xticks.append(ct+1) xticks.append(x)
xtickslabels.append(key) xtickslabels.append(key)
for rhs in flrg.RHS: for rhs in flrg.RHS:
dest = model.sets[rhs].centroid dest = model.sets[rhs].centroid
ax.arrow(ct+1.1, orig, 0.8, dest - orig, #length_includes_head=True, ax.arrow(x+.1, y, 0.8, dest - y, #length_includes_head=True,
head_width=0.1, head_length=0.1, shape='full', overhang=0, head_width=0.1, head_length=0.1, shape='full', overhang=0,
fc='k', ec='k') fc='k', ec='k')
else: else:
flrg = model.flrgs[key] flrg = model.flrgs[key]
disp = (ct%rules_by_axis)*model.order + 1 x = (ct%rules_by_axis)*model.order + 1
for ct2, lhs in enumerate(flrg.LHS): for ct2, lhs in enumerate(flrg.LHS):
orig = model.sets[lhs].centroid y = model.sets[lhs].centroid
ax.plot([disp+ct2], [orig], 'o') ax.plot([x+ct2], [y], 'o')
xticks.append(disp+ct2) xticks.append(x+ct2)
xtickslabels.append(lhs) xtickslabels.append(lhs)
for ct2 in range(1, model.order): for ct2 in range(1, model.order):
fs1 = flrg.LHS[ct2-1] fs1 = flrg.LHS[ct2-1]
fs2 = flrg.LHS[ct2] fs2 = flrg.LHS[ct2]
orig = model.sets[fs1].centroid y = model.sets[fs1].centroid
dest = model.sets[fs2].centroid dest = model.sets[fs2].centroid
ax.plot([disp+ct2-1,disp+ct2], [orig,dest],'-') ax.plot([x+ct2-1,x+ct2], [y,dest],'-')
orig = model.sets[flrg.LHS[-1]].centroid y = model.sets[flrg.LHS[-1]].centroid
for rhs in flrg.RHS: for rhs in flrg.RHS:
dest = model.sets[rhs].centroid dest = model.sets[rhs].centroid
ax.arrow(disp + model.order -1 + .1, orig, 0.8, dest - orig, # length_includes_head=True, ax.arrow(x + model.order -1 + .1, y, 0.8, dest - y, # length_includes_head=True,
head_width=0.1, head_length=0.1, shape='full', overhang=0, head_width=0.1, head_length=0.1, shape='full', overhang=0,
fc='k', ec='k') fc='k', ec='k')
ax.set_xticks(xticks) ax.set_xticks(xticks)
ax.set_xticklabels(xtickslabels) ax.set_xticklabels(xtickslabels)
ax.set_xlim([0,rules_by_axis*model.order+1])
plt.tight_layout() plt.tight_layout()
plt.show() plt.show()

View File

@ -19,7 +19,7 @@ dataset = TAIEX.get_data()
#print(len(dataset)) #print(len(dataset))
from pyFTS.partitioners import Grid, Util as pUtil from pyFTS.partitioners import Grid, Util as pUtil
partitioner = Grid.GridPartitioner(data=dataset[:800], npart=10, transformation=tdiff) partitioner = Grid.GridPartitioner(data=dataset[:800], npart=10)#, transformation=tdiff)
from pyFTS.common import Util as cUtil from pyFTS.common import Util as cUtil
@ -30,12 +30,16 @@ from pyFTS.models.ensemble import ensemble
#model = chen.ConventionalFTS(partitioner=partitioner) #model = chen.ConventionalFTS(partitioner=partitioner)
model = hofts.HighOrderFTS(partitioner=partitioner,order=2) model = hofts.HighOrderFTS(partitioner=partitioner,order=2)
model.append_transformation(tdiff) #model.append_transformation(tdiff)
model.fit(dataset[:800]) model.fit(dataset[:800])
cUtil.plot_rules(model, size=[20,20], rules_by_axis=5, columns=1)
print(model) print(model)
cUtil.plot_rules(model, size=[20,20], rules_by_axis=6, columns=1) print("fim")
''' '''
model = knn.KNearestNeighbors(order=3) model = knn.KNearestNeighbors(order=3)