CVFTS improvements and bugfixes; FTS.fit bugfix for multivariate models; Util.plot_rules high order capability
This commit is contained in:
parent
a2002c20b1
commit
68a4a953b8
@ -8,42 +8,92 @@ import dill
|
||||
import numpy as np
|
||||
|
||||
|
||||
def plot_rules(model, size=[5, 5], axis=None):
|
||||
def plot_rules(model, size=[5, 5], axis=None, rules_by_axis=None, columns=1):
|
||||
if axis is None and rules_by_axis is None:
|
||||
fig, axis = plt.subplots(nrows=1, ncols=1, figsize=size)
|
||||
elif axis is None and rules_by_axis is not None:
|
||||
rows = (((len(model.flrgs.keys())//rules_by_axis)) // columns)+1
|
||||
fig, axis = plt.subplots(nrows=rows, ncols=columns, figsize=size)
|
||||
|
||||
if rules_by_axis is None:
|
||||
draw_sets_on_axis(axis, model, size)
|
||||
|
||||
_lhs = model.partitioner.ordered_sets if not model.is_high_order else model.flrgs.keys()
|
||||
|
||||
for ct, key in enumerate(_lhs):
|
||||
|
||||
if rules_by_axis is None:
|
||||
ax = axis
|
||||
else:
|
||||
colcount = (ct // rules_by_axis) % columns
|
||||
rowcount = (ct // rules_by_axis) // columns
|
||||
|
||||
ax = axis[rowcount, colcount] if columns > 1 else axis[rowcount]
|
||||
|
||||
if ct % rules_by_axis == 0:
|
||||
xticks = []
|
||||
xtickslabels = []
|
||||
draw_sets_on_axis(ax, model, size)
|
||||
|
||||
if not model.is_high_order:
|
||||
if key in model.flrgs:
|
||||
flrg = model.flrgs[key]
|
||||
orig = model.sets[key].centroid
|
||||
ax.plot([ct+1],[orig],'o')
|
||||
xticks.append(ct+1)
|
||||
xtickslabels.append(key)
|
||||
for rhs in flrg.RHS:
|
||||
dest = model.sets[rhs].centroid
|
||||
ax.arrow(ct+1.1, orig, 0.8, dest - orig, #length_includes_head=True,
|
||||
head_width=0.1, head_length=0.1, shape='full', overhang=0,
|
||||
fc='k', ec='k')
|
||||
else:
|
||||
flrg = model.flrgs[key]
|
||||
disp = (ct%rules_by_axis)*model.order + 1
|
||||
for ct2, lhs in enumerate(flrg.LHS):
|
||||
orig = model.sets[lhs].centroid
|
||||
ax.plot([disp+ct2], [orig], 'o')
|
||||
xticks.append(disp+ct2)
|
||||
xtickslabels.append(lhs)
|
||||
for ct2 in range(1, model.order):
|
||||
fs1 = flrg.LHS[ct2-1]
|
||||
fs2 = flrg.LHS[ct2]
|
||||
orig = model.sets[fs1].centroid
|
||||
dest = model.sets[fs2].centroid
|
||||
ax.plot([disp+ct2-1,disp+ct2], [orig,dest],'-')
|
||||
|
||||
orig = model.sets[flrg.LHS[-1]].centroid
|
||||
for rhs in flrg.RHS:
|
||||
dest = model.sets[rhs].centroid
|
||||
ax.arrow(disp + model.order -1 + .1, orig, 0.8, dest - orig, # length_includes_head=True,
|
||||
head_width=0.1, head_length=0.1, shape='full', overhang=0,
|
||||
fc='k', ec='k')
|
||||
|
||||
|
||||
ax.set_xticks(xticks)
|
||||
ax.set_xticklabels(xtickslabels)
|
||||
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
|
||||
|
||||
def draw_sets_on_axis(axis, model, size):
|
||||
if axis is None:
|
||||
fig, axis = plt.subplots(nrows=1, ncols=1, figsize=size)
|
||||
|
||||
for ct, key in enumerate(model.partitioner.ordered_sets):
|
||||
fs = model.sets[key]
|
||||
axis.plot([0, 1, 0], fs.parameters, label=fs.name)
|
||||
axis.axhline(fs.centroid, c="lightgray", alpha=0.5)
|
||||
|
||||
axis.set_xlim([0, len(model.partitioner.ordered_sets)])
|
||||
axis.set_xticks(range(0,len(model.partitioner.ordered_sets)))
|
||||
axis.set_xticks(range(0, len(model.partitioner.ordered_sets)))
|
||||
tmp = ['']
|
||||
tmp.extend(model.partitioner.ordered_sets)
|
||||
axis.set_xticklabels(tmp)
|
||||
axis.set_ylim([model.partitioner.min, model.partitioner.max])
|
||||
axis.set_yticks([model.sets[k].centroid for k in model.partitioner.ordered_sets])
|
||||
axis.set_yticklabels([str(round(model.sets[k].centroid,1)) + " - " + k
|
||||
axis.set_yticklabels([str(round(model.sets[k].centroid, 1)) + " - " + k
|
||||
for k in model.partitioner.ordered_sets])
|
||||
|
||||
if not model.is_high_order:
|
||||
for ct, key in enumerate(model.partitioner.ordered_sets):
|
||||
if key in model.flrgs:
|
||||
flrg = model.flrgs[key]
|
||||
orig = model.sets[key].centroid
|
||||
axis.plot([ct+1],[orig],'o')
|
||||
for rhs in flrg.RHS:
|
||||
dest = model.sets[rhs].centroid
|
||||
axis.arrow(ct+1.1, orig, 0.8, dest - orig, #length_includes_head=True,
|
||||
head_width=0.1, head_length=0.1, shape='full', overhang=0,
|
||||
fc='k', ec='k')
|
||||
plt.tight_layout()
|
||||
plt.show()
|
||||
print("fim")
|
||||
|
||||
|
||||
|
||||
|
||||
current_milli_time = lambda: int(round(time.time() * 1000))
|
||||
|
||||
|
@ -226,8 +226,8 @@ class FTS(object):
|
||||
else:
|
||||
data = self.apply_transformations(ndata)
|
||||
|
||||
self.original_min = np.nanmin(data)
|
||||
self.original_max = np.nanmax(data)
|
||||
self.original_min = np.nanmin(data)
|
||||
self.original_max = np.nanmax(data)
|
||||
|
||||
if 'sets' in kwargs:
|
||||
self.sets = kwargs.pop('sets')
|
||||
@ -235,7 +235,7 @@ class FTS(object):
|
||||
if 'partitioner' in kwargs:
|
||||
self.partitioner = kwargs.pop('partitioner')
|
||||
|
||||
if (self.sets is None or len(self.sets) == 0) and not self.benchmark_only:
|
||||
if (self.sets is None or len(self.sets) == 0) and not self.benchmark_only and not self.is_multivariate:
|
||||
if self.partitioner is not None:
|
||||
self.sets = self.partitioner.sets
|
||||
else:
|
||||
|
@ -1,10 +1,36 @@
|
||||
import numpy as np
|
||||
from pyFTS.models import chen
|
||||
from pyFTS.models import hofts
|
||||
from pyFTS.models.nonstationary import common,nsfts
|
||||
from pyFTS.common import FLR
|
||||
from pyFTS.common import FLR, flrg, tree
|
||||
|
||||
class HighOrderNonstationaryFLRG(hofts.HighOrderFTS):
|
||||
"""Conventional High Order Fuzzy Logical Relationship Group"""
|
||||
def __init__(self, order, **kwargs):
|
||||
super(HighOrderNonstationaryFLRG, self).__init__(order, **kwargs)
|
||||
self.LHS = []
|
||||
self.RHS = {}
|
||||
self.strlhs = ""
|
||||
|
||||
def append_rhs(self, c, **kwargs):
|
||||
if c not in self.RHS:
|
||||
self.RHS[c] = c
|
||||
|
||||
def append_lhs(self, c):
|
||||
self.LHS.append(c)
|
||||
|
||||
def __str__(self):
|
||||
tmp = ""
|
||||
for c in sorted(self.RHS):
|
||||
if len(tmp) > 0:
|
||||
tmp = tmp + ","
|
||||
tmp = tmp + c
|
||||
return self.get_key() + " -> " + tmp
|
||||
|
||||
|
||||
class ConditionalVarianceFTS(chen.ConventionalFTS):
|
||||
def __len__(self):
|
||||
return len(self.RHS)
|
||||
|
||||
class ConditionalVarianceFTS(hofts.HighOrderFTS):
|
||||
def __init__(self, **kwargs):
|
||||
super(ConditionalVarianceFTS, self).__init__(**kwargs)
|
||||
self.name = "Conditional Variance FTS"
|
||||
@ -17,6 +43,8 @@ class ConditionalVarianceFTS(chen.ConventionalFTS):
|
||||
self.min_stack = [0,0,0]
|
||||
self.max_stack = [0,0,0]
|
||||
self.uod_clip = False
|
||||
self.order = 1
|
||||
self.min_order = 1
|
||||
|
||||
def train(self, ndata, **kwargs):
|
||||
|
||||
@ -32,6 +60,7 @@ class ConditionalVarianceFTS(chen.ConventionalFTS):
|
||||
self.flrgs[flr.LHS.name] = nsfts.ConventionalNonStationaryFLRG(flr.LHS)
|
||||
self.flrgs[flr.LHS.name].append_rhs(flr.RHS)
|
||||
|
||||
|
||||
def _smooth(self, a):
|
||||
return .1 * a[0] + .3 * a[1] + .6 * a[2]
|
||||
|
||||
|
@ -28,11 +28,14 @@ from pyFTS.benchmarks import benchmarks as bchmk, Util as bUtil, Measures, knn,
|
||||
from pyFTS.models import pwfts, song, chen, ifts, hofts
|
||||
from pyFTS.models.ensemble import ensemble
|
||||
|
||||
model = chen.ConventionalFTS(partitioner=partitioner)
|
||||
#model = chen.ConventionalFTS(partitioner=partitioner)
|
||||
model = hofts.HighOrderFTS(partitioner=partitioner,order=2)
|
||||
model.append_transformation(tdiff)
|
||||
model.fit(dataset[:800])
|
||||
|
||||
cUtil.plot_rules(model)
|
||||
print(model)
|
||||
|
||||
cUtil.plot_rules(model, size=[20,20], rules_by_axis=6, columns=1)
|
||||
|
||||
'''
|
||||
model = knn.KNearestNeighbors(order=3)
|
||||
|
@ -59,7 +59,7 @@ model1.target_variable = vavg
|
||||
#model.fit(train, num_batches=60, save=True, batch_save=True, file_path='mvfts_sonda')
|
||||
|
||||
model1.fit(train, num_batches=200, save=True, batch_save=True, file_path='mvfts_sonda', distributed=True,
|
||||
nodes=['192.168.1.35'], batch_save_interval=10)
|
||||
nodes=['192.168.0.110'], batch_save_interval=10)
|
||||
|
||||
|
||||
#model = Util.load_obj('mvfts_sonda')
|
Loading…
Reference in New Issue
Block a user