Module monk.tf_keras_1.models.models
Expand source code
from tf_keras_1.models.imports import *
from system.imports import *
from tf_keras_1.models.common import set_parameter_requires_grad
#classifier 6
set1 = ["mobilenet", "densenet121", "densenet169", "densenet201", "inception_v3",
"inception_resnet_v3", "mobilenet_v2", "nasnet_mobile", "nasnet_large", "resnet50",
"resnet101", "resnet152", "resnet50_v2", "resnet101_v2", "resnet152_v2", "vgg16",
"vgg19", "xception"];
combined_list = set1
combined_list_lower = list(map(str.lower, combined_list))
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network, input_size):
'''
Get base network for transfer learning based on parameters selected
Args:
model_name (str): Select from available models. Check via List_Models() function
freeze_base_network (bool): If set as True, then base network's weights are freezed (cannot be trained)
use_gpu (bool): If set as True, uses GPU
use_pretrained (bool): If set as True, use weights trained on imagenet and coco like dataset
Else, use randomly initialized weights.
Returns:
neural network: Base network
str: Name of the model
'''
if(model_name not in combined_list_lower):
print("Model name: {} not found".format(model_name));
else:
index = combined_list_lower.index(model_name);
model_name = combined_list[index];
if(use_pretrained):
weights="imagenet";
else:
weights=None;
if(model_name == "mobilenet"):
from keras.applications import MobileNet as keras_model
elif(model_name == "densenet121"):
from keras.applications import DenseNet121 as keras_model
elif(model_name == "densenet169"):
from keras.applications import DenseNet169 as keras_model
elif(model_name == "densenet201"):
from keras.applications import DenseNet201 as keras_model
elif(model_name == "inception_v3"):
from keras.applications import InceptionV3 as keras_model
elif(model_name == "inception_resnet_v3"):
from keras.applications import InceptionResNetV2 as keras_model
elif(model_name == "mobilenet_v2"):
from keras.applications import MobileNetV2 as keras_model
elif(model_name == "nasnet_mobile"):
from keras.applications import NASNetMobile as keras_model
elif(model_name == "nasnet_large"):
from keras.applications import NASNetLarge as keras_model
elif(model_name == "resnet50"):
from keras.applications import ResNet50 as keras_model
elif(model_name == "resnet101"):
from keras.applications import ResNet101 as keras_model
elif(model_name == "resnet152"):
from keras.applications import ResNet152 as keras_model
elif(model_name == "resnet50_v2"):
from keras.applications import ResNet50V2 as keras_model
elif(model_name == "resnet101_v2"):
from keras.applications import ResNet101V2 as keras_model
elif(model_name == "resnet152_v2"):
from keras.applications import ResNet152V2 as keras_model
elif(model_name == "vgg16"):
from keras.applications import VGG16 as keras_model
elif(model_name == "vgg19"):
from keras.applications import VGG19 as keras_model
elif(model_name == "xception"):
from keras.applications import Xception as keras_model
finetune_net = keras_model(weights=weights, include_top=False, input_shape=(input_size, input_size, 3));
finetune_net = set_parameter_requires_grad(finetune_net, freeze_base_network);
return finetune_net, model_name;
Functions
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network, input_size)
-
Get base network for transfer learning based on parameters selected
Args
model_name
:str
- Select from available models. Check via List_Models() function
freeze_base_network
:bool
- If set as True, then base network's weights are freezed (cannot be trained)
use_gpu
:bool
- If set as True, uses GPU
use_pretrained
:bool
- If set as True, use weights trained on imagenet and coco like dataset Else, use randomly initialized weights.
Returns
neural
network
:Base
network
str
- Name of the model
Expand source code
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network, input_size): ''' Get base network for transfer learning based on parameters selected Args: model_name (str): Select from available models. Check via List_Models() function freeze_base_network (bool): If set as True, then base network's weights are freezed (cannot be trained) use_gpu (bool): If set as True, uses GPU use_pretrained (bool): If set as True, use weights trained on imagenet and coco like dataset Else, use randomly initialized weights. Returns: neural network: Base network str: Name of the model ''' if(model_name not in combined_list_lower): print("Model name: {} not found".format(model_name)); else: index = combined_list_lower.index(model_name); model_name = combined_list[index]; if(use_pretrained): weights="imagenet"; else: weights=None; if(model_name == "mobilenet"): from keras.applications import MobileNet as keras_model elif(model_name == "densenet121"): from keras.applications import DenseNet121 as keras_model elif(model_name == "densenet169"): from keras.applications import DenseNet169 as keras_model elif(model_name == "densenet201"): from keras.applications import DenseNet201 as keras_model elif(model_name == "inception_v3"): from keras.applications import InceptionV3 as keras_model elif(model_name == "inception_resnet_v3"): from keras.applications import InceptionResNetV2 as keras_model elif(model_name == "mobilenet_v2"): from keras.applications import MobileNetV2 as keras_model elif(model_name == "nasnet_mobile"): from keras.applications import NASNetMobile as keras_model elif(model_name == "nasnet_large"): from keras.applications import NASNetLarge as keras_model elif(model_name == "resnet50"): from keras.applications import ResNet50 as keras_model elif(model_name == "resnet101"): from keras.applications import ResNet101 as keras_model elif(model_name == "resnet152"): from keras.applications import ResNet152 as keras_model elif(model_name == "resnet50_v2"): from keras.applications import ResNet50V2 as keras_model elif(model_name == "resnet101_v2"): from keras.applications import ResNet101V2 as keras_model elif(model_name == "resnet152_v2"): from keras.applications import ResNet152V2 as keras_model elif(model_name == "vgg16"): from keras.applications import VGG16 as keras_model elif(model_name == "vgg19"): from keras.applications import VGG19 as keras_model elif(model_name == "xception"): from keras.applications import Xception as keras_model finetune_net = keras_model(weights=weights, include_top=False, input_shape=(input_size, input_size, 3)); finetune_net = set_parameter_requires_grad(finetune_net, freeze_base_network); return finetune_net, model_name;