[JAX] Automatic Vectorization
·
코딩/JAX
`[JAX] Just-in-time compilation`에 이어지는 내용입니다. 아래의 공식문서를 공부(거의 번역)하며 기록합니다.http://docs.jax.dev/en/latest/automatic-vectorization.html이번 포스팅에는 `jax.vmap()`을 통한 벡터화에 대해 설명하겠습니다.Manual vectorization다음은 1차원 벡터 두개의 convolution을 계산하는 코드입니다:import jaximport jax.numpy as jnpx = jnp.arange(5)w = jnp.array([2., 3., 4.])def convolve(x, w): output = [] for i in range(1, len(x)-1): output.append(jnp.dot(..
[JAX] Quick Start
·
코딩/JAX
`JAX` 쓰는 법에 대해 공부하는 문서입니다. `JAX` 문서에서 제공하는 튜토리얼을 순서대로 공부(번역)해나갈 예정입니다. https://docs.jax.dev/en/latest/tutorials.htmlInstallationJAX는 NumPy와 같이 어레이 기반으로 연산을 다루는 라이브러리 입니다. 자동 미분과 JIT 을 통해 속도적인 측면에서 높은 성능을 보입니다. 설치는 다음과 같은 커맨드라인을 따라 진행하시면 됩니다.# for cpu versionpip install jax# for NVIDAI GPU versionpip install -U "jax[cuda12]"JAX NumPy대부분의 JAX는 `jax.numpy` API를 이용하여 사용됩니다. 보통 아래와같이 불러옵니다:import jax..