[JAX] Automatic Vectorization

2025. 4. 8. 07:25·코딩/JAX

`[JAX] Just-in-time compilation`에 이어지는 내용입니다. 아래의 공식문서를 공부(거의 번역)하며 기록합니다.

http://docs.jax.dev/en/latest/automatic-vectorization.html


이번 포스팅에는 `jax.vmap()`을 통한 벡터화에 대해 설명하겠습니다.

Manual vectorization

다음은 1차원 벡터 두개의 convolution을 계산하는 코드입니다:

import jax
import jax.numpy as jnp

x = jnp.arange(5)
w = jnp.array([2., 3., 4.])

def convolve(x, w):
  output = []
  for i in range(1, len(x)-1):
    output.append(jnp.dot(x[i-1:i+2], w))
  return jnp.array(output)

convolve(x, w)

>>> Array([11., 20., 29.], dtype=float32)

이제 여러 개의 `x`, 여러 개의 `w`로 이루어진 배치 데이터를 대상으로 이 함수를 적용하고 싶다고 가정하겠습니다:

xs = jnp.stack([x, x])
ws = jnp.stack([w, w])

가장 단순한 방법은 Python에서 루프를 돌리는 것입니다:

The most naive option would be to simply loop over the batch in Python:

def manually_batched_convolve(xs, ws):
  output = []
  for i in range(xs.shape[0]):
    output.append(convolve(xs[i], ws[i]))
  return jnp.stack(output)

manually_batched_convolve(xs, ws)

>>> Array([[11., 20., 29.],
    	   [11., 20., 29.]], dtype=float32)

결과는 맞지만, 속도가 느립니다. 그림으로 보자면 붉은 영역만큼씩 두 행에 대해 (for을 통해) 순차적으로 적용되는 양상입니다.

보다 효율적인 벡터화를 위해서는 함수를 재구성해야합니다. 이 경우에는 인덱싱, 축 처리방식을 바꾸는 작업이 풀요합니다. 예를들어, 다음은 수동으로 벡터화된 버전입니다:

def manually_vectorized_convolve(xs, ws):
  output = []
  for i in range(1, xs.shape[-1] -1):
    output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
  return jnp.stack(output, axis=1)

manually_vectorized_convolve(xs, ws)

>>> Array([[11., 20., 29.],
       	   [11., 20., 29.]], dtype=float32)

이 경우 빠르긴합니다. 그림으로 보자면 두 행 모두에 대해 행렬곱으로 한번에 연산이 진행되기 때문이죠.

이런식으로 벡터화 하여 연산하면 좋긴하지만, 구현이 너무 복잡하여 오류의 위험이 있습니다.

Automatic vectorization

이를 해결하기 위해 `JAX`에서는 `jax.vmap()`을 이용합니다. 

auto_batch_convolve = jax.vmap(convolve)

auto_batch_convolve(xs, ws)

>>> Array([[11., 20., 29.],
           [11., 20., 29.]], dtype=float32)

이 함수는 내부적으로 `jax.jit()`처럼 함수를 `추적(tracing)`하여, 각 입력에 대해 batch 차원(축)을 자동으로 추가해 벡터화를 수행합니다. 만약 batch의 쌓이는 방향(축)이 첫 번째가 아닌 경우, `in_axes`, `out_axes` 인자를 사용해 batch 차원의 위치를 지정할 수 있습니다. 

auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)

xst = jnp.transpose(xs)
wst = jnp.transpose(ws)

auto_batch_convolve_v2(xst, wst)

>>> Array([[11., 11.],
           [20., 20.],
           [29., 29.]], dtype=float32)

위의 스크립트에서 `in_axes=1`는 입력 텐서에서의 두번째 축(axis =1)이 batch 쌓인 방향이라는 의미입니다. `out_axes=1`의 경우 출력도 두번째 축을 배치로 설정한다는 뜻입니다. `axes` 0, 1 관계는 transpose 관계로 생각하시면 편합니다.

 

`jax.vmap()`은 입력 인자 중 일부만 배치 처리된 경우도 연산을 지원합니다. 예를 들어, 하나의 weight 벡터 w에 대해 여러 개의 입력 벡터 x를 컨볼루션하고 싶을 때 사용할 수 있습니다. 이 경우 `in_axes`인자를 다음과 같이 설정합니다:

batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])

batch_convolve_v3(xs, w)

>>> Array([[11., 20., 29.],
           [11., 20., 29.]], dtype=float32)

`in_axes=[0, None]` 첫 인자 `xs`는 배치차원을 따라 벡터화 되었으며, 두번째 인자 `w`는 벡터화되지 않음을 명시하는 내용입니다.

Combining transformations

다른 JAX 변환 함수들과 마찬가지로, `jax.jit()`와 `jax.vmap()`도 서로 조합하여 사용할 수 있도록 설계되어 있습니다.

즉, vmap된 함수에 jit을 감싸거나, 반대로 jit된 함수에 vmap을 감싸는 것 모두 가능하며, 정상적으로 동작합니다.

jitted_batch_convolve = jax.jit(auto_batch_convolve)

jitted_batch_convolve(xs, ws)

>>> Array([[11., 20., 29.],
           [11., 20., 29.]], dtype=float32)

 

'코딩 > JAX' 카테고리의 다른 글

[JAX] Automatic Differentiation; Logiostic Regression  (0) 2025.04.09
[JAX] Automatic Differentiation; jax.grad  (0) 2025.04.09
[JAX] Just-in-time compilation  (0) 2025.04.08
[JAX] Quick Start  (0) 2025.04.07
'코딩/JAX' 카테고리의 다른 글
  • [JAX] Automatic Differentiation; Logiostic Regression
  • [JAX] Automatic Differentiation; jax.grad
  • [JAX] Just-in-time compilation
  • [JAX] Quick Start
CDeo
CDeo
잘 부탁해요 ~.~
  • 링크

    • Inter-link
    • LinkedIn
  • CDeo
    Hello World!
    CDeo
  • 공지사항

    • Inter-link
    • 분류 전체보기 (123)
      • 월간 (1)
        • 2024 (1)
      • 논문참여 (2)
      • 통계 & 머신러닝 (47)
        • 피처 엔지니어링 (2)
        • 최적화 (2)
        • 군집화 (5)
        • 공변량 보정 (4)
        • 생물정보통계 모델 (3)
        • 연합학습 (13)
        • 통계적 머신러닝 (13)
        • 논의 (0)
        • 구현 (2)
        • 스터디 (3)
      • 데이터 엔지니어링 (1)
        • 하둡 (1)
      • 코딩 (26)
        • 웹개발 (1)
        • 시각화 (2)
        • 이슈 (8)
        • 노트 (5)
        • PyTorch Lightning (5)
        • JAX (5)
      • 에너지 (2)
        • 뉴스 및 동향 (2)
        • 용어 정리 (0)
      • 기본 이론 (0)
        • 집합론 (0)
        • 그래프 이론 (0)
      • 약리학 (28)
        • 강의 (5)
        • ADMET parameter (16)
        • DDI (4)
        • DTI (0)
      • 생명과학 (1)
        • 분석기술 (1)
      • 일상 (15)
        • 연구일지 (3)
        • 생각 (8)
        • 영화 (1)
        • 동화책 만들기 (1)
        • 요리 (0)
        • 다이어트 (2)
  • 블로그 메뉴

    • 홈
    • 태그
    • 방명록
  • 인기 글

  • 전체
    오늘
    어제
  • hELLO· Designed By정상우.v4.10.1
CDeo
[JAX] Automatic Vectorization
상단으로

티스토리툴바