1、dataset:(数据类型 dataset)
输入的数据类型,这里是原始数据的输入。PyTorch内也有这种数据结构。
2、batch_size:(数据类型 int)
批训练数据量的大小,根据具体情况设置即可(默认:1)。PyTorch训练模型时调用数据不是一行一行进行的(这样太没效率),而是一捆一捆来的。这里就是定义每次喂给神经网络多少行数据,如果设置成1,那就是一行一行进行(个人偏好,PyTorch默认设置是1)。每次是随机读取大小为batch_size。如果dataset中的数据个数不是batch_size的整数倍,这最后一次把剩余的数据全部输出。若想把剩下的不足batch size个的数据丢弃,则将drop_last设置为True,会将多出来不足一个batch的数据丢弃。
3、shuffle:(数据类型 bool)
洗牌。默认设置为False。在每次迭代训练时是否将数据洗牌,默认设置是False。将输入数据的顺序打乱,是为了使数据更有独立性,但如果数据是有序列特征的,就不要设置成True了。
4、collate_fn:(数据类型 callable,没见过的类型)
将一小段数据合并成数据列表,默认设置是False。如果设置成True,系统会在返回前会将张量数据(Tensors)复制到CUDA内存中。
5、batch_sampler:(数据类型 Sampler)
批量采样,默认设置为None。但每次返回的是一批数据的索引(注意:不是数据)。其和batch_size、shuffle 、sampler and drop_last参数是不兼容的。我想,应该是每次输入网络的数据是随机采样模式,这样能使数据更具有独立性质。所以,它和一捆一捆按顺序输入,数据洗牌,数据采样,等模式是不兼容的。
6、sampler:(数据类型 Sampler)
采样,默认设置为None。根据定义的策略从数据集中采样输入。如果定义采样规则,则洗牌(shuffle)设置必须为False。
7、num_workers:(数据类型 Int)
工作者数量,默认是0。使用多少个子进程来导入数据。设置为0,就是使用主进程来导入数据。注意:这个数字必须是大于等于0的,负数估计会出错。
8、pin_memory:(数据类型 bool)
内存寄存,默认为False。在数据返回前,是否将数据复制到CUDA内存中。
9、drop_last:(数据类型 bool)
丢弃最后数据,默认为False。设置了 batch_size 的数目后,最后一批数据未必是设置的数目,有可能会小些。这时你是否需要丢弃这批数据。
10、timeout:(数据类型 numeric)
超时,默认为0。是用来设置数据读取的超时时间的,但超过这个时间还没读取到数据的话就会报错。 所以,数值必须大于等于0。
11、worker_init_fn(数据类型 callable,没见过的类型)
子进程导入模式,默认为Noun。在数据导入前和步长结束后,根据工作子进程的ID逐个按顺序导入数据。
对batch_size举例分析:
""" 批训练,把数据变成一小批一小批数据进行训练。 DataLoader就是用来包装所使用的数据,每次抛出一批数据 """ import torch import torch.utils.data as Data BATCH_SIZE = 5 x = torch.linspace(1, 11, 11) y = torch.linspace(11, 1, 11) print(x) print(y) # 把数据放在数据库中 torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( # 从数据库中每次抽出batch size个样本 dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, # num_workers=2, ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) if __name__ == '__main__': show_batch()
输出为:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
tensor([11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
steop:0, batch_x:tensor([ 3., 2., 8., 11., 1.]), batch_y:tensor([ 9., 10., 4., 1., 11.])
steop:1, batch_x:tensor([ 5., 6., 7., 4., 10.]), batch_y:tensor([7., 6., 5., 8., 2.])
steop:2, batch_x:tensor([9.]), batch_y:tensor([3.])
steop:0, batch_x:tensor([ 9., 7., 10., 2., 4.]), batch_y:tensor([ 3., 5., 2., 10., 8.])
steop:1, batch_x:tensor([ 5., 11., 3., 6., 8.]), batch_y:tensor([7., 1., 9., 6., 4.])
steop:2, batch_x:tensor([1.]), batch_y:tensor([11.])
steop:0, batch_x:tensor([10., 5., 7., 4., 2.]), batch_y:tensor([ 2., 7., 5., 8., 10.])
steop:1, batch_x:tensor([3., 9., 1., 8., 6.]), batch_y:tensor([ 9., 3., 11., 4., 6.])
steop:2, batch_x:tensor([11.]), batch_y:tensor([1.])
Process finished with exit code 0
若drop_last=True
""" 批训练,把数据变成一小批一小批数据进行训练。 DataLoader就是用来包装所使用的数据,每次抛出一批数据 """ import torch import torch.utils.data as Data BATCH_SIZE = 5 x = torch.linspace(1, 11, 11) y = torch.linspace(11, 1, 11) print(x) print(y) # 把数据放在数据库中 torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( # 从数据库中每次抽出batch size个样本 dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, # num_workers=2, drop_last=True, ) def show_batch(): for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): # training print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y)) if __name__ == '__main__': show_batch()
对应的输出为:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11.])
tensor([11., 10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
steop:0, batch_x:tensor([ 9., 2., 7., 4., 11.]), batch_y:tensor([ 3., 10., 5., 8., 1.])
steop:1, batch_x:tensor([ 3., 5., 10., 1., 8.]), batch_y:tensor([ 9., 7., 2., 11., 4.])
steop:0, batch_x:tensor([ 5., 11., 6., 1., 2.]), batch_y:tensor([ 7., 1., 6., 11., 10.])
steop:1, batch_x:tensor([ 3., 4., 10., 8., 9.]), batch_y:tensor([9., 8., 2., 4., 3.])
steop:0, batch_x:tensor([10., 4., 9., 8., 7.]), batch_y:tensor([2., 8., 3., 4., 5.])
steop:1, batch_x:tensor([ 6., 1., 11., 2., 5.]), batch_y:tensor([ 6., 11., 1., 10., 7.])
Process finished with exit code 0
总结
到此这篇关于PyTorch中torch.utils.data.DataLoader的文章就介绍到这了,更多相关PyTorch torch.utils.data.DataLoader内容请搜索自由互联以前的文章或继续浏览下面的相关文章希望大家以后多多支持自由互联!