`[JAX] Just-in-time compilation`에 이어지는 내용입니다. 아래의 공식문서를 공부(거의 번역)하며 기록합니다.
http://docs.jax.dev/en/latest/automatic-vectorization.html
이번 포스팅에는 `jax.vmap()`을 통한 벡터화에 대해 설명하겠습니다.
Manual vectorization
다음은 1차원 벡터 두개의 convolution을 계산하는 코드입니다:
import jax
import jax.numpy as jnp
x = jnp.arange(5)
w = jnp.array([2., 3., 4.])
def convolve(x, w):
output = []
for i in range(1, len(x)-1):
output.append(jnp.dot(x[i-1:i+2], w))
return jnp.array(output)
convolve(x, w)
>>> Array([11., 20., 29.], dtype=float32)
이제 여러 개의 `x`, 여러 개의 `w`로 이루어진 배치 데이터를 대상으로 이 함수를 적용하고 싶다고 가정하겠습니다:
xs = jnp.stack([x, x])
ws = jnp.stack([w, w])
가장 단순한 방법은 Python에서 루프를 돌리는 것입니다:
The most naive option would be to simply loop over the batch in Python:
def manually_batched_convolve(xs, ws):
output = []
for i in range(xs.shape[0]):
output.append(convolve(xs[i], ws[i]))
return jnp.stack(output)
manually_batched_convolve(xs, ws)
>>> Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
결과는 맞지만, 속도가 느립니다. 그림으로 보자면 붉은 영역만큼씩 두 행에 대해 (for을 통해) 순차적으로 적용되는 양상입니다.
보다 효율적인 벡터화를 위해서는 함수를 재구성해야합니다. 이 경우에는 인덱싱, 축 처리방식을 바꾸는 작업이 풀요합니다. 예를들어, 다음은 수동으로 벡터화된 버전입니다:
def manually_vectorized_convolve(xs, ws):
output = []
for i in range(1, xs.shape[-1] -1):
output.append(jnp.sum(xs[:, i-1:i+2] * ws, axis=1))
return jnp.stack(output, axis=1)
manually_vectorized_convolve(xs, ws)
>>> Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
이 경우 빠르긴합니다. 그림으로 보자면 두 행 모두에 대해 행렬곱으로 한번에 연산이 진행되기 때문이죠.
이런식으로 벡터화 하여 연산하면 좋긴하지만, 구현이 너무 복잡하여 오류의 위험이 있습니다.
Automatic vectorization
이를 해결하기 위해 `JAX`에서는 `jax.vmap()`을 이용합니다.
auto_batch_convolve = jax.vmap(convolve)
auto_batch_convolve(xs, ws)
>>> Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
이 함수는 내부적으로 `jax.jit()`처럼 함수를 `추적(tracing)`하여, 각 입력에 대해 batch 차원(축)을 자동으로 추가해 벡터화를 수행합니다. 만약 batch의 쌓이는 방향(축)이 첫 번째가 아닌 경우, `in_axes`, `out_axes` 인자를 사용해 batch 차원의 위치를 지정할 수 있습니다.
auto_batch_convolve_v2 = jax.vmap(convolve, in_axes=1, out_axes=1)
xst = jnp.transpose(xs)
wst = jnp.transpose(ws)
auto_batch_convolve_v2(xst, wst)
>>> Array([[11., 11.],
[20., 20.],
[29., 29.]], dtype=float32)
위의 스크립트에서 `in_axes=1`는 입력 텐서에서의 두번째 축(axis =1)이 batch 쌓인 방향이라는 의미입니다. `out_axes=1`의 경우 출력도 두번째 축을 배치로 설정한다는 뜻입니다. `axes` 0, 1 관계는 transpose 관계로 생각하시면 편합니다.
`jax.vmap()`은 입력 인자 중 일부만 배치 처리된 경우도 연산을 지원합니다. 예를 들어, 하나의 weight 벡터 w에 대해 여러 개의 입력 벡터 x를 컨볼루션하고 싶을 때 사용할 수 있습니다. 이 경우 `in_axes`인자를 다음과 같이 설정합니다:
batch_convolve_v3 = jax.vmap(convolve, in_axes=[0, None])
batch_convolve_v3(xs, w)
>>> Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
`in_axes=[0, None]` 첫 인자 `xs`는 배치차원을 따라 벡터화 되었으며, 두번째 인자 `w`는 벡터화되지 않음을 명시하는 내용입니다.
Combining transformations
다른 JAX 변환 함수들과 마찬가지로, `jax.jit()`와 `jax.vmap()`도 서로 조합하여 사용할 수 있도록 설계되어 있습니다.
즉, vmap된 함수에 jit을 감싸거나, 반대로 jit된 함수에 vmap을 감싸는 것 모두 가능하며, 정상적으로 동작합니다.
jitted_batch_convolve = jax.jit(auto_batch_convolve)
jitted_batch_convolve(xs, ws)
>>> Array([[11., 20., 29.],
[11., 20., 29.]], dtype=float32)
'코딩 > JAX' 카테고리의 다른 글
[JAX] Automatic Differentiation; Logiostic Regression (0) | 2025.04.09 |
---|---|
[JAX] Automatic Differentiation; jax.grad (0) | 2025.04.09 |
[JAX] Just-in-time compilation (0) | 2025.04.08 |
[JAX] Quick Start (0) | 2025.04.07 |