[PyTorch] Softmax + CrossEntropyLoss
Softmax 두번 쓰는 경우... 자칫 분석이 잘못 될 수도 있습니다..!
문제
다중 분류문제에서 보통 손실함수로 `CrossEnropyLoss`를 사용합니다. 이 함수는 보통 모델의 아웃풋 로짓값을 인풋으로 받게끔 구현되어있습니다. 만약 이 로짓에 `softmax`를 적용하여 `CrossEnropyLoss`에 집어넣으면 조금 왜곡된 결과가 나오겠죠..?
그런데 정말 조심해야하고 재미있는점은, 코드는 돌아가며 오류 경고가 뜨지않는 문제라는 점이고, 그 이상한 결과가 `아주 약간`만 이상해 보일 수 있다는 것입니다. 예를들어, 정상적으로 손실함수에 로짓을 집어넣은 모델의 `F1 macro score`가 `0.8` 이라하면, 위와같은 처리를 한 뒤의 결과는 `0.8`과 비슷하거나 별다른 문제가 없을 가능성이 큽니다. 물론 오류가 발견되더라도 비교적 다른 오류적인 결과들에 비해 납득할 만하게 나올 때가 있습니다. 그럼에도 분명 조심해야하는 결과로, 그 이유는 다음과 같습니다.
원인
먼저 이상한 결과가 나오는 이유는 우리가 구현시 명시해놓은 `softmax`와 `CrossEnropyLoss`에 내제된 `softmax`가 두번 연속으로 적용되어 발생하기 때문입니다. 이렇게 되면 일반적으로는 문제가 없지만, 과신하는 모델일 수록 (모델의 결과되는 probability의 특정 값이 극단적으로 1에 가까운 값이 나올 수록) 예측 값 사이의 극단적인 차이게 존재하게 되어 원래의 확률 분포가 왜곡되게 됩니다.
예를들어,
원래 이렇던 output이,
이렇게 바뀔 수 있습니다. 이렇게되면, 피드백이 왜곡되어 올바른 방향으로의 학습이 방해될 수 있습니다.
재미있게도 이런식으로 학습된 모델은 비슷한 데이터 셋에서는 성능이 유지되며, 약간 더 과적합 상태에 이르게됩니다. 사실 큰 문제가 없을 가능성이 더 큽니다만... 이러한 오류를 모른채로 분석을 진행하다가 여러 우연이 겹친다면, 예상과는 아주 다른 결과를 내놓는 경우도 있게 됩니다. 특히 `Uncertainty`나 `Model`의 `Reliability`는 prediction value인 `confidence`를 확인해야하는데, 이 경우 직접적인 타격을 입게되어 결과가 아주 박살납니다...;;
저의 경우에는 두번의 박살이 있었습니다.
첫째로, 낮은 확률을 뚫고 성능이 박살났습니다. `validation set`을 기준으로 `hyperparameter`를 최적화 시키고, `early stopping`을 하여 `epoch`수를 정하고, `train + validation`으로 다시 학습을 시키는 방식을 사용했습니다. `train`, `validation`, `test` 모두 의도된 `covariate shift`를 보이며, multi-class의 `imbalance`가 제법 심한 자료였습니다. 이런 조건들이 맞아 떨어지면, 단지 일반화가 잘 안되었다 수준이 아니라, 그냥 하나도 못 맞추는 수준으로 `test`결과가 나올 수 있습니다.
둘쨰로, 비교적 최근 `Evidential Deep Learning`에 같은 실수를 반복하여, `Uncertainty`의 설명 가능한 부분과 그렇지 못한 부분이 모든 샘플에 대해 같게 나오게되어 사용할 수 없는 결과를 얻는 대참사가 있었습니다.
혹시 나도...?
보통은 실수를 안하시겠지만, 제가 실수하게된 배경을 말씀드리겠습니다. 😢
원래의 실험은 `multi-label`의 경우를 가정하고 진행되었는데, `multi-class`의 경우로 바꾸어 진행하게 되었습니다. 일반적으로 `multi-label`은 `Sigmoid` + `binary_cross_entropy` 조합으로 학습을 시킵니다. 그리고 수업때 외운 `multi-label - Sigmoid, multi-class - Softmax` 대로 그냥, 사용하던 `Sigmoid`를 `Softmax`로 바꾸고 `binary_cross_entropy`를 `CrossEntropyLoss`로 바꾸어 진행하였습니다.
앞서 말씀드린 여러 우연적인 상황에서의 데이터 셋에 대하여 위의 방식으로 모델링 한 뒤, `test set`을 집어넣으니 말도 안되게 성능이 떨어지더군요... 한 `0.7` 기대했는데 `0.003` 이렇게 나오더랍니다.
여러분들은 혹시 바보같지만 재미있던 실수를 하신적 있으신가요? ㅎㅎ
있으시다면 댓글로 공유해주세요~