图像迁移学习 3.PyTorch实现迁移学习 3.1数据集预处理 3.2构建模型 3.3模型训练与验证 3.PyTorch实现迁移学习 文件目录 3.1数据集预处理 这里实现一个蚂
图像迁移学习
- 3.PyTorch实现迁移学习
- 3.1数据集预处理
- 3.2构建模型
- 3.3模型训练与验证
3.PyTorch实现迁移学习
文件目录
3.1数据集预处理
这里实现一个蚂蚁与蜜蜂的图像分类,用到的数据集data下载
dataset.py
import torch
train=transforms.Compose([
transforms.RandomResizedCrop(224), # 随机裁剪一个area然后再resize
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
trainset=datasets.ImageFolder(root='hymenoptera_data/train',transform=train)
valset=datasets.ImageFolder(root='hymenoptera_data/val',transform=val)
trainloader=torch.utils.data.DataLoader(trainset,batch_size=4,
shuffle=True, num_workers=4)
valloader=torch.utils.data.DataLoader(valset,batch_size=4,
shuffle=True, num_workers=4)
3.2构建模型
model.py
from torchvision import modelsimport torch.nn as nn
#初始化模型
#保证模型不改变的层的参数,不发生梯度变化
def set_parameter_requires_grad(model, feature_extracting):
if feature_extracting:
for param in model.parameters():
param.requires_grad = False
def initialize_model(model_name, num_classes, feature_extract):
model_ft=None
input_size=0
if model_name =='resnet':
#resnet18
model_ft = models.resnet18(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "alexnet":
model_ft = models.alexnet(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "vgg":
#vgg11
model_ft = models.vgg11_bn(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier[6].in_features
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "squeezenet":
model_ft = models.squeezenet1_0(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
model_ft.classifier[1] = nn.Conv2d(512, num_classes, kernel_size=(1, 1), stride=(1, 1))
model_ft.num_classes = num_classes
input_size = 224
elif model_name == "densenet":
model_ft = models.densenet121(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
num_ftrs = model_ft.classifier.in_features
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
input_size = 224
elif model_name == "inception":
model_ft = models.inception_v3(pretrained=True)
set_parameter_requires_grad(model_ft, feature_extract)
# Handle the auxilary net
num_ftrs = model_ft.AuxLogits.fc.in_features
model_ft.AuxLogits.fc = nn.Linear(num_ftrs, num_classes)
# Handle the primary net
num_ftrs = model_ft.fc.in_features
model_ft.fc = nn.Linear(num_ftrs, num_classes)
input_size = 299
else:
print("没有合适的模型...")
return model_ft, input_size
3.3模型训练与验证
run.py
from __future__ import print_functionfrom __future__ import division
import torch.nn as nn
import torch.optim as optim
from model import initialize_model
from torch.optim import lr_scheduler
import time
import copy
from dataset import *
import argparse
parser=argparse.ArgumentParser()
#模型选择
parser.add_argument('-m','--model_name',type=str,choices=['resnet', 'alexnet', 'vgg', 'squeezenet', 'densenet', 'inception'],help="input model_name",default='resnet')
#分类类别数
parser.add_argument('-n','--num_classes',type=int,help="input num_classes",default=2)
#定义一个批次的样本数
parser.add_argument('-b','--batch_size',type=int,help="input batch_size",default=8)
#定义迭代批次
parser.add_argument('-e','--num_epochs',type=int,help="input num_epochs",default=25)
args=parser.parse_args()
#用于特征提取的标志。如果为False,则对整个模型进行微调,
#如果为True,则仅更新重塑的图层参数
feature_extract = True
#定义数据字典
datasets={train:trainset,val:valset}
#定义数据集字典
dataloaders={train:trainloader,val:valloader}
model_ft, input_size = initialize_model(args.model_name, args.num_classes, feature_extract)
criterion = nn.CrossEntropyLoss()
# 观察所有参数都正在优化
optimizer_ft = optim.SGD(model_ft.parameters(), lr=0.001, momentum=0.9)
# 每7个epochs衰减LR通过设置gamma=0.1
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=7, gamma=0.1)
def train_model(model,criterion,optimizer,scheduler,num_epochs):
since=time.time()
val_acc_history = []
#获取模型初始参数
best_model_wts=copy.deepcopy(model.state_dict())
best_acc=0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch,num_epochs-1))
print('-'*10)
for data in ['train','val']:
if data=='train':
scheduler.step()
model.train()
else:
model.eval()
running_loss = 0.0
running_corrects = 0
for inputs,labels in dataloaders[data]:
optimizer.zero_grad()
with torch.set_grad_enabled(data=='train'):
outputs=model(inputs)
_,preds=torch.max(outputs,1)
loss=criterion(outputs,labels)
if data=='train':
loss.backward()
optimizer.step()
running_loss+=loss.item()*inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(datasets[data])
epoch_acc = running_corrects.double() / len(datasets[data])
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
data, epoch_loss, epoch_acc))
# 深度复制mo
if data=='val' and epoch_acc > best_acc:
best_acc = epoch_acc
best_model_wts = copy.deepcopy(model.state_dict())
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(
time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
model.load_state_dict(best_model_wts)
return model
train_model(model_ft,criterion, optimizer_ft, exp_lr_scheduler,args.num_epochs)