
[JAX] Automatic Vectorization
·
코딩/JAX
`[JAX] Just-in-time compilation`에 이어지는 내용입니다. 아래의 공식문서를 공부(거의 번역)하며 기록합니다.http://docs.jax.dev/en/latest/automatic-vectorization.html이번 포스팅에는 `jax.vmap()`을 통한 벡터화에 대해 설명하겠습니다.Manual vectorization다음은 1차원 벡터 두개의 convolution을 계산하는 코드입니다:import jaximport jax.numpy as jnpx = jnp.arange(5)w = jnp.array([2., 3., 4.])def convolve(x, w): output = [] for i in range(1, len(x)-1): output.append(jnp.dot(..