[Torch Uncertainty] Tutorial: Deep Evidential Classification on a Toy Example

2024. 8. 23. 00:50·통계 & 머신러닝/통계적 머신러닝

`Torch Uncertainty` 공식문서의 `Tutorials`를 번역한 내용입니다.


Deep Evidential Classification on a Toy Example

이 튜토리얼은 Deep Evidential Classification(DEC)의 개요를 소개하고자 하며, MNIST 데이터셋을 사용하는 Multi-Layer Perceptron(MLP) 신경망 모델을 통해 DEC를 적용하는 방법을 보여줍니다. MLP의 출력은 디리클레(Dirichlet) 분포로 모델링되며, DEC 손실 함수는 베이지안 리스크 제곱 오차 손실과 KL Divergence 기반의 정규화 항으로 구성됩니다.

Training a LeNet with DEC using TorchUncertainty models

다음 단계에서는 이미 TorchUncertainty(TU)에 구현된 모델과 루틴을 기반으로 신경망을 훈련합니다.

1. Loading the utilities

DEC 손실 함수를 사용하여 LeNet을 훈련하기 위해 TorchUncertainty에서 다음 유틸리티를 로드해야 합니다:

  • Lightning의 Trainer
  • `torch_uncertainty.models`의 LeNet 모델
  • `torch_uncertainty.routines`의 분류 훈련 루틴
  • `torch_uncertainty.losses`의 DEC 손실 함수 `DECLoss`
  • 데이터로더와 변환을 처리하는 데이터 모듈: `torch_uncertainty.datamodules`의 `MNISTDataModule`

또한, `torch.optim`을 사용하여 옵티마이저를 정의하고, `torch.nn`에서 신경망 유틸리티를 임포트해야 합니다.

from pathlib import Path

import torch
from lightning.pytorch import Trainer
from torch import nn, optim

from torch_uncertainty.datamodules import MNISTDataModule
from torch_uncertainty.losses import DECLoss
from torch_uncertainty.models.lenet import lenet
from torch_uncertainty.routines import ClassificationRoutine

2. Creating the Optimizer Wrapper

DEC의 공식 구현을 따르며, 기본 학습률 0.001과 스텝 스케줄러를 사용하는 Adam 옵티마이저를 사용합니다.

def optim_lenet(model: nn.Module) -> dict:
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=0.005)
    exp_lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
    return {"optimizer": optimizer, "lr_scheduler": exp_lr_scheduler}

3. Creating the necessary variables

다음 단계에서는 로그의 루트 경로를 정의하고, PyTorch Lightning Trainer를 사용하기 위해 필요한 인수를 설정합니다. 또한, DEC 논문에서 사용된 것과 동일한 MNIST 분류 예제를 사용합니다. 시간 절약을 위해 3 에포크만 훈련합니다.

trainer = Trainer(accelerator="cpu", max_epochs=3, enable_progress_bar=False)

# datamodule
root = Path() / "data"
datamodule = MNISTDataModule(root=root, batch_size=128)

model = lenet(
    in_channels=datamodule.num_channels,
    num_classes=datamodule.num_classes,
)

4. The Loss and the Training Routine

이제 훈련 중에 사용할 손실 함수를 정의해야 합니다. 그런 다음 `torch_uncertainty.routines.ClassificationRoutine`에서 제공하는 단일 분류 모델 훈련 루틴을 사용하여 훈련 루틴을 정의합니다. 이 루틴에 모델, DEC 손실, 옵티마이저 및 기본 인수를 제공합니다.

loss = DECLoss(reg_weight=1e-2)

routine = ClassificationRoutine(
    model=model,
    num_classes=datamodule.num_classes,
    loss=loss,
    optim_recipe=optim_lenet(model),
)

5. Gathering Everything and Training the Model

이제 모든 것이 준비되었으므로, Lightning Trainer를 사용하여 모델을 훈련합니다.

trainer.fit(model=routine, datamodule=datamodule)
trainer.test(model=routine, datamodule=datamodule)
6. Testing the Model

이제 모델이 훈련되었으므로 MNIST에서 테스트해 보겠습니다. 테스트 단계에서는 MNIST 이미지를 회전시켜 이미지를 확인하고, 예측 및 확신도를 출력합니다.

import matplotlib.pyplot as plt
import numpy as np
import torchvision
import torchvision.transforms.functional as F


def imshow(img) -> None:
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()


def rotated_mnist(angle: int) -> None:
    """Rotate MNIST images and show images and confidence.

    Args:
        angle: Rotation angle in degrees.
    """
    rotated_images = F.rotate(images, angle)
    # print rotated images
    plt.axis("off")
    imshow(torchvision.utils.make_grid(rotated_images[:4, ...]))
    print("Ground truth: ", " ".join(f"{labels[j]}" for j in range(4)))

    evidence = routine(rotated_images)
    alpha = torch.relu(evidence) + 1
    strength = torch.sum(alpha, dim=1, keepdim=True)
    probs = alpha / strength
    entropy = -1 * torch.sum(probs * torch.log(probs), dim=1, keepdim=True)
    for j in range(4):
        predicted = torch.argmax(probs[j, :])
        print(
            f"Predicted digits for the image {j}: {predicted} with strength "
            f"{strength[j,0]:.3} and entropy {entropy[j,0]:.3}."
        )


dataiter = iter(datamodule.val_dataloader())
images, labels = next(dataiter)

with torch.no_grad():
    routine.eval()
    rotated_mnist(0)
    rotated_mnist(45)
    rotated_mnist(90)

Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Ground truth:  7 2 1 0
Predicted digits for the image 0: 7 with strength 99.7 and entropy 0.502.
Predicted digits for the image 1: 0 with strength 10.0 and entropy 2.3.
Predicted digits for the image 2: 1 with strength 99.2 and entropy 0.504.
Predicted digits for the image 3: 0 with strength 61.6 and entropy 0.737.
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Ground truth:  7 2 1 0
Predicted digits for the image 0: 9 with strength 21.9 and entropy 1.58.
Predicted digits for the image 1: 0 with strength 13.6 and entropy 2.1.
Predicted digits for the image 2: 8 with strength 26.1 and entropy 1.72.
Predicted digits for the image 3: 0 with strength 43.1 and entropy 1.2.
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
Ground truth:  7 2 1 0
Predicted digits for the image 0: 0 with strength 42.4 and entropy 1.44.
Predicted digits for the image 1: 4 with strength 37.8 and entropy 1.07.
Predicted digits for the image 2: 7 with strength 42.9 and entropy 0.975.
Predicted digits for the image 3: 9 with strength 38.7 and entropy 1.38.

이 코드에서는 모델이 회전된 이미지를 어떻게 처리하는지 보여줍니다. 이를 통해 예측과 그에 따른 불확실성을 분석할 수 있습니다.

참조

https://torch-uncertainty.github.io/auto_tutorials/tutorial_evidential_classification.html#sphx-glr-auto-tutorials-tutorial-evidential-classification-py

 

Deep Evidential Classification on a Toy Example — TorchUncertainty 0.2.1.post0 documentation

Deep Evidential Classification on a Toy Example This tutorial aims to provide an introductory overview of Deep Evidential Classification (DEC) using a practical example. We demonstrate an application of DEC by tackling the toy-problem of fitting the MNIST

torch-uncertainty.github.io

Deep Evidential Classification: Murat Sensoy, Lance Kaplan, & Melih Kandemir (2018). Evidential Deep Learning to Quantify Classification Uncertainty NeurIPS 2018.

 

'통계 & 머신러닝 > 통계적 머신러닝' 카테고리의 다른 글

[논문] Simple and Scalable Predictive Uncertainty Estimation Using Deep Ensembles: Idea  (0) 2024.10.30
[Torch Uncertainty] Tutorial: Train a Bayesian Neural Network in Three Minutes  (0) 2024.08.23
[Torch Uncertainty] Tutorial: Training a LeNet with Monte Carlo Batch Normalization  (0) 2024.08.22
[Torch Uncertainty] Tutorial: Training a LeNet with Monte-Carlo Dropout  (0) 2024.08.22
[Torch Uncertainty] Tutorial: Improve Top-label Calibration with Temperature Scaling  (0) 2024.08.22
'통계 & 머신러닝/통계적 머신러닝' 카테고리의 다른 글
  • [논문] Simple and Scalable Predictive Uncertainty Estimation Using Deep Ensembles: Idea
  • [Torch Uncertainty] Tutorial: Train a Bayesian Neural Network in Three Minutes
  • [Torch Uncertainty] Tutorial: Training a LeNet with Monte Carlo Batch Normalization
  • [Torch Uncertainty] Tutorial: Training a LeNet with Monte-Carlo Dropout
CDeo
CDeo
잘 부탁해요 ~.~
  • 링크

    • Inter-link
    • LinkedIn
  • CDeo
    Hello World!
    CDeo
  • 공지사항

    • Inter-link
    • 분류 전체보기 (123)
      • 월간 (1)
        • 2024 (1)
      • 논문참여 (2)
      • 통계 & 머신러닝 (47)
        • 피처 엔지니어링 (2)
        • 최적화 (2)
        • 군집화 (5)
        • 공변량 보정 (4)
        • 생물정보통계 모델 (3)
        • 연합학습 (13)
        • 통계적 머신러닝 (13)
        • 논의 (0)
        • 구현 (2)
        • 스터디 (3)
      • 데이터 엔지니어링 (1)
        • 하둡 (1)
      • 코딩 (26)
        • 웹개발 (1)
        • 시각화 (2)
        • 이슈 (8)
        • 노트 (5)
        • PyTorch Lightning (5)
        • JAX (5)
      • 에너지 (2)
        • 뉴스 및 동향 (2)
        • 용어 정리 (0)
      • 기본 이론 (0)
        • 집합론 (0)
        • 그래프 이론 (0)
      • 약리학 (28)
        • 강의 (5)
        • ADMET parameter (16)
        • DDI (4)
        • DTI (0)
      • 생명과학 (1)
        • 분석기술 (1)
      • 일상 (15)
        • 연구일지 (3)
        • 생각 (8)
        • 영화 (1)
        • 동화책 만들기 (1)
        • 요리 (0)
        • 다이어트 (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 인기 글

  • 전체
    오늘
    어제
  • hELLO· Designed By정상우.v4.10.1
CDeo
[Torch Uncertainty] Tutorial: Deep Evidential Classification on a Toy Example
상단으로

티스토리툴바