`[JAX] Automatic Vectorization`에 이어지는 내용입니다. 아래의 공식 문서를 정리(거의 번역)한 내용입니다.
https://docs.jax.dev/en/latest/automatic-differentiation.html
이 포스팅에서는 `JAX`의 자동 미분(autodiff)의 기본적인 활용법에 대해 학습합니다. `JAX`는 일반적인 자동 미분 시스템을 제공합니다. 미분의 계산은 현대 머신러닝에 있어 매우 중요한 부분이며, 튜토리얼에서는 다음과 같은 기본적인 자동 미분 주제를 다룹니다:
- `jax.grad`를 사용한 계산 (이번 포스팅)
- 로지스틱 회귀에서의 gradient 계산
- 중첩 리스트, 튜플, 딕셔너리에 대한 미분
- `jax.value_and_grad`를 사용하여 함수와 그 그래디언트 평가
- 수치 미분과의 비교
더 심화된 주제는 고급 자동 미분 튜토리얼에 있습니다.
자동 미분이 내부적으로 어찌 작동하는지에 대한 이해는 꼭 필수적이지는 않지만, 깊은 이해와 특수한 상황에서 필요합니다. 영상 참고: https://www.youtube.com/watch?v=wG_nF1awSSY
`jax.grad`를 사용한 계산
`JAX`에서는 `jax.grad()`를 통해 스칼라 값을 반환하는 함수를 미분할 수 있습니다:
import jax
import jax.numpy as jnp
from jax import grad
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
>>> 0.070650816
`jax.grad()`는 함수를 입력받아 그 함수의 그래디언트를 계산하는 새로운 함수를 반환합니다. 즉, 파이썬 함수 `f`가 수학적 함수 $f$를 평가한다면, `jax.grad(f)`는 $\nabla{f}$를 평가하는 파이썬 함수입니다. 이를테면, `grad(f)(x)`는 $\nabla{f(x)}$의 값을 나타냅니다.
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))
>>> -0.13621868
>>> 0.25265405
`JAX`에서의 자동미분은 미분함수에 대한 미분도, 그 함수에 대한 미분도 역시 가능하기 때문에 고차 도함수를 쉽게 구할 수 있습니다. 예를들어 $f(x)=x^3+2x^2-3x+1$의 도함수는 다음과같이 계산이 가능합니다:
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)
이 함수에 대해 아래와 같은 고차 도함수가 계산됩니다:
$$f′(x)=3x^2+4x−3$$
$$f′′(x)=6x+4$$
$$f′′′(x)=6$$
$$f^{iv}(x)=0$$
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)
위의 함수를 $x=1$에서 값을 구하면:
$$f′(1)=4$$
$$f′′(1)=10$$
$$f′′′(1)=6$$
$$f^{iv}(1)=0$$
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))
>>> 4.0
>>> 10.0
>>> 6.0
>>> 0.0
'코딩 > JAX' 카테고리의 다른 글
[JAX] Automatic Differentiation; Logiostic Regression (0) | 2025.04.09 |
---|---|
[JAX] Automatic Vectorization (0) | 2025.04.08 |
[JAX] Just-in-time compilation (0) | 2025.04.08 |
[JAX] Quick Start (0) | 2025.04.07 |