From 560dcfacf64bd0f552b2d8761271a6bc95fcc103 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Petr=C3=B4nio=20C=C3=A2ndido=20de=20Lima=20e=20Silva?= Date: Fri, 26 May 2023 14:21:12 -0300 Subject: [PATCH] Update Activations.py - weights --- pyFTS/common/Activations.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/pyFTS/common/Activations.py b/pyFTS/common/Activations.py index 1ca0aba..b5b8ad1 100644 --- a/pyFTS/common/Activations.py +++ b/pyFTS/common/Activations.py @@ -6,14 +6,14 @@ import numpy as np import math from pyFTS import * -def scale(dist : dict) -> dict: +def scale(dist : dict, weights : dict) -> dict: norm = np.sum([v for v in dist.values()]) - return {k : (v / norm) for k,v in dist.items() } + return {k : ((v * weights[k]) / norm) for k,v in dist.items() } -def softmax(dist : dict) -> dict: +def softmax(dist : dict, weights : dict) -> dict: norm = np.sum([np.exp(v) for v in dist.values()]) - return {k : (np.exp(v) / norm) for k,v in dist.items() } + return {k : (np.exp(v * weights[k]) / norm) for k,v in dist.items() } -def argmax(dist : dict) -> str: - mx = np.max([v for v in dist.values()]) - return [k for k,v in dist.items() if v == mx ][0] \ No newline at end of file +def argmax(dist : dict, weights : dict) -> str: + mx = np.max([v * weights[k] for k,v in dist.items()]) + return [k for k,v in dist.items() if v * weights[k] == mx ][0]