diff --git a/pyFTS/common/Activations.py b/pyFTS/common/Activations.py index 1ca0aba..b5b8ad1 100644 --- a/pyFTS/common/Activations.py +++ b/pyFTS/common/Activations.py @@ -6,14 +6,14 @@ import numpy as np import math from pyFTS import * -def scale(dist : dict) -> dict: +def scale(dist : dict, weights : dict) -> dict: norm = np.sum([v for v in dist.values()]) - return {k : (v / norm) for k,v in dist.items() } + return {k : ((v * weights[k]) / norm) for k,v in dist.items() } -def softmax(dist : dict) -> dict: +def softmax(dist : dict, weights : dict) -> dict: norm = np.sum([np.exp(v) for v in dist.values()]) - return {k : (np.exp(v) / norm) for k,v in dist.items() } + return {k : (np.exp(v * weights[k]) / norm) for k,v in dist.items() } -def argmax(dist : dict) -> str: - mx = np.max([v for v in dist.values()]) - return [k for k,v in dist.items() if v == mx ][0] \ No newline at end of file +def argmax(dist : dict, weights : dict) -> str: + mx = np.max([v * weights[k] for k,v in dist.items()]) + return [k for k,v in dist.items() if v * weights[k] == mx ][0]