ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ JAX์ ํต์ฌ ๊ธฐ๋ฅ - ์๋ ๋ฏธ๋ถ๊ณผ JIT ์ปดํ์ผ๋ก ์ฑ๋ฅ ๊ทน๋ํ
octo54 2025. 5. 8. 11:15๐ JAX์ ํต์ฌ ๊ธฐ๋ฅ - ์๋ ๋ฏธ๋ถ๊ณผ JIT ์ปดํ์ผ๋ก ์ฑ๋ฅ ๊ทน๋ํ
๐ JAX์ ํต์ฌ ๊ธฐ๋ฅ ๋๋ฌ๋ณด๊ธฐ
JAX๋ ๋ค์ํ ๊ธฐ๋ฅ์ ์ ๊ณตํ์ง๋ง, ํนํ **์๋ ๋ฏธ๋ถ(Automatic Differentiation)**๊ณผ **JIT ์ปดํ์ผ(Just-In-Time Compilation)**์ด ๊ฐ์ฅ ์ค์ํ ์์์
๋๋ค.
์ด ๋ ๊ฐ์ง ๊ธฐ๋ฅ์ ๊น์ด ์ดํดํ๋ฉด JAX๋ฅผ ํ์ฉํ์ฌ ๊ณ ์ฑ๋ฅ ๋ชจ๋ธ์ ๊ตฌ์ถํ ์ ์์ต๋๋ค.
๐ก 1. ์๋ ๋ฏธ๋ถ (Automatic Differentiation)
์๋ ๋ฏธ๋ถ์ ์ํ ํจ์์ ๋ฏธ๋ถ์ ๊ธฐ๊ณ์ ์ผ๋ก ๊ณ์ฐํ๋ ๊ธฐ๋ฒ์ผ๋ก,
๊ธฐ๊ณ ํ์ต ๋ชจ๋ธ์ ํ์ต ๋จ๊ณ์์ ํ์์ ์ธ **๊ธฐ์ธ๊ธฐ ๊ณ์ฐ(Gradient Calculation)**์ ์ฌ์ฉ๋ฉ๋๋ค.
โ ์๋ ๋ฏธ๋ถ์ ์ฅ์
- ์ํ์ ์ ๋ ๋ถํ์: ๋ณต์กํ ๋ฏธ๋ถ ๊ณต์์ ์ง์ ๊ณ์ฐํ ํ์๊ฐ ์์ต๋๋ค.
- ์ฑ๋ฅ ์ต์ ํ: GPU๋ฅผ ํ์ฉํ์ฌ ๋น ๋ฅด๊ฒ ๊ณ์ฐํ ์ ์์ต๋๋ค.
- ๋ณต์กํ ํจ์๋ ๋ฌธ์ ์์: ๊ณ ์ฐจ ๋ฏธ๋ถ๋ ์ฝ๊ฒ ๊ณ์ฐํ ์ ์์ต๋๋ค.
๐ ์๋ ๋ฏธ๋ถ ๊ธฐ์ด ์ฌ์ฉ๋ฒ
JAX์์ ๋ฏธ๋ถ์ ๊ณ์ฐํ๋ ๊ธฐ๋ณธ ํจ์๋ jax.grad()์
๋๋ค.
๋ค์์ ๊ฐ๋จํ ์์ ์
๋๋ค:
import jax.numpy as jnp
from jax import grad
# ํจ์ ์ ์
def loss_fn(x):
return jnp.sum(x ** 2)
# ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ
x = jnp.array([1.0, 2.0, 3.0])
gradient = grad(loss_fn)(x)
print(f"Gradient: {gradient}")
์ถ๋ ฅ:
Gradient: [2. 4. 6.]
๐ก Tip:
grad() ํจ์๋ ์ค์นผ๋ผ ๊ฐ์ ๋ฐํํ๋ ํจ์์์๋ง ๋์ํฉ๋๋ค.
๋ง์ฝ ๋ฒกํฐ๋ฅผ ๋ฐํํ๋ ํจ์๋ผ๋ฉด jax.jacrev() ๋๋ jax.jacfwd()๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
๐ ๊ณ ์ฐจ ๋ฏธ๋ถ ๊ณ์ฐ
JAX๋ ๊ณ ์ฐจ ๋ฏธ๋ถ๋ ์ฝ๊ฒ ์ง์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, 2์ฐจ ๋ฏธ๋ถ์ ๊ณ์ฐํ๋ ค๋ฉด grad() ํจ์๋ฅผ ์ค์ฒฉํด์ ์ฌ์ฉํฉ๋๋ค:
# 2์ฐจ ๋ฏธ๋ถ ๊ณ์ฐ
second_derivative = grad(grad(loss_fn))(x)
print(f"Second Gradient: {second_derivative}")
์ถ๋ ฅ:
Second Gradient: [2. 2. 2.]
โก 2. JIT ์ปดํ์ผ (Just-In-Time Compilation)
JIT ์ปดํ์ผ์ Python ์ฝ๋๋ฅผ GPU์์ ๋น ๋ฅด๊ฒ ์คํํ ์ ์๋๋ก ์ต์ ํํ๋ ๊ธฐ๋ฅ์
๋๋ค.
JAX์์๋ jax.jit() ๋ฐ์ฝ๋ ์ดํฐ๋ฅผ ํตํด JIT ์ปดํ์ผ์ ์ ์ฉํ ์ ์์ต๋๋ค.
โ JIT์ ์ฅ์
- ์ฑ๋ฅ ํฅ์: CPU ๋๋น ์์ญ ๋ฐฐ ๋น ๋ฅธ ์ฐ์ฐ ์๋๋ฅผ ์๋ํฉ๋๋ค.
- ์ฝ๋ ์ต์ ํ: ๋ฐ๋ณต ๊ณ์ฐ์ด ๋ง์ ์ ๊ฒฝ๋ง ํ์ต์ ์ ํฉํฉ๋๋ค.
- GPU ํ์ฉ ๊ทน๋ํ: CUDA๋ฅผ ํตํด ๋๊ท๋ชจ ๋ฐ์ดํฐ ์ฒ๋ฆฌ์ ์ ๋ฆฌํฉ๋๋ค.
๐ JIT ๊ธฐ์ด ์ฌ์ฉ๋ฒ
JIT์ ์ฌ์ฉํ์ฌ ํ๋ ฌ ๊ณฑ์ ์ฑ๋ฅ์ ๋น๊ตํด๋ด ์๋ค:
from jax import jit
import time
# ํ๋ ฌ ๊ณฑ์
ํจ์
def matmul(x, y):
return jnp.dot(x, y)
# JIT ์ ์ฉ
jit_matmul = jit(matmul)
# ์
๋ ฅ ๋ฐ์ดํฐ
x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
# ์ฑ๋ฅ ๋น๊ต
start = time.time()
result = matmul(x, y)
print(f"์ผ๋ฐ ๊ณ์ฐ ์๊ฐ: {time.time() - start:.6f}์ด")
start = time.time()
result_jit = jit_matmul(x, y)
print(f"JIT ๊ณ์ฐ ์๊ฐ: {time.time() - start:.6f}์ด")
๐ก JIT์ ์ฌ์ฉํ๋ ๊ฒฝ์ฐ
- ๋๊ท๋ชจ ํ๋ ฌ ๊ณ์ฐ
- CNN๊ณผ ๊ฐ์ ๋ณต์กํ ์ ๊ฒฝ๋ง ํ์ต
- ๋ฐ๋ณต์ ์ธ ์์น ์ต์ ํ ๋ฌธ์
- ์ค์๊ฐ ์๋ต์ด ์ค์ํ ์ ํ๋ฆฌ์ผ์ด์
๐ JAX์ ์๋ ๋ฏธ๋ถ๊ณผ JIT์ ํ์ฉํ ์ค์ ๋ชจ๋ธ
๐ ์ ํ ํ๊ท ๋ชจ๋ธ ํ์ต
์๋ ๋ฏธ๋ถ๊ณผ JIT์ ์ด์ฉํ์ฌ ์ ํ ํ๊ท ๋ชจ๋ธ์ ํ์ตํด๋ด ์๋ค:
# ๋ชจ๋ธ ์ ์
def model(w, x):
return w[0] * x + w[1]
# ์์ค ํจ์
def loss_fn(w, x, y):
pred = model(w, x)
return jnp.mean((pred - y) ** 2)
# ๋ฐ์ดํฐ ์์ฑ
x_data = jnp.array([1.0, 2.0, 3.0, 4.0])
y_data = jnp.array([2.0, 4.0, 6.0, 8.0])
# ์ด๊ธฐ ๊ฐ์ค์น
w = jnp.array([0.0, 0.0])
# ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ํจ์
grad_fn = jit(grad(loss_fn))
# ํ์ต ๋ฃจํ
learning_rate = 0.01
for epoch in range(100):
gradient = grad_fn(w, x_data, y_data)
w = w - learning_rate * gradient
if epoch % 10 == 0:
loss = loss_fn(w, x_data, y_data)
print(f"Epoch {epoch}, Loss: {loss:.4f}")
๐ JAX์ ํ์ฉ์ฑ ์ ๋ฆฌ
- ์๋ ๋ฏธ๋ถ์ ํจ์จ์ฑ
- ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ์ ํตํด ์ ๊ฒฝ๋ง ํ์ต์ ์ต์ ํ๋์ด ์์ต๋๋ค.
- JIT ์ปดํ์ผ์ ๊ณ ์ ์ฒ๋ฆฌ
- GPU๋ฅผ ์ต๋ํ ํ์ฉํ์ฌ ๋ฐ๋ณต ์ฐ์ฐ ์๋๋ฅผ ํฌ๊ฒ ๊ฐ์ ํฉ๋๋ค.
- ์ค์ ํ๋ก์ ํธ ํ์ฉ ๊ฐ๋ฅ
- CNN, RNN๊ณผ ๊ฐ์ ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต
- ๊ฐํ ํ์ต ์๊ณ ๋ฆฌ์ฆ ํ์ต
- ํ๋ฅ ์ ๊ทธ๋ํ ๋ชจ๋ธ ๊ตฌํ
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ฅผ ํ์ฉํ ์ฌํ ๋ชจ๋ธ ๊ตฌ์ถ
๋ค์ ๊ธ์์๋ JAX๋ฅผ ์ฌ์ฉํ์ฌ CNN๊ณผ ๊ฐํ ํ์ต ๋ชจ๋ธ์ ๊ตฌํํ์ฌ ์ฑ๋ฅ์ ํ๊ฐํด๋ณด๊ฒ ์ต๋๋ค.
์ค์ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ํ์ตํ๋ ๊ณผ์ ์ ๋จ๊ณ๋ณ๋ก ์ค๋ช
ํ ์์ ์
๋๋ค.
JAX, ์๋ ๋ฏธ๋ถ, JIT ์ปดํ์ผ, GPU ์ฐ์ฐ, ์ ํ ํ๊ท, ๋ชจ๋ธ ํ์ต, ๊ณ ์ ์ฒ๋ฆฌ, ์ค์ ํ๋ก์ ํธ, Python, ๊ณ ์ฑ๋ฅ ๊ณ์ฐ, ๋ฅ๋ฌ๋, ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ, ์ฑ๋ฅ ์ต์ ํ, ๊ณ ์ฐจ ๋ฏธ๋ถ, ํจ์ ๋ฒกํฐํ
'Programming > JAX' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
- Total
- Today
- Yesterday
- ๋ฐฑ์๋๊ฐ๋ฐ
- seo ์ต์ ํ 10๊ฐ
- NestJS
- ์น๊ฐ๋ฐ
- fastapi
- kotlin
- gatsbyjs
- CI/CD
- Docker
- ํ์ด์ฌ์๊ณ ๋ฆฌ์ฆ
- Prisma
- AI์ฑ๋ด
- SEO ์ต์ ํ
- Ktor
- Next.js
- SEO์ต์ ํ
- flax
- rag
- PostgreSQL
- ํ๋ก ํธ์๋
- App Router
- JAX
- REACT
- ๋ฅ๋ฌ๋
- Python
- llm
- ํ๋ก ํธ์๋๋ฉด์
- nextJS
- nodejs
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |