Update Activations.py - weights
This commit is contained in:
parent
b8cfea5278
commit
560dcfacf6
@ -6,14 +6,14 @@ import numpy as np
|
|||||||
import math
|
import math
|
||||||
from pyFTS import *
|
from pyFTS import *
|
||||||
|
|
||||||
def scale(dist : dict) -> dict:
|
def scale(dist : dict, weights : dict) -> dict:
|
||||||
norm = np.sum([v for v in dist.values()])
|
norm = np.sum([v for v in dist.values()])
|
||||||
return {k : (v / norm) for k,v in dist.items() }
|
return {k : ((v * weights[k]) / norm) for k,v in dist.items() }
|
||||||
|
|
||||||
def softmax(dist : dict) -> dict:
|
def softmax(dist : dict, weights : dict) -> dict:
|
||||||
norm = np.sum([np.exp(v) for v in dist.values()])
|
norm = np.sum([np.exp(v) for v in dist.values()])
|
||||||
return {k : (np.exp(v) / norm) for k,v in dist.items() }
|
return {k : (np.exp(v * weights[k]) / norm) for k,v in dist.items() }
|
||||||
|
|
||||||
def argmax(dist : dict) -> str:
|
def argmax(dist : dict, weights : dict) -> str:
|
||||||
mx = np.max([v for v in dist.values()])
|
mx = np.max([v * weights[k] for k,v in dist.items()])
|
||||||
return [k for k,v in dist.items() if v == mx ][0]
|
return [k for k,v in dist.items() if v * weights[k] == mx ][0]
|
||||||
|
Loading…
Reference in New Issue
Block a user