Update Activations.py - weights

This commit is contained in:
Petrônio Cândido de Lima e Silva 2023-05-26 14:21:12 -03:00 committed by GitHub
parent b8cfea5278
commit 560dcfacf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,14 +6,14 @@ import numpy as np
import math import math
from pyFTS import * from pyFTS import *
def scale(dist : dict) -> dict: def scale(dist : dict, weights : dict) -> dict:
norm = np.sum([v for v in dist.values()]) norm = np.sum([v for v in dist.values()])
return {k : (v / norm) for k,v in dist.items() } return {k : ((v * weights[k]) / norm) for k,v in dist.items() }
def softmax(dist : dict) -> dict: def softmax(dist : dict, weights : dict) -> dict:
norm = np.sum([np.exp(v) for v in dist.values()]) norm = np.sum([np.exp(v) for v in dist.values()])
return {k : (np.exp(v) / norm) for k,v in dist.items() } return {k : (np.exp(v * weights[k]) / norm) for k,v in dist.items() }
def argmax(dist : dict) -> str: def argmax(dist : dict, weights : dict) -> str:
mx = np.max([v for v in dist.values()]) mx = np.max([v * weights[k] for k,v in dist.items()])
return [k for k,v in dist.items() if v == mx ][0] return [k for k,v in dist.items() if v * weights[k] == mx ][0]