[JAX] Quick Start

2025. 4. 7. 23:49·코딩/JAX

`JAX` 쓰는 법에 대해 공부하는 문서입니다. `JAX` 문서에서 제공하는 튜토리얼을 순서대로 공부(번역)해나갈 예정입니다. https://docs.jax.dev/en/latest/tutorials.html


Installation

JAX는 NumPy와 같이 어레이 기반으로 연산을 다루는 라이브러리 입니다. 자동 미분과 JIT 을 통해 속도적인 측면에서 높은 성능을 보입니다. 설치는 다음과 같은 커맨드라인을 따라 진행하시면 됩니다.

# for cpu version
pip install jax

# for NVIDAI GPU version
pip install -U "jax[cuda12]"

JAX NumPy

대부분의 JAX는 `jax.numpy` API를 이용하여 사용됩니다. 보통 아래와같이 불러옵니다:

import jax.numpy as jnp

이렇게 불러오면 기존의 NumPy 방식과 비슷하게 코딩이 가능해집니다. 예를들어, SELU 활성화 함수에 대해 다음과 같이 구현 가능합니다:

def selu(x, alpha = 1.67, lmbda=1.05):
	return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)
    
x = jnp.arange(5.0)
print(selu(x))

JIT

from jax import random

key = random.key(1701)
x = random.normal(key, (1_000_000,))
%timeit selu(x).block_until_ready()
>>> 18.9 ms ± 739 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

`jax.jit()`을 통해 함수를 컴파일(첫 호출) 시켜놓고, 그 다음 호출부터는 빠르게 실행이 됩니다.

from jax import jit

selu_jit = jit(selu)
_ = selu_jit(x) # 첫 실행시 컴파일도 됨!
%timeit selu_jit(x).block_until_ready()
>>> 8.92 µs ± 75.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

한 2000배는 빠르네요! (뭐지...?) Warm-up 된 상태로 JIT 돌아가서 이런 차이가 난건가 싶기도 합니다.

Derivatives

JAX는 또한 자동 미분 기능을 제공합니다.

from jax import grad

def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x_small = jnp.arange(3.)
derivative_fn = grad(sum_logistic)

%timeit derivative_fn(x_small)
>>> 3.14 ms ± 15.6 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

print(derivative_fn(x_small))
>>> [0.25       0.19661197 0.10499357]

그냥 계산하려면 아래처럼 구현해야합니다.

import numpy as np

def sum_logistic(x):
    return np.sum(1.0 / (1.0 + np.exp(-x)))

def numerical_gradient(f, x, eps=1e-6):
    grad = np.zeros_like(x)
    for i in range(len(x)):
        x_eps_plus = x.copy()
        x_eps_minus = x.copy()
        x_eps_plus[i] += eps
        x_eps_minus[i] -= eps
        grad[i] = (f(x_eps_plus) - f(x_eps_minus)) / (2 * eps)
    return grad

x = np.array([0, 1.0, 2.0])

%timeit numerical_gradient(sum_logistic, x)
>>> 72.5 µs ± 1.32 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

print(numerical_gradient(sum_logistic, x))
>>> [0.25       0.19661193 0.10499359]

대략 40배 정도 빠릅니다! (JIT과 grad를 통한 성능 증가 폭이 좀 다르군요..!)

 

추가로 이러한 `grad()`와 `jit()`의 조합으로도 계산 할 수 있습니다.

print(grad(jit(grad(jit(grad(sum_logistic)))))(1.0))
>>> -0.0353256

여기서는 2계 도함수를 계산한 결과를 제시하는군요!

 

이 외에도 스칼라 값에 대한 계산을 넘어 `jax.jacobian()` 변환을 통해 벡터값을 반환하는 함수에 대해 야코비안 행렬을 계산할 수 있습니다.

from jax import jacobian

%timeit J = jacobian(jnp.exp)(x)
print("JAX Jacobian:\n", J)

>>> 2.18 ms ± 11.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
>>> JAX Jacobian:
     [[1.        0.        0.       ]
     [0.        2.7182817 0.       ]
     [0.        0.        7.389056 ]]

넘파이로 하면

import numpy as np

def f(x):
    return np.exp(x)

def numerical_jacobian(f, x, eps=1e-6):
    x = x.astype(float)  # ensure float
    n = len(x)
    J = np.zeros((n, n))
    fx = f(x)
    for i in range(n):
        x_eps = x.copy()
        x_eps[i] += eps # 이 상황에서는 어차피 변수가 다르면 편미분시 사라지니까 diagonal에만 변화를 줌
        fx_eps = f(x_eps)
        J[:, i] = (fx_eps - fx) / eps
    return J

# 입력값
x_np = np.array([0, 1.0, 2.0])
%timeit J_np = numerical_jacobian(f, x_np)

print("NumPy Numerical Jacobian:\n", J_np)

>>> 18.4 µs ± 373 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
>>> NumPy Numerical Jacobian:
     [[1.0000005  0.         0.        ]
     [0.         2.71828319 0.        ]
     [0.         0.         7.38905979]]

대략 120배 정도 빠르네요! 

 

스칼라 나오고, 야코비안 나왔으면 이젠 헤시안 차례겠죠..? ㅎㅎ

 

우선 제시하는 원문 내용을 이야기하자면, 다음과 같습니다:

보다 심화된 자동 미분 연산을 위해, 역전파 방식의 벡터-야코비안 곱을 계산할 수 있는 `jax.vjp()`와 순전파 방식의 야코비안-벡터 곱을 계산하는 `jax.jvp()` 및 `jax.linearize()`를 사용할 수 있습니다. 이 두 방식은 서로 자유롭게 조합할 수 있으며,
다른 JAX 변환(jit, grad 등)과도 함께 사용할 수 있습니다. 예를 들어, `jax.jvp()`와 `jax.vjp()`는 각각 순전파 및 역전파 방식으로 야코비안을 계산하는 `jax.jacfwd()`와 `jax.jacrev()`의 기반이 됩니다. 아래는 이들을 조합하여 전체 헤시안(Hessian) 행렬을 효율적으로 계산하는 함수를 정의하는 방법 중 하나입니다:

from jax import jacfwd, jacrev
def hessian(fun):
    return jit(jacfwd(jacrev(fun)))

print(hessian(sum_logistic)(x_small))

>>> [[-0.         -0.         -0.        ]
     [-0.         -0.09085776 -0.        ]
     [-0.         -0.         -0.07996249]]

JAX 내장 함수 `jax.hessian()`도 이런식으로 구현되어있다고 합니다.

 

참고로, 위의 구현들에서 eps는 아주 작은수를 나타내는 $\epsilon$으로 이해해주시면 됩니다! 컴퓨터상에서는 무한히 작은 $\epsilon$ 을 구할 수 없으니 (저는 못구해요.. ㅠㅠ) 수치미분으로 매우 작은 $\epsilon$ 만큼 x 변동시 y가 얼마나 변하는지로 미분의 근사값을 추정합니다.

Auto-vectorization

또한 `JAX`는 벡터화가 가능합니다. 벡터화는 루프를 통해 계산하는 것이 아닌 벡터에서의 병렬 계산을 통해 빠른 연산이 가능하게 됩니다. 여기서 제공하는 `vmap()`은 함수를 자연스럽게 벡터화 하며, `jit()`과 결합하면 내적하여 나오는 결과처럼 (manually batched) 아주 강력한 성능을 보인다고 합니다.

 

먼저 루프기반의 계산입니다.

key1, key2 = random.split(key)
mat = random.normal(key1, (150, 100)) # 150 x 100
batched_x = random.normal(key2, (10, 100)) # 길이 100인 벡터 10개

def apply_matrix(x): # (150 x 100) x (100 x 1) 계산
      return jnp.dot(mat, x)
      
def naively_batched_apply_matrix(v_batched):
    return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

>>> Naively batched
>>> 1.04 ms ± 27.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

내적의 결과입니다.

import numpy as np

@jit
def batched_apply_matrix(batched_x): # (10 x 100) x (100 x 150) 계산
    return jnp.dot(batched_x, mat.T)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

>>> Manually batched
>>> 30.2 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

빠르긴 하지만, 복잡한 코딩으로 인해 에러가 발생할 가능성이 생깁니다.

 

`vmap()`을 이용하면 함수 자체를 이러한 빠른 실행이 가능하도록 변환이 가능합니다.

from jax import vmap

@jit
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

np.testing.assert_allclose(naively_batched_apply_matrix(batched_x),
                           vmap_batched_apply_matrix(batched_x), atol=1E-4, rtol=1E-4)
print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

>>> Auto-vectorized with vmap
>>> 37.2 µs ± 716 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

오우, 벡터화 되는 경우에는 300배 정도 빠르군요..!

참고로 여기서 rtol과 atol은 두가지의 결과를 비교시에 아주 작은 오차의 정도를 설정해주는 파라미터입니다.

'코딩 > JAX' 카테고리의 다른 글

[JAX] Automatic Differentiation; Logiostic Regression  (0) 2025.04.09
[JAX] Automatic Differentiation; jax.grad  (0) 2025.04.09
[JAX] Automatic Vectorization  (0) 2025.04.08
[JAX] Just-in-time compilation  (0) 2025.04.08
'코딩/JAX' 카테고리의 다른 글
  • [JAX] Automatic Differentiation; Logiostic Regression
  • [JAX] Automatic Differentiation; jax.grad
  • [JAX] Automatic Vectorization
  • [JAX] Just-in-time compilation
CDeo
CDeo
잘 부탁해요 ~.~
  • 링크

    • Inter-link
    • LinkedIn
  • CDeo
    Hello World!
    CDeo
  • 공지사항

    • Inter-link
    • 분류 전체보기 (121)
      • 월간 (1)
        • 2024 (1)
      • 논문참여 (2)
      • 통계 & 머신러닝 (47)
        • 피처 엔지니어링 (2)
        • 최적화 (2)
        • 군집화 (5)
        • 공변량 보정 (4)
        • 생물정보통계 모델 (3)
        • 연합학습 (13)
        • 통계적 머신러닝 (13)
        • 논의 (0)
        • 구현 (2)
        • 스터디 (3)
      • 데이터 엔지니어링 (1)
        • 하둡 (1)
      • 코딩 (26)
        • 웹개발 (1)
        • 시각화 (2)
        • 이슈 (8)
        • 노트 (5)
        • PyTorch Lightning (5)
        • JAX (5)
      • 기본 이론 (0)
        • 집합론 (0)
        • 그래프 이론 (0)
      • 약리학 (28)
        • 강의 (5)
        • ADMET parameter (16)
        • DDI (4)
        • DTI (0)
      • 생명과학 (1)
        • 분석기술 (1)
      • 일상 (15)
        • 연구일지 (3)
        • 생각 (8)
        • 영화 (1)
        • 동화책 만들기 (1)
        • 요리 (0)
        • 다이어트 (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 인기 글

  • 전체
    오늘
    어제
  • hELLO· Designed By정상우.v4.10.1
CDeo
[JAX] Quick Start
상단으로

티스토리툴바