当前位置 : 主页 > 网络编程 > 其它编程 >

Pytorch图像处理篇:使用pytorch搭建ResNet并基于迁移学习训练

来源:互联网 收集:自由互联 发布时间:2023-07-02
model.pyimporttorch.nnasnnimporttorch#首先定义34层残差结构classBasicBlock(nn.Module):expansion1# model.py import torch.nn as nnimport torch#首先定义34层残差结构class BasicBlock(nn.Module):expansion 1 #对应主分支中卷积
model.pyimporttorch.nnasnnimporttorch#首先定义34层残差结构classBasicBlock(nn.Module):expansion1#

model.py

import torch.nn as nnimport torch#首先定义34层残差结构class BasicBlock(nn.Module):expansion 1 #对应主分支中卷积核的个数有没有发生变化#定义初始化函数输入特征矩阵的深度输出特征矩阵的深度主分支上卷积核的个数不惧默认设置为1下采样参数设置为Nonedef __init__(self, in_channel, out_channel, stride1, downsampleNone, **kwargs):super(BasicBlock, self).__init__()self.conv1 nn.Conv2d(in_channelsin_channel, out_channelsout_channel,kernel_size3, stridestride, padding1, biasFalse)self.bn1 nn.BatchNorm2d(out_channel)self.relu nn.ReLU()self.conv2 nn.Conv2d(in_channelsout_channel, out_channelsout_channel,kernel_size3, stride1, padding1, biasFalse)self.bn2 nn.BatchNorm2d(out_channel)self.downsample downsample#定义正向传播的过程def forward(self, x):identity xif self.downsample is not None:identity self.downsample(x) #将输入传入下采样函数得到捷径分支的输出#主分支上的输出out self.conv1(x)out self.bn1(out)out self.relu(out)out self.conv2(out)out self.bn2(out)out identity #将主分支上的输出加上捷径分支上的输出out self.relu(out)return out #得到残差结构的最终输出#定义50层、101层、152层的残差结构在这个网络上进行修改得到ResNext网络class Bottleneck(nn.Module):"""注意原论文中在虚线残差结构的主分支上第一个1x1卷积层的步距是2第二个3x3卷积层步距是1。但在pytorch官方实现过程中是第一个1x1卷积层的步距是1第二个3x3卷积层步距是2这么做的好处是能够在top1上提升大概0.5%的准确率。可参考Resnet v1.5 https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch"""expansion 4 #残差结构所使用卷积核的一个变化#定义初始化函数def __init__(self, in_channel, out_channel, stride1, downsampleNone,#相比resnet网络多传入了两个参数groups1, width_per_group64groups1, width_per_group64):super(Bottleneck, self).__init__()width int(out_channel * (width_per_group / 64.)) * groupsself.conv1 nn.Conv2d(in_channelsin_channel, out_channelswidth,kernel_size1, stride1, biasFalse) # squeeze channelsself.bn1 nn.BatchNorm2d(width)# -----------------------------------------#输入、输出特征矩阵的channel设置为widthself.conv2 nn.Conv2d(in_channelswidth, out_channelswidth, groupsgroups,kernel_size3, stridestride, biasFalse, padding1)self.bn2 nn.BatchNorm2d(width)# -----------------------------------------self.conv3 nn.Conv2d(in_channelswidth, out_channelsout_channel*self.expansion,kernel_size1, stride1, biasFalse) # unsqueeze channelsself.bn3 nn.BatchNorm2d(out_channel*self.expansion)self.relu nn.ReLU(inplaceTrue)self.downsample downsample#定义正向传播过程def forward(self, x):identity xif self.downsample is not None:identity self.downsample(x)out self.conv1(x)out self.bn1(out)out self.relu(out)out self.conv2(out)out self.bn2(out)out self.relu(out)out self.conv3(out)out self.bn3(out)out identityout self.relu(out)return out#定义ResNet网络模型class ResNet(nn.Module):def __init__(self,block,#对应的就是残差结构blocks_num, #所使用残差结构的数目num_classes1000, #训练集的分类个数include_topTrue,#是为了在ResNet网络上搭建更复杂的网络groups1,width_per_group64):super(ResNet, self).__init__()self.include_top include_top #传入类变量之中self.in_channel 64 #输入特征矩阵的深度self.groups groupsself.width_per_group width_per_group#定义第一层的卷积层3表示输入矩阵的深度self.conv1 nn.Conv2d(3, self.in_channel, kernel_size7, stride2,padding3, biasFalse)self.bn1 nn.BatchNorm2d(self.in_channel)self.relu nn.ReLU(inplaceTrue)self.maxpool nn.MaxPool2d(kernel_size3, stride2, padding1) #最大池化下采样操作self.layer1 self._make_layer(block, 64, blocks_num[0]) #一系列残差结构self.layer2 self._make_layer(block, 128, blocks_num[1], stride2)self.layer3 self._make_layer(block, 256, blocks_num[2], stride2)self.layer4 self._make_layer(block, 512, blocks_num[3], stride2)if self.include_top:self.avgpool nn.AdaptiveAvgPool2d((1, 1)) # 自适应的平均池化下采样output size (1, 1)self.fc nn.Linear(512 * block.expansion, num_classes) #全连接层for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, modefan_out, nonlinearityrelu)#哪一个残差结构残差结构中第一卷积层所使用卷积核的个数该层包含了几个残差结构步距为1def _make_layer(self, block, channel, block_num, stride1):downsample None #定义下采样if stride ! 1 or self.in_channel ! channel * block.expansion: #对于十八层和三十四层的残差结构就会跳过if语句downsample nn.Sequential( #生成下采样函数nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size1, stridestride, biasFalse),nn.BatchNorm2d(channel * block.expansion))layers [] #定义空的列表#将第一层的残差结构传进去layers.append(block(self.in_channel,channel,downsampledownsample,stridestride,groupsself.groups,width_per_groupself.width_per_group))self.in_channel channel * block.expansion#实现实线部分for _ in range(1, block_num): #表示从一开始遍历不写则默认是0层开始layers.append(block(self.in_channel,channel,groupsself.groups,width_per_groupself.width_per_group))return nn.Sequential(*layers) #非关键字参数的方式传入nn.squential函数#进行正向传播过程def forward(self, x):x self.conv1(x)x self.bn1(x)x self.relu(x)x self.maxpool(x)x self.layer1(x)x self.layer2(x)x self.layer3(x)x self.layer4(x)if self.include_top:x self.avgpool(x)x torch.flatten(x, 1)x self.fc(x)return x#对网络进行实例化传入BasicBlock或者Bottleneck来确定是哪个网络第二个参数是block的个数def resnet34(num_classes1000, include_topTrue):# https://download.pytorch.org/models/resnet34-333f7ec4.pthreturn ResNet(BasicBlock, [3, 4, 6, 3], num_classesnum_classes, include_topinclude_top)def resnet50(num_classes1000, include_topTrue):# https://download.pytorch.org/models/resnet50-19c8e357.pthreturn ResNet(Bottleneck, [3, 4, 6, 3], num_classesnum_classes, include_topinclude_top)def resnet101(num_classes1000, include_topTrue):# https://download.pytorch.org/models/resnet101-5d3b4d8f.pthreturn ResNet(Bottleneck, [3, 4, 23, 3], num_classesnum_classes, include_topinclude_top)#进行升级ResNext网络def resnext50_32x4d(num_classes1000, include_topTrue):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups 32width_per_group 4return ResNet(Bottleneck, [3, 4, 6, 3],num_classesnum_classes,include_topinclude_top,groupsgroups,width_per_groupwidth_per_group)def resnext101_32x8d(num_classes1000, include_topTrue):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups 32width_per_group 8return ResNet(Bottleneck, [3, 4, 23, 3],num_classesnum_classes,include_topinclude_top,groupsgroups,width_per_groupwidth_per_group)

train.py

import osimport sysimport jsonimport torchimport torch.nn as nnimport torch.optim as optimfrom torchvision import transforms, datasetsfrom tqdm import tqdmfrom model import resnet34 #要进行训练的话要导入需要的网络是resnet34还是rtesnet50或者其他网络def main():device torch.device("cuda:0" if torch.cuda.is_available() else "cpu")print("using {} device.".format(device))data_transform {"train": transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]),#进行标准化的方法的参数是参考官网"val": transforms.Compose([transforms.Resize(256),#先通过resize将最小遍缩放到256transforms.CenterCrop(224),#在使用中心裁剪transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])}data_root os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root pathimage_path os.path.join(data_root, "data_set", "flower_data") # flower data set pathassert os.path.exists(image_path), "{} path does not exist.".format(image_path)train_dataset datasets.ImageFolder(rootos.path.join(image_path, "train"),transformdata_transform["train"])train_num len(train_dataset)# {daisy:0, dandelion:1, roses:2, sunflower:3, tulips:4}flower_list train_dataset.class_to_idxcla_dict dict((val, key) for key, val in flower_list.items())# write dict into json filejson_str json.dumps(cla_dict, indent4)with open(class_indices.json, w) as json_file:json_file.write(json_str)batch_size 16nw min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workersprint(Using {} dataloader workers every process.format(nw))train_loader torch.utils.data.DataLoader(train_dataset,batch_sizebatch_size, shuffleTrue,num_workersnw)validate_dataset datasets.ImageFolder(rootos.path.join(image_path, "val"),transformdata_transform["val"])val_num len(validate_dataset)validate_loader torch.utils.data.DataLoader(validate_dataset,batch_sizebatch_size, shuffleFalse,num_workersnw)print("using {} images for training, {} images for validation.".format(train_num,val_num))#使用迁移学习的方法所以需要使用Pytorch官方所提供的resnet网络的预训练模型需要去下载net resnet34() #需要用哪个网络就实例化哪个网络# load pretrain weights# download url: https://download.pytorch.org/models/resnet34-333f7ec4.pthmodel_weight_path "./resnet34-pre.pth" #使用哪个权重就写哪个assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)net.load_state_dict(torch.load(model_weight_path, map_locationcpu))# for param in net.parameters():# param.requires_grad False#载入预训练模型的方法# change fc layer structurein_channel net.fc.in_featuresnet.fc nn.Linear(in_channel, 5) #将最后一个新连接层替换成自己的新建的一个全连接层5表示要分类的类别个数net.to(device)# define loss functionloss_function nn.CrossEntropyLoss()# construct an optimizerparams [p for p in net.parameters() if p.requires_grad]optimizer optim.Adam(params, lr0.0001)epochs 3best_acc 0.0save_path ./resNet34.pth #保存权重的名字也进行相应的修改train_steps len(train_loader)for epoch in range(epochs):# trainnet.train()#重要的running_loss 0.0train_bar tqdm(train_loader, filesys.stdout)for step, data in enumerate(train_bar):images, labels dataoptimizer.zero_grad()logits net(images.to(device))loss loss_function(logits, labels.to(device))loss.backward()optimizer.step()# print statisticsrunning_loss loss.item()train_bar.desc "train epoch[{}/{}] loss:{:.3f}".format(epoch 1,epochs,loss)# validatenet.eval()acc 0.0 # accumulate accurate number / epochwith torch.no_grad():val_bar tqdm(validate_loader, filesys.stdout)for val_data in val_bar:val_images, val_labels val_dataoutputs net(val_images.to(device))# loss loss_function(outputs, test_labels)predict_y torch.max(outputs, dim1)[1]acc torch.eq(predict_y, val_labels.to(device)).sum().item()val_bar.desc "valid epoch[{}/{}]".format(epoch 1,epochs)val_accurate acc / val_numprint([epoch %d] train_loss: %.3f val_accuracy: %.3f %(epoch 1, running_loss / train_steps, val_accurate))if val_accurate > best_acc:best_acc val_accuratetorch.save(net.state_dict(), save_path)print(Finished Training)if __name__ __main__:main()

predict.py

import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsimport matplotlib.pyplot as pltfrom model import resnet34def main():device torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load imageimg_path "../tulip.jpg"assert os.path.exists(img_path), "file: {} dose not exist.".format(img_path)img Image.open(img_path)plt.imshow(img)# [N, C, H, W]img data_transform(img)# expand batch dimensionimg torch.unsqueeze(img, dim0)# read class_indictjson_path ./class_indices.jsonassert os.path.exists(json_path), "file: {} dose not exist.".format(json_path)with open(json_path, "r") as f:class_indict json.load(f)# create model 使用哪个网络就传入哪个网络model resnet34(num_classes5).to(device)# load model weightsweights_path "./resNet34.pth" #权重进行相应的改变assert os.path.exists(weights_path), "file: {} dose not exist.".format(weights_path)model.load_state_dict(torch.load(weights_path, map_locationdevice))# predictionmodel.eval()with torch.no_grad():#不对损失梯度进行跟踪# predict classoutput torch.squeeze(model(img.to(device))).cpu()predict torch.softmax(output, dim0)predict_cla torch.argmax(predict).numpy() #最大值对相应的索引print_res "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],predict[predict_cla].numpy())plt.title(print_res)for i in range(len(predict)):print("class: {:10} prob: {:.3}".format(class_indict[str(i)],predict[i].numpy()))plt.show()if __name__ __main__:main()

batch_predict.py

#批量进行预测import osimport jsonimport torchfrom PIL import Imagefrom torchvision import transformsfrom model import resnet34def main():device torch.device("cuda:0" if torch.cuda.is_available() else "cpu")data_transform transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])# load image# 指向需要遍历预测的图像文件夹imgs_root "/data/imgs"assert os.path.exists(imgs_root), f"file: {imgs_root} dose not exist."# 读取指定文件夹下所有jpg图像路径img_path_list [os.path.join(imgs_root, i) for i in os.listdir(imgs_root) if i.endswith(".jpg")]# read class_indictjson_path ./class_indices.jsonassert os.path.exists(json_path), f"file: {json_path} dose not exist."json_file open(json_path, "r")class_indict json.load(json_file)# create modelmodel resnet34(num_classes5).to(device)# load model weightsweights_path "./resNet34.pth"assert os.path.exists(weights_path), f"file: {weights_path} dose not exist."model.load_state_dict(torch.load(weights_path, map_locationdevice))# predictionmodel.eval()batch_size 8 # 每次预测时将多少张图片打包成一个batchwith torch.no_grad():for ids in range(0, len(img_path_list) // batch_size):img_list []for img_path in img_path_list[ids * batch_size: (ids 1) * batch_size]:assert os.path.exists(img_path), f"file: {img_path} dose not exist."img Image.open(img_path)img data_transform(img)img_list.append(img)# batch img# 将img_list列表中的所有图像打包成一个batchbatch_img torch.stack(img_list, dim0)# predict classoutput model(batch_img.to(device)).cpu()predict torch.softmax(output, dim1)probs, classes torch.max(predict, dim1)for idx, (pro, cla) in enumerate(zip(probs, classes)):print("image: {} class: {} prob: {:.3}".format(img_path_list[ids * batch_size idx],class_indict[str(cla.numpy())],pro.numpy()))if __name__ __main__:main()

网友评论