티스토리 뷰

728x90
반응형

PyTorch는 딥러닝 프레임워크 중에서 가장 인기 있는 도구 중 하나로, 특히 연구와 실험에서 많이 사용됩니다. 이 PyTorch 생태계에서 컴퓨터 비전 프로젝트에 필수적인 역할을 하는 것이 바로 torchvision입니다. 이번 글에서는 torchvision이 무엇인지, 어떤 기능을 제공하는지, 그리고 각 모듈의 활용법까지 최대한 자세히 정리하겠습니다.


📦 torchvision이란?

torchvision은 PyTorch에서 이미지 관련 작업을 보다 편리하게 처리할 수 있도록 만들어진 부가 라이브러리입니다. 컴퓨터 비전 프로젝트에서 반복적으로 사용되는 데이터셋, 모델, 전처리(transform), 이미지 유틸리티 등을 간단하게 사용할 수 있게 도와줍니다.

쉽게 말하면, **"이미지 데이터를 불러오고, 전처리하고, 유명한 pretrained 모델까지 쉽게 가져와서 학습에 바로 활용할 수 있게 하는 도구 모음"**이라고 할 수 있습니다.


🗂️ torchvision의 주요 모듈

torchvision은 크게 네 가지 핵심 모듈로 나뉩니다.

1️⃣ torchvision.datasets

  • 유명한 공개 데이터셋을 간단히 불러올 수 있도록 지원합니다.
  • 주로 학습/테스트용 이미지 데이터셋을 다운로드하고 관리합니다.

예시:

from torchvision import datasets
cifar10 = datasets.CIFAR10(root='./data', train=True, download=True)

지원하는 대표 데이터셋:

  • CIFAR-10, CIFAR-100
  • MNIST, FashionMNIST
  • ImageNet, COCO, VOC, CelebA 등

2️⃣ torchvision.models

  • 유명한 이미지 분류, 객체 탐지, 분할 모델의 pretrained 버전을 제공합니다.
  • 학습된 weight를 활용해 transfer learning이나 fine-tuning을 빠르게 시작할 수 있습니다.

예시:

from torchvision import models
resnet18 = models.resnet18(pretrained=True)

지원하는 대표 모델:

  • ResNet, AlexNet, VGG, DenseNet, Inception, MobileNet, EfficientNet
  • Faster R-CNN, Mask R-CNN, SSD (객체 탐지용)

3️⃣ torchvision.transforms

  • 이미지 전처리 및 데이터 증강(Augmentation)을 쉽게 적용할 수 있는 모듈입니다.
  • 학습 안정성 향상과 일반화 성능을 높이는 데 필수적입니다.

예시:

from torchvision import transforms
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

주요 transform 함수:

  • ToTensor() : PIL → Tensor, [0,255] → [0,1]
  • Normalize(mean, std) : 평균·표준편차 정규화
  • Resize(), CenterCrop() : 크기 조정, 중앙 자르기
  • RandomHorizontalFlip(), RandomRotation() : 데이터 증강

4️⃣ torchvision.utils

  • 이미지 그리드 만들기, 이미지 저장 등 편리한 시각화 및 유틸리티 함수 제공

예시:

from torchvision.utils import make_grid

주요 함수:

  • make_grid() : 여러 이미지를 그리드로 묶음
  • save_image() : Tensor 이미지를 파일로 저장

🔍 torchvision 사용 시 주의할 점 및 팁

  1. 데이터셋 별 root 디렉토리 관리
    • 여러 실험에서 root 경로가 꼬이지 않도록 프로젝트별 폴더를 구분하세요.
  2. transforms는 학습/테스트 분리
    • 학습용은 데이터 증강 포함 (예: RandomCrop, RandomFlip)
    • 테스트용은 크기 맞춤, 정규화만 (예: Resize, Normalize)
  3. pretrained 모델은 ImageNet 기준
    • pretrained=True로 불러온 모델은 ImageNet 학습 weight를 사용합니다.
    • 다른 데이터셋에서 fine-tuning할 때는 마지막 layer 교체 필요.
  4. GPU 사용 최적화
    • DataLoader에서 num_workers와 pin_memory를 적절히 설정하면 데이터 로딩 속도가 빨라집니다.

💡 torchvision으로 할 수 있는 실전 예제들

  • ResNet18을 CIFAR-10에 fine-tuning
  • Faster R-CNN으로 COCO 데이터셋 객체 탐지
  • 데이터셋 augmentation 실험 (transforms.RandomErasing, ColorJitter)
  • make_grid로 학습 샘플 시각화

✨ 마무리

torchvision은 PyTorch를 이용한 컴퓨터 비전 연구 및 개발에서 사실상 표준 라이브러리로 자리 잡았습니다. 데이터셋, 모델, 전처리, 유틸리티까지 한 번에 다룰 수 있어 학습 속도를 높이고, 실험의 질을 높여줍니다.

이 글을 통해 torchvision의 기본 개념과 활용법을 명확히 이해하고, 앞으로 실험에 적극적으로 활용해보시길 바랍니다. 더 깊은 코드 예제나 프로젝트별 활용법이 필요하다면 댓글로 알려주세요!

728x90
반응형