【学习】torchvision.datasets.ImageFolder()

小明 2025-04-28 18:51:21 6

在分���任务中,数据集文件存储往往是如下形式:

- train
  - class1
    - image1.jpg
    - image2.jpg
    ...
  - class2
    - image1.jpg
    - image2.jpg
    ...
  ...

此时,我们想要获取图片和标签,标签即为文件名(class1、class2…)

可以使用torchvision.datasets.ImageFolder()来进行获取,示例代码如下:

dataset = datasets.ImageFolder(root=DATA_PATH/'train')

torchvision.datasets.ImageFolder() 参数列表:

  • root:图像文件读取路径
  • transform:对图像数据采取的数据增强策略
  • target_transform:对label进行转换
  • loader:指定加载图像的函数
  • is_valid_file:获取图像路径,检查文件的有效性

    返回值

    dataset 返回有如下三个属性:

    • self.classes:用一个 list 保存类别名称
    • self.class_to_idx:类别对应的索引,与不做任何转换返回的 target 对应
    • self.imgs:保存(img-path, class) tuple的 list

      我们得到的dataset,它的结构就是[(img_data,class_id),(img_data,class_id),…]下面我们打印dataset第一个元素中的图片:

      返回对应的label:

      其中,对于dataset[0]来说,其中也存储了两个元素,第一个是图片,第二个是类别索引号。

      sample = dataset[0]
      img = sample[0]   #图片
      label = sample[1]  #类别索引
      

      注意:

      1. dataset中存储的label是按文件夹顺序生成对应索引的,且以下标为0开始。如果要读取类别的字符,可以通过self.classes[0]来获取。
      2. train文件夹下的文件格式是固定的,不能有多余的文件,否则会读取出错。
The End
微信