当前位置 : 主页 > 编程语言 > python >

使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

来源:互联网 收集:自由互联 发布时间:2022-06-15
使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作 总共分为四步 构造一个​​my_dataset​​​类,继承自​​torch.utils.data.Dataset​​ 重写​​__getitem__​​​ 和​​__l


使用pytorch自定义DataSet,以加载图像数据集为例,实现一些骚操作

总共分为四步

  • 构造一个​​my_dataset​​​类,继承自​​torch.utils.data.Dataset​​
  • 重写​​__getitem__​​​ 和​​__len__​​ 类函数
  • 建立两个函数​​find_classes​​​、​​has_file_allowed_extension​​,直接从这copy过去
  • 建立​​my_make_dataset​​函数用来构造(path,lable)对

一、构造一个​​my_dataset​​​类,继承自​​torch.utils.data.Dataset​​

二、 重写​​__getitem__​​​ 和​​__len__​​ 类函数

要构造Dataset的子类,就必须要实现两个方法:

  • getitem_(self, index):根据index来返回数据集中标号为index的元素及其标签。
  • len_(self):返回数据集的长度。
class my_dataset(Dataset):
def __init__(self,root_original, root_cdtfed, transform=None):
super(my_dataset, self).__init__()
self.transform = transform
self.root_original = root_original
self.root_cdtfed = root_cdtfed
self.original_imgs = []
self.cdtfed_imgs = []

#add (img_path, label) to lists
self.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)
self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)

# super(my_dataset, self).__init__()
def __getitem__(self, index): #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
fn1, label1 = self.original_imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
fn2, label2 = self.cdtfed_imgs[index]

img1 = Image.open(fn1).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片
img2 = Image.open(fn2).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片

if self.transform is not None:
img1 = self.transform(img1) #是否进行transform
img2 = self.transform(img2) #是否进行transform
img_list = [img1, img2]
label = label1
name = fn1
return img_list,label,name #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
return len(self.original_imgs)

三、建立两个函数​​find_classes​​​、​​has_file_allowed_extension​​,直接从这copy过去

def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.

See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.

Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)

Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)
  • 建立​​my_make_dataset​​函数用来构造(path,lable)对
def my_make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).

See :class:`DatasetFolder` for details.

Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)

if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)

instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
# item = path, [int(cl) for cl in target_class.split('_')]
item = path, target_class
instances.append(item)

if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)

return instances #instance:[item:(path, int(class_name)), ]

附录:完整代码

我这里传入两个root_dir,因为我要用一个dataset加载两个数据集,分别放在data1和data2里

class my_dataset(Dataset):
def __init__(self,root_original, root_cdtfed, transform=None):
super(my_dataset, self).__init__()
self.transform = transform
self.root_original = root_original
self.root_cdtfed = root_cdtfed
self.original_imgs = []
self.cdtfed_imgs = []

#add (img_path, label) to lists
self.original_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)
self.cdtfed_imgs = my_make_dataset(root_original, class_to_idx=None, extensions=('.jpg', '.png'), is_valid_file=None)

# super(my_dataset, self).__init__()
def __getitem__(self, index): #这个方法是必须要有的,用于按照索引读取每个元素的具体内容
fn1, label1 = self.original_imgs[index] #fn是图片path #fn和label分别获得imgs[index]也即是刚才每行中word[0]和word[1]的信息
fn2, label2 = self.cdtfed_imgs[index]

img1 = Image.open(fn1).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片
img2 = Image.open(fn2).convert('RGB') #按照path读入图片from PIL import Image # 按照路径读取图片

if self.transform is not None:
img1 = self.transform(img1) #是否进行transform
img2 = self.transform(img2) #是否进行transform
img_list = [img1, img2]
label = label1
name = fn1
return img_list,label,name #return很关键,return回哪些内容,那么我们在训练时循环读取每个batch时,就能获得哪些内容

def __len__(self): #这个函数也必须要写,它返回的是数据集的长度,也就是多少张图片,要和loader的长度作区分
return len(self.original_imgs)


def find_classes(directory: str) -> Tuple[List[str], Dict[str, int]]:
"""Finds the class folders in a dataset.

See :class:`DatasetFolder` for details.
"""
classes = sorted(entry.name for entry in os.scandir(directory) if entry.is_dir())
if not classes:
raise FileNotFoundError(f"Couldn't find any class folder in {directory}.")

class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
return classes, class_to_idx

def has_file_allowed_extension(filename: str, extensions: Tuple[str, ...]) -> bool:
"""Checks if a file is an allowed extension.

Args:
filename (string): path to a file
extensions (tuple of strings): extensions to consider (lowercase)

Returns:
bool: True if the filename ends with one of given extensions
"""
return filename.lower().endswith(extensions)

def my_make_dataset(
directory: str,
class_to_idx: Optional[Dict[str, int]] = None,
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
"""Generates a list of samples of a form (path_to_sample, class).

See :class:`DatasetFolder` for details.

Note: The class_to_idx parameter is here optional and will use the logic of the ``find_classes`` function
by default.
"""
directory = os.path.expanduser(directory)

if class_to_idx is None:
_, class_to_idx = find_classes(directory)
elif not class_to_idx:
raise ValueError("'class_to_index' must have at least one entry to collect any samples.")

both_none = extensions is None and is_valid_file is None
both_something = extensions is not None and is_valid_file is not None
if both_none or both_something:
raise ValueError("Both extensions and is_valid_file cannot be None or not None at the same time")

if extensions is not None:
def is_valid_file(x: str) -> bool:
return has_file_allowed_extension(x, cast(Tuple[str, ...], extensions))

is_valid_file = cast(Callable[[str], bool], is_valid_file)

instances = []
available_classes = set()
for target_class in sorted(class_to_idx.keys()):
class_index = class_to_idx[target_class]
target_dir = os.path.join(directory, target_class)
if not os.path.isdir(target_dir):
continue
for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)):
for fname in sorted(fnames):
if is_valid_file(fname):
path = os.path.join(root, fname)
# item = path, [int(cl) for cl in target_class.split('_')]
item = path, target_class
instances.append(item)

if target_class not in available_classes:
available_classes.add(target_class)

empty_classes = set(class_to_idx.keys()) - available_classes
if empty_classes:
msg = f"Found no valid file for the classes {', '.join(sorted(empty_classes))}. "
if extensions is not None:
msg += f"Supported extensions are: {', '.join(extensions)}"
raise FileNotFoundError(msg)

return instances #instance:[item:(path, int(class_name)), ]



上一篇:写给朋友的 Python知识点,字符串方法
下一篇:没有了
网友评论