본문 바로가기
ML & DL/pytorch

[pytorch] Tutorial - Datasets & DataLoaders & Transforms

by 별준 2022. 11. 26.

References

  • Official PyTorch Tutorials (link)

파이토치에서는 데이터를 쉽게 전처리하고, 모델을 학습하는 코드와 별도로 모듈화할 수 있는 기능인 torch.utils.data.DataLoaders와 torch.utils.data.Dataset을 제공합니다. Datasets은 샘플 데이터와 대응하는 라벨을 저장하고, DataLoader는 Dataset을 순회하면서 쉽게 샘플 데이터에 액세스할 수 있도록 해줍니다.

 

Loading a FashionMNIST Dataset

torchvision에서 제공하는 FashionMNIST가 Dataset의 서브클래스인데, 파이토치는 다양한 이미지, 텍스트, 오디오 데이터셋을 제공합니다.

FashionMNIST 데이터셋을 예시로 어떻게 로드할 수 있는지 살펴보겠습니다.

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt


training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

datasets에는 다양한 데이터셋이 있으며, 여기서는 FashionMNIST 데이터를 사용합니다. 사용된 파라미터는 다음과 같습니다.

  • root : train/test 데이터가 저장되는 경로
  • train : training data를 로드할 것인지 test data를 로드할 것인지 지정
  • download : root에 데이터가 없는 경우, 인터넷으로부터 데이터를 다운로드할 것인지 지정
  • transform/target_transform : feature과 label의 transformation을 지정 (전처리)

 

Iterating and Visualizing the Dataset

Datasets은 인덱싱을 통해서 액세스할 수 있습니다.

labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

 

Creating a Custom Dataset

커스텀 데이터셋 클래스를 구현할 수 있습니다. 이때, __init__, __len__, __getitem__ 함수를 꼭 구현해주어야 합니다.

 

__init__

__init__ 함수는 Dataset 객체가 초기화될 때 한 번 수행되며, image, annotation file 디렉토리와 데이터와 라벨에 대한 transform을 초기화합니다.

 

__len__

__len__ 함수는 데이터셋에서 샘플의 총 갯수를 리턴합니다.

 

__getitem__

__getitem__ 함수는 데이터셋으로부터 주어진 인덱스 idx의 샘플을 반환합니다.

 

세 함수를 구현은 다음과 같이 구현할 수 있습니다.

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

특히 __getitem__ 함수는 torch.io.read_image를 사용하여 이미지 파일에서 이미지를 읽어서 텐서로 변환하고 대응하는 라벨을 읽습니다. 그리고 transform 함수가 지정되었다면, transformation을 적용한 뒤 텐서 이미지와 대응되는 라벨을 튜플로 반환합니다.

 

 

Preparing your data for training with DataLoaders

Datatset은 한 번에 하나의 샘플을 순회합니다. 모델을 학습할 때는 일반적으로 minibatch로 샘플을 전달하고, 모델의 과적합을 피하기 위해 매 epoch마다 데이터를 다시 shuffle 합니다.

torch.utils.data.DataLoader를 사용하면, 쉽게 위와 같은 동작을 적용할 수 있습니다.

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

Iterate through the DataLoader

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

train_dataloader는 batch_size가 64이기 때문에 한 번 순회할 때마다 64개의 샘플을 로드합니다.

 

Transforms

https://pytorch.org/vision/stable/transforms.html

 

Transforming and augmenting images — Torchvision 0.14 documentation

Shortcuts

pytorch.org

raw data는 학습에 바로 사용할 수 없기 때문에 전처리가 필요합니다. 이때, transforms를 사용하여 학습에 적절한 데이터로 만들어주는 몇 가지 manipulation을 수행할 수 있습니다.

 

torchvision의 모든 데이터셋은 transform과 target_transform 파라미터를 가지고 있으며, transformation logic을 포함하는 콜러블(callable)을 전달받습니다. torchvision.transforms 모듈은 자주 사용되는 몇 가지 transforms를 제공합니다.

 

FashionMNIST에서 이미지는 PIL Image format이고, 라벨은 정수입니다. 학습을 위해서는 이미지 데이터는 텐서로 normalization해야 하고, 라벨은 one-hot encoding된 텐서로 다음과 같이 변환해주어야 합니다.

import torch
from torchvision import datasets
from torchvision.transforms import ToTensor, Lambda

ds = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor(),
    target_transform=Lambda(lambda y: torch.zeros(10, dtype=torch.float).scatter_(0, torch.tensor(y), value=1))
)

 

여기서 ToTensor()는 [0,255] 범위의 값을 가지는 PIL Image 또는 numpy.ndarray(H x W x C)를 [0,1] 범위의 값을 갖는 torch.FloatTensor(C x H x W)로 변환합니다. 만약 입력 데이터가 이미 [0, 1.0] 범위라면 ToTensor()를 사용하면 안됩니다.

 

댓글