Module monk.tf_keras_1.models.initializers
Expand source code
from tf_keras_1.models.imports import *
from system.imports import *
def get_initializer(initializer):
'''
Get the right initializer for custom network weight initialization
Args:
initializer (str): The type of initializer
Returns:
str: The type of initializer
'''
if(initializer == "xavier_normal"):
return "glorot_normal";
elif(initializer == "xavier_uniform"):
return "glorot_uniform";
elif(initializer == "random_uniform"):
return "random_uniform";
elif(initializer == "random_normal"):
return "random_normal";
elif(initializer == "lecun_uniform"):
return "lecun_uniform";
elif(initializer == "lecun_normal"):
return "lecun_normal";
elif(initializer == "he_normal"):
return "he_normal";
elif(initializer == "he_uniform"):
return "he_uniform";
elif(initializer == "truncated_normal"):
return "truncated_normal";
elif(initializer == "orthogonal"):
return "orthogonal";
elif(initializer == "variance_scaling"):
return "VarianceScaling";
Functions
def get_initializer(initializer)
-
Get the right initializer for custom network weight initialization
Args
initializer
:str
- The type of initializer
Returns
str
- The type of initializer
Expand source code
def get_initializer(initializer): ''' Get the right initializer for custom network weight initialization Args: initializer (str): The type of initializer Returns: str: The type of initializer ''' if(initializer == "xavier_normal"): return "glorot_normal"; elif(initializer == "xavier_uniform"): return "glorot_uniform"; elif(initializer == "random_uniform"): return "random_uniform"; elif(initializer == "random_normal"): return "random_normal"; elif(initializer == "lecun_uniform"): return "lecun_uniform"; elif(initializer == "lecun_normal"): return "lecun_normal"; elif(initializer == "he_normal"): return "he_normal"; elif(initializer == "he_uniform"): return "he_uniform"; elif(initializer == "truncated_normal"): return "truncated_normal"; elif(initializer == "orthogonal"): return "orthogonal"; elif(initializer == "variance_scaling"): return "VarianceScaling";