`[ML with JAX] XOR 분류기`에 이어지는 내용입니다.
`MLP(Multi-Layer Perceptron)`는 비선형 분류 문제를 해결할 수 있도록 해줍니다. 이는 결국 문제에 대해 적절한 `가중치`와 `바이어스(절편)`가 존재한다는 전제에서 출발합니다. 이전 포스팅에서 다룬 논리 연산 이진 분류 문제에서는 데이터의 위치가 명확했기 때문에 적당히 계산한 가중치로도 해결이 가능했습니다. 하지만 실생활 문제에서는 데이터도 가지가지 양상을 보이고 이러한 상황을 반영하는 가중치를 직접 계산하기 어렵기 때문에, 주어진 데이터를 바탕으로 결과를 예측하고, 오차를 확인한 뒤, 수정하는 과정을 반복하여 가장 적합한 가중치를 추정합니다.
Loss function
이 오차를 측정하기 위해 사용하는 것이 바로 `손실함수(loss function)`입니다. 손실함수는 추정한 결과와 실제값 간의 차이를 정량화하며, 이 값을 최소화하는 방향으로 가중치를 학습하게 됩니다. 문제에 따라 손실을 적합하게 표현하는 여러 손실함수가 존재하게 됩니다. 디테일한 식의 의미는 따로 정리할 예정입니다
Classification
분류 문제는 "Yes or No" 와 같은 이진 분류문제와 "건강기능식품 vs. 일반의약품 vs. 전문의약품" 과 같은 다중 분류문제로 나뉩니다. 다중 분류문제 중에서도 순서를 갖는 서수형 분류문제도 있습니다 ("정상 vs. 당뇨 전단계 vs. 당뇨병"). 기본적으로 정보이론에서의 entropy라는 개념을 이용하여 손실함수를 구성합니다. 구체적인 내용은 첨부된 스크립트에 있습니다~!
이진분류에서는 Binary Cross-Entropy(BCE)를 이용합니다:
$$\mathrm{BCE} = -\frac{1}{n} \sum_{i=1}^{n} \left[ y_i \log(\hat{y}_i) + (1 - y_i) \log(1 - \hat{y}_i) \right]$$
여기서의 $n$은 샘플 수, $y_i$는 0 또는 1의 값을 갖는 실제 정답이며, $\hat{y_i}$는 예측을 위함 확률값을 의미합니다.
@jax.jit
def jax_bce(y, y_hat):
eps = 1e-8
return jnp.mean(- (y * jnp.log(y_hat + eps) + (1 - y) * jnp.log(1 - y_hat + eps)))
다중분류의 경우 Categorical Cross-Entropy(CCE)를 이용합니다:
$$\mathrm{CCE} = -\frac{1}{n} \sum_{i=1}^{n} \sum_{j=1}^{C} y_{ij} \log(\hat{y}_{ij})$$
여기서 $C$는 총 클래스의 수, $y_{ij}$와 $\hat{y}_{ij}$는 실제의 클래스 값과 예측된 값으로 정수가 아니라 보통 원핫인코딩된 (더미코딩) 값이 들어갑니다.
@jax.jit
def jax_cce(y, y_hat):
eps = 1e-8
return jnp.mean(-jnp.log(y_hat[jnp.arange(y.shape[0]), y] + eps))
Regression
회귀의 경우 "아이스크림 판매량" 과 같이 특정한 값을 예측하여 제시는 문제입니다. 가장 기본적으로 사용되는 손실은 Mean Squared Error(MSE)입니다:
$$\mathrm{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_{ij} - \hat{y}_{ij})^2$$
@jax.jit
def jax_mse(y, y_hat):
return jnp.mean((y - y_hat) ** 2)
비교
NumPy와 `PyTorch`를 통한 `JAX`와의 시간 비교를 해보겠습니다. 실험설정은, 공통적으로 1000000개의 샘플에 대해 진행하였으며, 다중분류에서는 5개의 범주를 설정하였습니다. 연산의 비교를 위해 `JAX`의 asynchronous (비동기) 실행을 불활성화 하고 잰 결과입니다.
PyTorch
torch_mse: 0.008488 sec
torch_bce: 0.009204 sec
torch_cce: 0.016950 sec
JAX
첫 실행(컴파일 비용 포함):
jax_mse: 0.120055 sec
jax_bce: 0.081467 sec
jax_cce: 0.099774 sec
이후의 실행:
jax_mse: 0.000133 sec
jax_bce: 0.000056 sec
jax_cce: 0.000053 sec
역시 `JAX` 가 컴파일 비용은 크지만, 그 이후의 속도는 훨~~씬 빠릅니다!
스크립트는 아래의 첨부를 참고해주세요~
'통계 & 머신러닝 > 구현' 카테고리의 다른 글
[ML with JAX] XOR 분류기 (0) | 2025.04.03 |
---|