Module monk.pytorch.models.models
Expand source code
from pytorch.models.imports import *
from system.imports import *
from pytorch.models.common import set_parameter_requires_grad
#classifier 6
set1 = ["alexnet", "vgg11", "vgg11_bn", "vgg13", "vgg13_bn", "vgg16", "vgg16_bn", "vgg19", "vgg19_bn"]
#classifier
set2 = ["densenet121", "densenet161", "densenet169", "densenet201"]
#fc
set3 = ["googlenet", "inception_v3", "resnet18", "resnet34", "resnet50", "resnet101", "resnet152", "resnext50_32x4d", "resnext101_32x8d",
"shufflenet_v2_x0_5", "shufflenet_v2_x1_0, shufflenet_v2_x1_5", "shufflenet_v2_x2_0", "wide_resnet101_2", "wide_resnet50_2"]
#classifier 1
set4 = ["mnasnet0_5", "mnasnet0_75", "mnasnet1_0", "mnasnet1_3", "mobilenet_v2", "squeezenet1_0", "squeezenet1_1"]
combined_list = set1+set2+set3+set4
combined_list_lower = list(map(str.lower, combined_list))
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network):
'''
Get base network for transfer learning based on parameters selected
Args:
model_name (str): Select from available models. Check via List_Models() function
freeze_base_network (bool): If set as True, then base network's weights are freezed (cannot be trained)
use_gpu (bool): If set as True, uses GPU
use_pretrained (bool): If set as True, use weights trained on imagenet and coco like dataset
Else, use randomly initialized weights.
Returns:
neural network: Base network
str: Name of the model
'''
if(model_name not in combined_list_lower):
print("Model name: {} not found".format(model_name));
else:
index = combined_list_lower.index(model_name);
model_name = combined_list[index];
if(model_name == "alexnet"):
finetune_net = torchvision.models.alexnet(pretrained=use_pretrained);
elif(model_name == "vgg11"):
finetune_net = torchvision.models.vgg11(pretrained=use_pretrained);
elif(model_name == "vgg11_bn"):
finetune_net = torchvision.models.vgg11_bn(pretrained=use_pretrained);
elif(model_name == "vgg13"):
finetune_net = torchvision.models.vgg11(pretrained=use_pretrained);
elif(model_name == "vgg13_bn"):
finetune_net = torchvision.models.vgg11_bn(pretrained=use_pretrained);
elif(model_name == "vgg16"):
finetune_net = torchvision.models.vgg16(pretrained=use_pretrained);
elif(model_name == "vgg16_bn"):
finetune_net = torchvision.models.vgg16_bn(pretrained=use_pretrained);
elif(model_name == "vgg19"):
finetune_net = torchvision.models.vgg19(pretrained=use_pretrained);
elif(model_name == "vgg19_bn"):
finetune_net = torchvision.models.vgg19_bn(pretrained=use_pretrained);
elif(model_name == "densenet121"):
finetune_net = torchvision.models.densenet121(pretrained=use_pretrained);
elif(model_name == "densenet161"):
finetune_net = torchvision.models.densenet161(pretrained=use_pretrained);
elif(model_name == "densenet169"):
finetune_net = torchvision.models.densenet169(pretrained=use_pretrained);
elif(model_name == "densenet201"):
finetune_net = torchvision.models.densenet201(pretrained=use_pretrained);
elif(model_name == "googlenet"):
finetune_net = torchvision.models.googlenet(pretrained=use_pretrained);
elif(model_name == "inception_v3"):
finetune_net = torchvision.models.inception_v3(pretrained=use_pretrained);
elif(model_name == "resnet18"):
finetune_net = torchvision.models.resnet18(pretrained=use_pretrained);
elif(model_name == "resnet34"):
finetune_net = torchvision.models.resnet34(pretrained=use_pretrained);
elif(model_name == "resnet50"):
finetune_net = torchvision.models.resnet50(pretrained=use_pretrained);
elif(model_name == "resnet101"):
finetune_net = torchvision.models.resnet101(pretrained=use_pretrained);
elif(model_name == "resnet152"):
finetune_net = torchvision.models.resnet152(pretrained=use_pretrained);
elif(model_name == "resnext50_32x4d"):
finetune_net = torchvision.models.resnext50_32x4d(pretrained=use_pretrained);
elif(model_name == "resnext101_32x8d"):
finetune_net = torchvision.models.resnext101_32x8d(pretrained=use_pretrained);
elif(model_name == "shufflenet_v2_x0_5"):
finetune_net = torchvision.models.shufflenet_v2_x0_5(pretrained=use_pretrained);
elif(model_name == "shufflenet_v2_x1_0"):
finetune_net = torchvision.models.shufflenet_v2_x1_0(pretrained=use_pretrained);
elif(model_name == "shufflenet_v2_x1_5"):
if(use_pretrained):
msg = "Pretrained model Unavailable for {}.\n".format(model_name);
msg += "Using xavier initialization";
ConstraintWarning(msg);
finetune_net = torchvision.models.shufflenet_v2_x1_5(pretrained=False);
elif(model_name == "shufflenet_v2_x2_0"):
if(use_pretrained):
msg = "Pretrained model Unavailable for {}.\n".format(model_name);
msg += "Using xavier initialization";
ConstraintWarning(msg);
finetune_net = torchvision.models.shufflenet_v2_x2_0(pretrained=False);
elif(model_name == "wide_resnet101_2"):
finetune_net = torchvision.models.wide_resnet101_2(pretrained=use_pretrained);
elif(model_name == "wide_resnet50_2"):
finetune_net = torchvision.models.wide_resnet50_2(pretrained=use_pretrained);
elif(model_name == "mnasnet0_5"):
finetune_net = torchvision.models.mnasnet0_5(pretrained=use_pretrained);
elif(model_name == "mnasnet0_75"):
if(use_pretrained):
msg = "Pretrained model Unavailable for {}.\n".format(model_name);
msg += "Using xavier initialization";
ConstraintWarning(msg);
finetune_net = torchvision.models.mnasnet0_75(pretrained=False);
elif(model_name == "mnasnet1_0"):
finetune_net = torchvision.models.mnasnet1_0(pretrained=use_pretrained);
elif(model_name == "mnasnet1_3"):
if(use_pretrained):
msg = "Pretrained model Unavailable for {}.\n".format(model_name);
msg += "Using xavier initialization";
ConstraintWarning(msg);
finetune_net = torchvision.models.mnasnet1_3(pretrained=False);
elif(model_name == "mobilenet_v2"):
finetune_net = torchvision.models.mobilenet_v2(pretrained=use_pretrained);
elif(model_name == "squeezenet1_0"):
finetune_net = torchvision.models.squeezenet1_0(pretrained=use_pretrained);
elif(model_name == "squeezenet1_1"):
finetune_net = torchvision.models.squeezenet1_1(pretrained=use_pretrained);
finetune_net = set_parameter_requires_grad(finetune_net, freeze_base_network);
return finetune_net, model_name;
Functions
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network)
-
Get base network for transfer learning based on parameters selected
Args
model_name
:str
- Select from available models. Check via List_Models() function
freeze_base_network
:bool
- If set as True, then base network's weights are freezed (cannot be trained)
use_gpu
:bool
- If set as True, uses GPU
use_pretrained
:bool
- If set as True, use weights trained on imagenet and coco like dataset Else, use randomly initialized weights.
Returns
neural
network
:Base
network
str
- Name of the model
Expand source code
def get_base_model(model_name, use_pretrained, num_classes, freeze_base_network): ''' Get base network for transfer learning based on parameters selected Args: model_name (str): Select from available models. Check via List_Models() function freeze_base_network (bool): If set as True, then base network's weights are freezed (cannot be trained) use_gpu (bool): If set as True, uses GPU use_pretrained (bool): If set as True, use weights trained on imagenet and coco like dataset Else, use randomly initialized weights. Returns: neural network: Base network str: Name of the model ''' if(model_name not in combined_list_lower): print("Model name: {} not found".format(model_name)); else: index = combined_list_lower.index(model_name); model_name = combined_list[index]; if(model_name == "alexnet"): finetune_net = torchvision.models.alexnet(pretrained=use_pretrained); elif(model_name == "vgg11"): finetune_net = torchvision.models.vgg11(pretrained=use_pretrained); elif(model_name == "vgg11_bn"): finetune_net = torchvision.models.vgg11_bn(pretrained=use_pretrained); elif(model_name == "vgg13"): finetune_net = torchvision.models.vgg11(pretrained=use_pretrained); elif(model_name == "vgg13_bn"): finetune_net = torchvision.models.vgg11_bn(pretrained=use_pretrained); elif(model_name == "vgg16"): finetune_net = torchvision.models.vgg16(pretrained=use_pretrained); elif(model_name == "vgg16_bn"): finetune_net = torchvision.models.vgg16_bn(pretrained=use_pretrained); elif(model_name == "vgg19"): finetune_net = torchvision.models.vgg19(pretrained=use_pretrained); elif(model_name == "vgg19_bn"): finetune_net = torchvision.models.vgg19_bn(pretrained=use_pretrained); elif(model_name == "densenet121"): finetune_net = torchvision.models.densenet121(pretrained=use_pretrained); elif(model_name == "densenet161"): finetune_net = torchvision.models.densenet161(pretrained=use_pretrained); elif(model_name == "densenet169"): finetune_net = torchvision.models.densenet169(pretrained=use_pretrained); elif(model_name == "densenet201"): finetune_net = torchvision.models.densenet201(pretrained=use_pretrained); elif(model_name == "googlenet"): finetune_net = torchvision.models.googlenet(pretrained=use_pretrained); elif(model_name == "inception_v3"): finetune_net = torchvision.models.inception_v3(pretrained=use_pretrained); elif(model_name == "resnet18"): finetune_net = torchvision.models.resnet18(pretrained=use_pretrained); elif(model_name == "resnet34"): finetune_net = torchvision.models.resnet34(pretrained=use_pretrained); elif(model_name == "resnet50"): finetune_net = torchvision.models.resnet50(pretrained=use_pretrained); elif(model_name == "resnet101"): finetune_net = torchvision.models.resnet101(pretrained=use_pretrained); elif(model_name == "resnet152"): finetune_net = torchvision.models.resnet152(pretrained=use_pretrained); elif(model_name == "resnext50_32x4d"): finetune_net = torchvision.models.resnext50_32x4d(pretrained=use_pretrained); elif(model_name == "resnext101_32x8d"): finetune_net = torchvision.models.resnext101_32x8d(pretrained=use_pretrained); elif(model_name == "shufflenet_v2_x0_5"): finetune_net = torchvision.models.shufflenet_v2_x0_5(pretrained=use_pretrained); elif(model_name == "shufflenet_v2_x1_0"): finetune_net = torchvision.models.shufflenet_v2_x1_0(pretrained=use_pretrained); elif(model_name == "shufflenet_v2_x1_5"): if(use_pretrained): msg = "Pretrained model Unavailable for {}.\n".format(model_name); msg += "Using xavier initialization"; ConstraintWarning(msg); finetune_net = torchvision.models.shufflenet_v2_x1_5(pretrained=False); elif(model_name == "shufflenet_v2_x2_0"): if(use_pretrained): msg = "Pretrained model Unavailable for {}.\n".format(model_name); msg += "Using xavier initialization"; ConstraintWarning(msg); finetune_net = torchvision.models.shufflenet_v2_x2_0(pretrained=False); elif(model_name == "wide_resnet101_2"): finetune_net = torchvision.models.wide_resnet101_2(pretrained=use_pretrained); elif(model_name == "wide_resnet50_2"): finetune_net = torchvision.models.wide_resnet50_2(pretrained=use_pretrained); elif(model_name == "mnasnet0_5"): finetune_net = torchvision.models.mnasnet0_5(pretrained=use_pretrained); elif(model_name == "mnasnet0_75"): if(use_pretrained): msg = "Pretrained model Unavailable for {}.\n".format(model_name); msg += "Using xavier initialization"; ConstraintWarning(msg); finetune_net = torchvision.models.mnasnet0_75(pretrained=False); elif(model_name == "mnasnet1_0"): finetune_net = torchvision.models.mnasnet1_0(pretrained=use_pretrained); elif(model_name == "mnasnet1_3"): if(use_pretrained): msg = "Pretrained model Unavailable for {}.\n".format(model_name); msg += "Using xavier initialization"; ConstraintWarning(msg); finetune_net = torchvision.models.mnasnet1_3(pretrained=False); elif(model_name == "mobilenet_v2"): finetune_net = torchvision.models.mobilenet_v2(pretrained=use_pretrained); elif(model_name == "squeezenet1_0"): finetune_net = torchvision.models.squeezenet1_0(pretrained=use_pretrained); elif(model_name == "squeezenet1_1"): finetune_net = torchvision.models.squeezenet1_1(pretrained=use_pretrained); finetune_net = set_parameter_requires_grad(finetune_net, freeze_base_network); return finetune_net, model_name;