当前位置 : 主页 > 编程语言 > java >

pytorch使用ImageFolder和random_split读取和划分数据集

来源:互联网 收集:自由互联 发布时间:2022-10-26
1. 最近重新学习torch知识,想实现对自己的数据集的封装和划分,由于自己的数据集格式如图所示 层级结构: |---data |---amazon |---images |---back_pack |---frame_0001.jpg |---frame_0002.jpg |---frame_00


1. 最近重新学习torch知识,想实现对自己的数据集的封装和划分,由于自己的数据集格式如图所示

pytorch使用ImageFolder和random_split读取和划分数据集_深度学习


层级结构:

|---data
|---amazon
|---images
|---back_pack
|---frame_0001.jpg
|---frame_0002.jpg
|---frame_0002.jpg
...

2. 首先,如果数据集层级结构是这样的格式,则可以进行如下方式处理

import torch
import torch.utils.data
from torchvision import transforms,datasets


# 定义transforms的一些操作
data_transform = transforms.Compose([
# Resize后数据的大小为224 * 224
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# 数据标准化,采用的图片标准化参数
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 使用ImageFolder去读取,返回后的数据路径和标签对应起来
all_dataset = datasets.ImageFolder('../data/amazon/images', transform=data_transform)

# 使用random_split实现数据集的划分,lengths是一个list,按照对应的数量返回数据个数。
# 这儿需要注意的是,lengths的数据量总和等于all_dataset中的数据个数,这儿不是按比例划分的
train, test, valid = torch.utils.data.random_split(dataset= all_dataset, lengths=[2000, 417, 400])

# 接着按照正常方式使用DataLoader读取数据,返回的是DataLoader对象
train = torch.utils.data.DataLoader(train, batch_size=4, shuffle=True, num_workers=4)
test = torch.utils.data.DataLoader(test, batch_size=4, shuffle=True, num_workers=4)
valid = torch.utils.data.DataLoader(valid, batch_size=4, shuffle=True, num_workers=4)

3. 进行遍历数据

# 使用迭代器进行迭代数据进行查看,如果这儿报错:The “freeze_support()” line can be omitted if the program
# is not going to be frozen to produce an executable
# 需要将你要运行的代码块放到main函数中运行即可
for step, (x, y) in enumerate(train):
print(step)
print(x.size())
print(y.size())
print(x)
break

输出结果如图:

pytorch使用ImageFolder和random_split读取和划分数据集_数据集_02

4. 总结:

一开始自己是写代码实现数据的读取,划分,并封装成DataLoader,殊不知还有这么好的库函数供使用。。。

库函数处理的思路过程如图:

pytorch使用ImageFolder和random_split读取和划分数据集_python_03

5. 如有问题可以留言~


网友评论