当前位置 : 主页 > 编程语言 > python >

pytorch中dataloader 的sampler 参数详解

来源:互联网 收集:自由互联 发布时间:2023-01-30
目录 1. dataloader() 初始化函数 2. shuffle 与sample 之间的关系 3. sample 的定义方法 3.1 sampler 参数的使用 4. batch 生成过程 1. dataloader() 初始化函数 def __init__(self, dataset, batch_size=1, shuffle=False,
目录
  • 1. dataloader() 初始化函数
  • 2. shuffle 与sample 之间的关系
  • 3. sample 的定义方法
    • 3.1 sampler 参数的使用
  • 4. batch 生成过程

    1. dataloader() 初始化函数

     def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None,
     batch_sampler=None, num_workers=0, collate_fn=None,
    pin_memory=False, drop_last=False, timeout=0,
                     worker_init_fn=None, multiprocessing_context=None):
    

    其中几个常用的参数:

    • dataset 数据集,map-style and iterable-style 可以用index取值的对象、
    • batch_size 大小
    • shuffle 取batch是否随机取, 默认为False
    • sampler 定义取batch的方法,是一个迭代器, 每次生成一个key 用于读取dataset中的值
    • batch_sampler 也是一个迭代器, 每次生次一个batch_size的key
    • num_workers 参与工作的线程数collate_fn 对取出的batch进行处理
    • drop_last 对最后不足batchsize的数据的处理方法

    下面看两段取自DataLoader中的__init__代码, 帮助我们理解几个常用参数之间的关系

    2. shuffle 与sample 之间的关系

    当我们sampler有输入时,shuffle的值就没有意义,

    	if sampler is None:  # give default samplers
    	    if self._dataset_kind == _DatasetKind.Iterable:
    	        # See NOTE [ Custom Samplers and IterableDataset ]
    	        sampler = _InfiniteConstantSampler()
    	    else:  # map-style
    	        if shuffle:
    	            sampler = RandomSampler(dataset)
    	        else:
    	            sampler = SequentialSampler(dataset)

    当dataset类型是map style时, shuffle其实就是改变sampler的取值

    • shuffle为默认值 False时,sampler是SequentialSampler,就是按顺序取样,
    • shuffle为True时,sampler是RandomSampler, 就是按随机取样

    3. sample 的定义方法

    3.1 sampler 参数的使用

    sampler 是用来定义取batch方法的一个函数或者类,返回的是一个迭代器。

    我们可以看下自带的RandomSampler类中最重要的iter函数

        def __iter__(self):
            n = len(self.data_source)
            # dataset的长度, 按顺序索引
            if self.replacement:# 对应的replace参数
                return iter(torch.randint(high=n, size=(self.num_samples,), dtype=torch.int64).tolist())
            return iter(torch.randperm(n).tolist())        

    可以看出,其实就是生成索引,然后随机的取值, 然后再迭代。

    其实还有一些细节需要注意理解:

    比如__len__函数,包括DataLoader的len和sample的len, 两者区别, 这部分代码比较简单,可以自行阅读,其实参考着RandomSampler写也不会出现问题。
    比如,迭代器和生成器的使用, 以及区别

        if batch_size is not None and batch_sampler is None:
            # auto_collation without custom batch_sampler
            batch_sampler = BatchSampler(sampler, batch_size, drop_last)
            
        self.sampler = sampler
        self.batch_sampler = batch_sampler

    BatchSampler的生成过程:

    # 略去类的初始化
        def __iter__(self):
            batch = []
            for idx in self.sampler:
                batch.append(idx)
                if len(batch) == self.batch_size:
                    yield batch
                    batch = []
            if len(batch) > 0 and not self.drop_last:
                yield batch

    就是按batch_size从sampler中读取索引, 并形成生成器返回。

    以上可以看出, batch_sampler和sampler, batch_size, drop_last之间的关系

    • 如果batch_sampler没有定义的话且batch_size有定义, 会根据sampler, batch_size, drop_last生成一个batch_sampler
    • 自带的注释中对batch_sampler有一句话: Mutually exclusive with :attr:batch_size :attr:shuffle, :attr:sampler, and :attr:drop_last.
    • 意思就是b
    • atch_sampler 与这些参数冲突 ,即 如果你定义了batch_sampler, 其他参数都不需要有

    4. batch 生成过程

    每个batch都是由迭代器产生的:

    # DataLoader中iter的部分
        def __iter__(self):
            if self.num_workers == 0:
                return _SingleProcessDataLoaderIter(self)
            else:
                return _MultiProcessingDataLoaderIter(self)
    
    # 再看调用的另一个类
    class _SingleProcessDataLoaderIter(_BaseDataLoaderIter):
        def __init__(self, loader):
            super(_SingleProcessDataLoaderIter, self).__init__(loader)
            assert self._timeout == 0
            assert self._num_workers == 0
    
            self._dataset_fetcher = _DatasetKind.create_fetcher(
                self._dataset_kind, self._dataset, self._auto_collation, self._collate_fn, self._drop_last)
    
        def __next__(self):
            index = self._next_index()  
            data = self._dataset_fetcher.fetch(index)  
            if self._pin_memory:
                data = _utils.pin_memory.pin_memory(data)
            return data

    到此这篇关于pytorch中dataloader 的sampler 参数详解的文章就介绍到这了,更多相关pytorch sampler 内容请搜索自由互联以前的文章或继续浏览下面的相关文章希望大家以后多多支持自由互联!

    上一篇:OpenCV 读取图像imread的使用详解
    下一篇:没有了
    网友评论