[JAX] Automatic Differentiation; jax.grad

2025. 4. 9. 22:55·코딩/JAX

`[JAX] Automatic Vectorization`에 이어지는 내용입니다. 아래의 공식 문서를 정리(거의 번역)한 내용입니다.

https://docs.jax.dev/en/latest/automatic-differentiation.html


이 포스팅에서는 `JAX`의 자동 미분(autodiff)의 기본적인 활용법에 대해 학습합니다. `JAX`는 일반적인 자동 미분 시스템을 제공합니다. 미분의 계산은 현대 머신러닝에 있어 매우 중요한 부분이며, 튜토리얼에서는 다음과 같은 기본적인 자동 미분 주제를 다룹니다:

  1. `jax.grad`를 사용한 계산 (이번 포스팅)
  2. 로지스틱 회귀에서의 gradient 계산
  3. 중첩 리스트, 튜플, 딕셔너리에 대한 미분
  4. `jax.value_and_grad`를 사용하여 함수와 그 그래디언트 평가
  5. 수치 미분과의 비교

더 심화된 주제는 고급 자동 미분 튜토리얼에 있습니다.

 

자동 미분이 내부적으로 어찌 작동하는지에 대한 이해는 꼭 필수적이지는 않지만, 깊은 이해와 특수한 상황에서 필요합니다. 영상 참고: https://www.youtube.com/watch?v=wG_nF1awSSY

`jax.grad`를 사용한 계산

`JAX`에서는 `jax.grad()`를 통해 스칼라 값을 반환하는 함수를 미분할 수 있습니다:

import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

>>> 0.070650816

`jax.grad()`는 함수를 입력받아 그 함수의 그래디언트를 계산하는 새로운 함수를 반환합니다. 즉, 파이썬 함수 `f`가 수학적 함수 $f$를 평가한다면, `jax.grad(f)`는 $\nabla{f}$를 평가하는 파이썬 함수입니다. 이를테면, `grad(f)(x)`는 $\nabla{f(x)}$의 값을 나타냅니다.

print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

>>> -0.13621868
>>> 0.25265405

`JAX`에서의 자동미분은 미분함수에 대한 미분도, 그 함수에 대한 미분도 역시 가능하기 때문에 고차 도함수를 쉽게 구할 수 있습니다. 예를들어 $f(x)=x^3+2x^2-3x+1$의 도함수는 다음과같이 계산이 가능합니다:

f = lambda x: x**3 + 2*x**2 - 3*x + 1

dfdx = jax.grad(f)

이 함수에 대해 아래와 같은 고차 도함수가 계산됩니다:

 

$$f′(x)=3x^2+4x−3$$

$$f′′(x)=6x+4$$

$$f′′′(x)=6$$

$$f^{iv}(x)=0$$

아래는 `grad()`를 통해 도함수를 쉽게 계산할 수 있게됩니다:

d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

 

위의 함수를 $x=1$에서 값을 구하면:

$$f′(1)=4$$

$$f′′(1)=10$$

$$f′′′(1)=6$$

$$f^{iv}(1)=0$$

`JAX`도 일치하는지 볼까요?

print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))

>>> 4.0
>>> 10.0
>>> 6.0
>>> 0.0

'코딩 > JAX' 카테고리의 다른 글

[JAX] Automatic Differentiation; Logiostic Regression  (0) 2025.04.09
[JAX] Automatic Vectorization  (0) 2025.04.08
[JAX] Just-in-time compilation  (0) 2025.04.08
[JAX] Quick Start  (0) 2025.04.07
'코딩/JAX' 카테고리의 다른 글
  • [JAX] Automatic Differentiation; Logiostic Regression
  • [JAX] Automatic Vectorization
  • [JAX] Just-in-time compilation
  • [JAX] Quick Start
CDeo
CDeo
잘 부탁해요 ~.~
  • 링크

    • Inter-link
    • LinkedIn
  • CDeo
    Hello World!
    CDeo
  • 공지사항

    • Inter-link
    • 분류 전체보기 (121)
      • 월간 (1)
        • 2024 (1)
      • 논문참여 (2)
      • 통계 & 머신러닝 (47)
        • 피처 엔지니어링 (2)
        • 최적화 (2)
        • 군집화 (5)
        • 공변량 보정 (4)
        • 생물정보통계 모델 (3)
        • 연합학습 (13)
        • 통계적 머신러닝 (13)
        • 논의 (0)
        • 구현 (2)
        • 스터디 (3)
      • 데이터 엔지니어링 (1)
        • 하둡 (1)
      • 코딩 (26)
        • 웹개발 (1)
        • 시각화 (2)
        • 이슈 (8)
        • 노트 (5)
        • PyTorch Lightning (5)
        • JAX (5)
      • 기본 이론 (0)
        • 집합론 (0)
        • 그래프 이론 (0)
      • 약리학 (28)
        • 강의 (5)
        • ADMET parameter (16)
        • DDI (4)
        • DTI (0)
      • 생명과학 (1)
        • 분석기술 (1)
      • 일상 (15)
        • 연구일지 (3)
        • 생각 (8)
        • 영화 (1)
        • 동화책 만들기 (1)
        • 요리 (0)
        • 다이어트 (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 인기 글

  • 전체
    오늘
    어제
  • hELLO· Designed By정상우.v4.10.1
CDeo
[JAX] Automatic Differentiation; jax.grad
상단으로

티스토리툴바