`JAX` 쓰는 법에 대해 공부하는 문서입니다. `JAX` 문서에서 제공하는 튜토리얼을 순서대로 공부(번역)해나갈 예정입니다. https://docs.jax.dev/en/latest/tutorials.html
Installation
JAX는 NumPy와 같이 어레이 기반으로 연산을 다루는 라이브러리 입니다. 자동 미분과 JIT 을 통해 속도적인 측면에서 높은 성능을 보입니다. 설치는 다음과 같은 커맨드라인을 따라 진행하시면 됩니다.
# for cpu version
pip install jax
# for NVIDAI GPU version
pip install -U "jax[cuda12]"
JAX NumPy
대부분의 JAX는 `jax.numpy` API를 이용하여 사용됩니다. 보통 아래와같이 불러옵니다:
import jax.numpy as jnp
이렇게 불러오면 기존의 NumPy 방식과 비슷하게 코딩이 가능해집니다. 예를들어, SELU 활성화 함수에 대해 다음과 같이 구현 가능합니다:
def selu(x, alpha = 1.67, lmbda=1.05):
return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
x = jnp.arange(5.0)
print(selu(x))
JIT
from jax import random
key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
>>> 18.9 ms ± 739 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
`jax.jit()`을 통해 함수를 컴파일(첫 호출) 시켜놓고, 그 다음 호출부터는 빠르게 실행이 됩니다.
from jax import jit
selu_jit = jit(selu)
_ = selu_jit(x) # 첫 실행시 컴파일도 됨!
%timeit selu_jit(x).block_until_ready()
>>> 8.92 µs ± 75.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
한 2000배는 빠르네요! (뭐지...?) Warm-up 된 상태로 JIT 돌아가서 이런 차이가 난건가 싶기도 합니다.
Derivatives
JAX는 또한 자동 미분 기능을 제공합니다.
from jax import grad
def sum_logistic(x):
return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))
x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)
%timeit derivative_fn(x_small)
>>> 3.14 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
print(derivative_fn(x_small))
>>> [0.25 0.19661197 0.10499357]
그냥 계산하려면 아래처럼 구현해야합니다.
import numpy as np
def sum_logistic(x):
return np.sum(1.0 / (1.0 + np.exp(-x)))
def numerical_gradient(f, x, eps=1e-6):
grad = np.zeros_like(x)
for i in range(len(x)):
x_eps_plus = x.copy()
x_eps_minus = x.copy()
x_eps_plus[i] += eps
x_eps_minus[i] -= eps
grad[i] = (f(x_eps_plus) - f(x_eps_minus)) / (2 * eps)
return grad
x = np.array([0, 1.0, 2.0])
%timeit numerical_gradient(sum_logistic, x)
>>> 72.5 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
print(numerical_gradient(sum_logistic, x))
>>> [0.25 0.19661193 0.10499359]
대략 40배 정도 빠릅니다! (JIT과 grad를 통한 성능 증가 폭이 좀 다르군요..!)
추가로 이러한 `grad()`와 `jit()`의 조합으로도 계산 할 수 있습니다.
print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
>>> -0.0353256
여기서는 2계 도함수를 계산한 결과를 제시하는군요!
이 외에도 스칼라 값에 대한 계산을 넘어 `jax.jacobian()` 변환을 통해 벡터값을 반환하는 함수에 대해 야코비안 행렬을 계산할 수 있습니다.
from jax import jacobian
%timeit J = jacobian(jnp.exp)(x)
print("JAX Jacobian:\n", J)
>>> 2.18 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> JAX Jacobian:
[[1. 0. 0. ]
[0. 2.7182817 0. ]
[0. 0. 7.389056 ]]
넘파이로 하면
import numpy as np
def f(x):
return np.exp(x)
def numerical_jacobian(f, x, eps=1e-6):
x = x.astype(float) # ensure float
n = len(x)
J = np.zeros((n, n))
fx = f(x)
for i in range(n):
x_eps = x.copy()
x_eps[i] += eps # 이 상황에서는 어차피 변수가 다르면 편미분시 사라지니까 diagonal에만 변화를 줌
fx_eps = f(x_eps)
J[:, i] = (fx_eps - fx) / eps
return J
# 입력값
x_np = np.array([0, 1.0, 2.0])
%timeit J_np = numerical_jacobian(f, x_np)
print("NumPy Numerical Jacobian:\n", J_np)
>>> 18.4 µs ± 373 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> NumPy Numerical Jacobian:
[[1.0000005 0. 0. ]
[0. 2.71828319 0. ]
[0. 0. 7.38905979]]
대략 120배 정도 빠르네요!
스칼라 나오고, 야코비안 나왔으면 이젠 헤시안 차례겠죠..? ㅎㅎ
우선 제시하는 원문 내용을 이야기하자면, 다음과 같습니다:
보다 심화된 자동 미분 연산을 위해, 역전파 방식의 벡터-야코비안 곱을 계산할 수 있는 `jax.vjp()`와 순전파 방식의 야코비안-벡터 곱을 계산하는 `jax.jvp()` 및 `jax.linearize()`를 사용할 수 있습니다. 이 두 방식은 서로 자유롭게 조합할 수 있으며,
다른 JAX 변환(jit, grad 등)과도 함께 사용할 수 있습니다. 예를 들어, `jax.jvp()`와 `jax.vjp()`는 각각 순전파 및 역전파 방식으로 야코비안을 계산하는 `jax.jacfwd()`와 `jax.jacrev()`의 기반이 됩니다. 아래는 이들을 조합하여 전체 헤시안(Hessian) 행렬을 효율적으로 계산하는 함수를 정의하는 방법 중 하나입니다:
from jax import jacfwd, jacrev
def hessian(fun):
return jit(jacfwd(jacrev(fun)))
print(hessian(sum_logistic)(x_small))
>>> [[-0. -0. -0. ]
[-0. -0.09085776 -0. ]
[-0. -0. -0.07996249]]
JAX 내장 함수 `jax.hessian()`도 이런식으로 구현되어있다고 합니다.
참고로, 위의 구현들에서 eps는 아주 작은수를 나타내는 $\epsilon$으로 이해해주시면 됩니다! 컴퓨터상에서는 무한히 작은 $\epsilon$ 을 구할 수 없으니 (저는 못구해요.. ㅠㅠ) 수치미분으로 매우 작은 $\epsilon$ 만큼 x 변동시 y가 얼마나 변하는지로 미분의 근사값을 추정합니다.
Auto-vectorization
또한 `JAX`는 벡터화가 가능합니다. 벡터화는 루프를 통해 계산하는 것이 아닌 벡터에서의 병렬 계산을 통해 빠른 연산이 가능하게 됩니다. 여기서 제공하는 `vmap()`은 함수를 자연스럽게 벡터화 하며, `jit()`과 결합하면 내적하여 나오는 결과처럼 (manually batched) 아주 강력한 성능을 보인다고 합니다.
먼저 루프기반의 계산입니다.
key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100)) # 150 x 100
batched_x = random.normal(key2, (10, 100)) # 길이 100인 벡터 10개
def apply_matrix(x): # (150 x 100) x (100 x 1) 계산
return jnp.dot(mat, x)
def naively_batched_apply_matrix(v_batched):
return jnp.stack([apply_matrix(v) for v in v_batched])
print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()
>>> Naively batched
>>> 1.04 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
내적의 결과입니다.
import numpy as np
@jit
def batched_apply_matrix(batched_x): # (10 x 100) x (100 x 150) 계산
return jnp.dot(batched_x, mat.T)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()
>>> Manually batched
>>> 30.2 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
빠르긴 하지만, 복잡한 코딩으로 인해 에러가 발생할 가능성이 생깁니다.
`vmap()`을 이용하면 함수 자체를 이러한 빠른 실행이 가능하도록 변환이 가능합니다.
from jax import vmap
@jit
def vmap_batched_apply_matrix(batched_x):
return vmap(apply_matrix)(batched_x)
np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()
>>> Auto-vectorized with vmap
>>> 37.2 µs ± 716 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
오우, 벡터화 되는 경우에는 300배 정도 빠르군요..!
참고로 여기서 rtol과 atol은 두가지의 결과를 비교시에 아주 작은 오차의 정도를 설정해주는 파라미터입니다.
'코딩 > JAX' 카테고리의 다른 글
[JAX] Automatic Differentiation; Logiostic Regression (0) | 2025.04.09 |
---|---|
[JAX] Automatic Differentiation; jax.grad (0) | 2025.04.09 |
[JAX] Automatic Vectorization (0) | 2025.04.08 |
[JAX] Just-in-time compilation (0) | 2025.04.08 |