当前位置 : 主页 > 编程语言 > java >

计算机视觉PyTorch迁移学习 - (一)

来源:互联网 收集:自由互联 发布时间:2022-06-30
如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常

如何在只有6万张图像的MNIST训练数据集上训练模型。学术界当下使用最广泛的大规模图像数据集ImageNet,它有超过1,000万的图像和1,000类的物体。然而,我们平常接触到数据集的规模通常在这两者之间。假设我们想从图像中识别出不同种类的椅子,然后将购买链接推荐给用户。一种可能的方法是先找出100种常见的椅子,为每种椅子拍摄1,000张不同角度的图像,然后在收集到的图像数据集上训练一个分类模型。另外一种解决办法是应用迁移学习(transfer learning),将从源数据集学到的知识迁移到目标数据集上。例如,虽然ImageNet数据集的图像大多跟椅子无关,但在该数据集上训练的模型可以抽取较通用的图像特征,从而能够帮助识别边缘、纹理、形状和物体组成等。这些类似的特征对于识别椅子也可能同样有效


图像迁移学习

  • ​​1.迁移学习原理与流程​​
  • ​​1.1微调​​
  • ​​1.2特征提取​​
  • ​​2.图像增强​​
  • ​​2.1比例缩放​​
  • ​​2.2位置裁剪​​
  • ​​2.3水平/垂直翻转​​
  • ​​2.4角度旋转​​
  • ​​2.5色度、亮度、饱和度、对比度​​
  • ​​2.6灰度化​​
  • ​​2.7Padding​​
  • ​​2.8模型中图像增强数据预处理​​

1.迁移学习原理与流程

图像迁移学习一共分两类:

1.1微调

选择使用Imagenet数据集训练好的模型,更新模型中所有参数

  • 在源数据集(如ImageNet数据集)上预训练一个神经网络模型,即源模型。
  • 创建一个新的神经网络模型,即目标模型。它复制了源模型上除了输出层外的所有模型设计及其参数。我们假设这些模型参数包含了源数据集上学习到的知识,且这些知识同样适用于目标数据集。我们还假设源模型的输出层跟源数据集的标签紧密相关,因此在目标模型中不予采用。
  • 为目标模型添加一个输出大小为目标数据集类别个数的输出层,并随机初始化该层的模型参数。
  • 在目标数据集(如椅子数据集)上训练目标模型。我们将从头训练输出层,而其余层的参数都是基于源模型的参数微调得到的。
    计算机视觉PyTorch迁移学习 - (一)_数据集
    当目标数据集远小于源数据集时,微调有助于提升模型的泛化能力
  • 1.2特征提取

    选择使用Imagenet数据集训练好的模型,更新模型预测的最后一层的参数
    流程与微调相似,只是更新参数的层不同。

    2.图像增强

    图像增强(image augmentation)指通过剪切、旋转/反射/翻转变换、缩放变换、平移变换、尺度变换、对比度变换、噪声扰动、颜色变换等一种或多种组合数据增强变换的方式来增加数据集的大小。图像增强的意义是通过对训练图像做一系列随机改变,来产生相似但又不同的训练样本,从而扩大训练数据集的规模,而且随机改变训练样本可以降低模型对某些属性的依赖,从而提高模型的泛化能力
    原始图像

    import matplotlib.pyplot as plt
    import cv2 as cv
    from torchvision import transforms as transforms

    image="data.jpg"
    img=cv.imread(image)
    b,g,r=cv.split(img)
    img=cv.merge([r,g,b])
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_ico_02

    2.1比例缩放

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    resize = transforms.Resize([125,125])
    img = resize(img)
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_计算机视觉_03

    2.2位置裁剪

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    crop = transforms.RandomCrop([100,100])
    img = crop(img)
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_数据集_04

    2.3水平/垂直翻转

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    HF = transforms.RandomHorizontalFlip()
    imgHF = HF(img)
    VF = transforms.RandomVerticalFlip()
    imgVF = VF(img)
    title=['水平','垂直']
    img=[imgHF,imgVF]
    for i in range(2):
    plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
    plt.title(title[i])
    plt.xticks([]), plt.yticks([])
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_显示中文_05

    2.4角度旋转

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    rotation = transforms.RandomRotation(45)
    img = rotation(img)
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_数据集_06

    2.5色度、亮度、饱和度、对比度

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    #色度
    transform_1=transforms.ColorJitter(brightness=1)
    img_1=transform_1(img)
    #亮度
    transform_2=transforms.ColorJitter(contrast=1)
    img_2=transform_2(img)
    #饱和度
    transform_3=transforms.ColorJitter(saturation=0.5)
    img_3=transform_3(img)
    #对比度
    transform_4=transforms.ColorJitter(hue=0.5)
    img_4=transform_4(img)

    title=['色度','亮度','饱和度','对比度']
    img=[img_1,img_2,img_3,img_4]
    for i in range(2):
    plt.subplot(1, 2, i + 1), plt.imshow(img[i], 'gray')
    plt.title(title[i])
    plt.xticks([]), plt.yticks([])
    for i in range(2):
    plt.subplot(2, 2, i + 3), plt.imshow(img[i+2], 'gray')
    plt.title(title[i+2])
    plt.xticks([]), plt.yticks([])
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_数据集_07
    计算机视觉PyTorch迁移学习 - (一)_计算机视觉_08

    2.6灰度化

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    gray = transforms.RandomGrayscale(p=0.5)
    img = gray(img)
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_ico_09

    2.7Padding

    import matplotlib.pyplot as plt
    from torchvision import transforms as transforms
    import PIL.Image as Image

    image="data.jpg"
    img=Image.open(image)
    # 用来正常显示中文标签
    plt.rcParams['font.sans-serif'] = ['Arial Unicode MS'] # 指定默认字体

    plt.rcParams['axes.unicode_minus'] = False
    pad = transforms.Pad((0,(img.size[0]-img.size[1])//2))
    img = pad(img)
    plt.title('天气之子')
    plt.axis('off')
    plt.imshow(img)
    plt.show()

    计算机视觉PyTorch迁移学习 - (一)_ico_10

    2.8模型中图像增强数据预处理

    datatrian=transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize(image_size),
    transforms.CenterCrop(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])


    上一篇:机器学习面试题 (一)
    下一篇:没有了
    网友评论