如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常
如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效
图像迁移学习
- 1.迁移学习原理与流程
- 1.1微调
- 1.2特征提取
- 2.图像增强
- 2.1比例缩放
- 2.2位置裁剪
- 2.3水平/垂直翻转
- 2.4角度旋转
- 2.5色度、亮度、饱和度、对比度
- 2.6灰度化
- 2.7Padding
- 2.8模型中图像增强数据预处理
1.迁移学习原理与流程
图像迁移学习一共分两类:
1.1微调
选择使用Imagenet数据集训练好的模型,更新模型中所有参数
当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力
1.2特征提取
选择使用Imagenet数据集训练好的模型,更新模型预测的最后一层的参数
流程与微调相似,只是更新参数的层不同。
2.图像增强
图像增强(image augmentation)指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。图像增强的意义是通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模,而且随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力
原始图像
import cv2 as cv
from torchvision import transforms as transforms
image="data.jpg"
img=cv.imread(image)
b,g,r=cv.split(img)
img=cv.merge([r,g,b])
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.1比例缩放
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
resize = transforms.Resize([125,125])
img = resize(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.2位置裁剪
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
crop = transforms.RandomCrop([100,100])
img = crop(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.3水平/垂直翻转
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
HF = transforms.RandomHorizontalFlip()
imgHF = HF(img)
VF = transforms.RandomVerticalFlip()
imgVF = VF(img)
title=['水平','垂直']
img=[imgHF,imgVF]
for i in range(2):
plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
plt.title(title[i])
plt.xticks([]), plt.yticks([])
plt.show()
2.4角度旋转
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
rotation = transforms.RandomRotation(45)
img = rotation(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.5色度、亮度、饱和度、对比度
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
#色度
transform_1=transforms.ColorJitter(brightness=1)
img_1=transform_1(img)
#亮度
transform_2=transforms.ColorJitter(contrast=1)
img_2=transform_2(img)
#饱和度
transform_3=transforms.ColorJitter(saturation=0.5)
img_3=transform_3(img)
#对比度
transform_4=transforms.ColorJitter(hue=0.5)
img_4=transform_4(img)
title=['色度','亮度','饱和度','对比度']
img=[img_1,img_2,img_3,img_4]
for i in range(2):
plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
plt.title(title[i])
plt.xticks([]), plt.yticks([])
for i in range(2):
plt.subplot(2, 2, i + 3), plt.imshow(img[i+2], 'gray')
plt.title(title[i+2])
plt.xticks([]), plt.yticks([])
plt.show()
2.6灰度化
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
gray = transforms.RandomGrayscale(p=0.5)
img = gray(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.7Padding
import matplotlib.pyplot as pltfrom torchvision import transforms as transforms
import PIL.Image as Image
image="data.jpg"
img=Image.open(image)
# 用来正常显示中文标签
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体
plt.rcParams['axes.unicode_minus'] = False
pad = transforms.Pad((0,(img.size[0]-img.size[1])//2))
img = pad(img)
plt.title('天气之子')
plt.axis('off')
plt.imshow(img)
plt.show()
2.8模型中图像增强数据预处理
datatrian=transforms.Compose([transforms.RandomHorizontalFlip(),
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])