Update Activations.py - weights
This commit is contained in:
parent
b8cfea5278
commit
560dcfacf6
@ -6,14 +6,14 @@ import numpy as np
|
||||
import math
|
||||
from pyFTS import *
|
||||
|
||||
def scale(dist : dict) -> dict:
|
||||
def scale(dist : dict, weights : dict) -> dict:
|
||||
norm = np.sum([v for v in dist.values()])
|
||||
return {k : (v / norm) for k,v in dist.items() }
|
||||
return {k : ((v * weights[k]) / norm) for k,v in dist.items() }
|
||||
|
||||
def softmax(dist : dict) -> dict:
|
||||
def softmax(dist : dict, weights : dict) -> dict:
|
||||
norm = np.sum([np.exp(v) for v in dist.values()])
|
||||
return {k : (np.exp(v) / norm) for k,v in dist.items() }
|
||||
return {k : (np.exp(v * weights[k]) / norm) for k,v in dist.items() }
|
||||
|
||||
def argmax(dist : dict) -> str:
|
||||
mx = np.max([v for v in dist.values()])
|
||||
return [k for k,v in dist.items() if v == mx ][0]
|
||||
def argmax(dist : dict, weights : dict) -> str:
|
||||
mx = np.max([v * weights[k] for k,v in dist.items()])
|
||||
return [k for k,v in dist.items() if v * weights[k] == mx ][0]
|
||||
|
Loading…
Reference in New Issue
Block a user