ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ JAX ๊ธฐ์ด - ์ JAX๋ฅผ ๋ฐฐ์์ผ ํ ๊น?
๐ JAX๋ ๋ฌด์์ธ๊ฐ?
JAX๋ Google์์ ๊ฐ๋ฐํ ๊ณ ์ฑ๋ฅ ์์น ๊ณ์ฐ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก, GPU์ TPU๋ฅผ ํ์ฉํ์ฌ ๋๊ท๋ชจ ๋ฐ์ดํฐ๋ฅผ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌํ ์ ์๋ ํน์ง์ ๊ฐ์ง๊ณ ์์ต๋๋ค.
Python์ Numpy์ ์ ์ฌํ ๋ฌธ๋ฒ์ ๊ฐ์ง๊ณ ์์ด ์ง๊ด์ ์ด๋ฉด์๋ ๊ฐ๋ ฅํ ์ฑ๋ฅ์ ์ ๊ณตํฉ๋๋ค.
๐ก JAX์ ์ฃผ์ ํน์ง
- ์๋ ๋ฏธ๋ถ (Automatic Differentiation)
- JAX๋ grad() ํจ์๋ฅผ ํตํด ๋งค์ฐ ๊ฐ๋จํ๊ฒ ๋ฏธ๋ถ์ ์ํํ ์ ์์ต๋๋ค.
- ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต ์ ํ์์ ์ธ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ์ด ๊ฐํธํฉ๋๋ค.
- JIT ์ปดํ์ผ (Just-In-Time Compilation)
- JIT์ ์ฌ์ฉํ์ฌ ์ฑ๋ฅ์ ํฌ๊ฒ ํฅ์์ํฌ ์ ์์ต๋๋ค.
- ๋ฐ๋ณต๋๋ ์ฐ์ฐ์ GPU ๋๋ TPU์์ ๋น ๋ฅด๊ฒ ์ฒ๋ฆฌํ ์ ์์ต๋๋ค.
- ํจ์ ๋ฒกํฐํ (Vectorization)
- vmap() ํจ์๋ฅผ ํตํด ๋ฐ๋ณต๋ฌธ์ ๋ณ๋ ฌํํ์ฌ ๋๊ท๋ชจ ๋ฐ์ดํฐ ์ฒ๋ฆฌ ์๋๋ฅผ ๊ฐ์ ํฉ๋๋ค.
- GPU๋ฅผ ์ต๋ํ ํ์ฉํ์ฌ ํจ์จ์ ์ธ ๊ณ์ฐ์ด ๊ฐ๋ฅํฉ๋๋ค.
- GPU/TPU ์ง์
- JAX๋ CUDA์ TPU๋ฅผ ์ง์ ์ง์ํ์ฌ ๋๊ท๋ชจ ๋ฐ์ดํฐ ํ์ต์ ์ ํฉํฉ๋๋ค.
๐ ์ JAX๋ฅผ ์ฌ์ฉํด์ผ ํ ๊น?
- Numpy์ ๋น์ทํ ์ฌ์ฉ์ฑ
- ๊ธฐ์กด Numpy ์ฌ์ฉ์๋ผ๋ฉด ๋งค์ฐ ์ฝ๊ฒ ์ ์ํ ์ ์์ต๋๋ค.
- GPU/TPU๋ฅผ ์ด์ฉํ ๊ณ ์ ์ฐ์ฐ
- ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต ์๋๋ฅผ ๊ทน๋ํํ ์ ์์ต๋๋ค.
- ์๋ ๋ฏธ๋ถ์ ๊ฐ๋ ฅํจ
- ์ํ์ ์ต์ ํ ๋ฌธ์ ๋ฅผ ์ง์ ํ๊ฑฐ๋, ๋ชจ๋ธ ํ์ต์ ๊ตฌํํ ๋ ๋งค์ฐ ์ ์ฉํฉ๋๋ค.
- ํจ์ํ ํ๋ก๊ทธ๋๋ฐ ์ง์
- JAX๋ ๋ถ๋ณ์ฑ(immutability)์ ์งํค๋ฉฐ ํจ์ํ ํ๋ก๊ทธ๋๋ฐ ์คํ์ผ์ ์ฑํํ์ฌ ๋๋ฒ๊น ์ด ์ฉ์ดํฉ๋๋ค.
๐ง JAX ์ค์น ๋ฐฉ๋ฒ
JAX๋ Python 3.12 ๋ฒ์ ์ ๊ธฐ์ค์ผ๋ก ์ค์นํ ์ ์์ต๋๋ค.
๋จผ์ Conda ํ๊ฒฝ์ ์ค๋นํ๊ณ ์๋์ ๊ฐ์ด ์ค์นํฉ๋๋ค:
conda create -n jax_env python=3.12 -y
conda activate jax_env
pip install --upgrade pip
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
๐ก Tip:
JAX๋ CUDA ๋ฒ์ ๊ณผ ํธํ์ด ์ค์ํฉ๋๋ค. CUDA 11 ์ด์์์ ์์ ์ ์ผ๋ก ์๋ํฉ๋๋ค.
์์ ์ GPU ํ๊ฒฝ์ ๋ง์ถฐ ์ค์น ๋งํฌ๋ฅผ ํ์ธํ์ธ์.
๐ JAX ๊ธฐ์ด ์ฌ์ฉ๋ฒ
1. ๊ธฐ๋ณธ ์ฐ์ฐ
JAX๋ฅผ ์ด์ฉํ์ฌ ๊ฐ๋จํ ์์น ์ฐ์ฐ์ ํด๋ณด๊ฒ ์ต๋๋ค:
import jax.numpy as jnp
# ๋ฐฐ์ด ์์ฑ
x = jnp.array([1.0, 2.0, 3.0])
y = jnp.array([0.1, 0.2, 0.3])
# ๊ธฐ๋ณธ ์ฐ์ฐ
sum_xy = jnp.add(x, y)
print(f"Sum: {sum_xy}")
์ถ๋ ฅ:
Sum: [1.1 2.2 3.3]
2. ์๋ ๋ฏธ๋ถ
JAX์ ๊ฐ๋ ฅํ ๊ธฐ๋ฅ ์ค ํ๋์ธ ์๋ ๋ฏธ๋ถ์ ์ฌ์ฉํด๋ณด๊ฒ ์ต๋๋ค:
from jax import grad
# ์์ค ํจ์ ์ ์
def loss_fn(w):
return jnp.sum(w ** 2)
# ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ
grad_loss = grad(loss_fn)
print(f"Gradient: {grad_loss(x)}")
์ถ๋ ฅ:
Gradient: [2.0 4.0 6.0]
3. JIT ์ปดํ์ผ
์ฑ๋ฅ์ ๊ทน๋ํํ๊ธฐ ์ํด JIT์ ์ฌ์ฉํด๋ด ์๋ค:
from jax import jit
# JIT์ ์ด์ฉํ ํจ์
@jit
def multiply(a, b):
return a * b
print(f"JIT Multiply: {multiply(x, y)}")
์ถ๋ ฅ:
JIT Multiply: [0.1 0.4 0.9]
๐ JAX์ ํ์ฉ ์ฌ๋ก
- ๋ฅ๋ฌ๋ ๋ชจ๋ธ ํ์ต
- ๊ณ ์ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ๋๋ถ์ CNN, RNN ๋ชจ๋ธ ํ์ต์ ์ ์ฉํฉ๋๋ค.
- ๊ฐํ ํ์ต ์๊ณ ๋ฆฌ์ฆ ๊ตฌํ
- ์ ์ฑ ๊ฒฝ์ฌ๋ฒ, Q-learning๊ณผ ๊ฐ์ ์๊ณ ๋ฆฌ์ฆ์ ํจ์จ์ ์ผ๋ก ๊ตฌํํ ์ ์์ต๋๋ค.
- ๊ณ ์ฐจ์ ํ๋ ฌ ๊ณ์ฐ
- ๊ณ ์ฐจ์ ํ๋ ฌ ๊ณฑ์ ๋ฐ ํฉ์ฑ๊ณฑ ๊ณ์ฐ์์ GPU ํ์ฉ ์ฑ๋ฅ์ด ๋ฐ์ด๋ฉ๋๋ค.
๐บ๏ธ JAX ํ์ต์ ์ํ ๋ก๋๋งต
- ๊ธฐ์ด ๋ฌธ๋ฒ ์ตํ๊ธฐ
- ๋ฐฐ์ด ์ฐ์ฐ, ๊ธฐ๋ณธ ํจ์ ์ฌ์ฉ๋ฒ
- ๊ณ ๊ธ ๊ธฐ๋ฅ ํ๊ตฌ
- ์๋ ๋ฏธ๋ถ๊ณผ JIT ์ปดํ์ผ์ ํตํ ์ฑ๋ฅ ํฅ์
- ํ๋ก์ ํธ ์ค์ต
- ๊ฐํ ํ์ต ๋ชจ๋ธ ๊ตฌ์ถ
- ๊ณ ์ฐจ์ ๋ฐ์ดํฐ ๋ถ์
- ์ค์ ๋ฐ์ดํฐ ์ฌ์ฉ ํ๋ก์ ํธ
- MNIST ์ด๋ฏธ์ง ๋ถ๋ฅ
- ์๊ณ์ด ์์ธก ๋ชจ๋ธ
JAX, Python, ์๋ ๋ฏธ๋ถ, JIT ์ปดํ์ผ, GPU ์ฐ์ฐ, TPU ํ์ฉ, ๋ฅ๋ฌ๋, ํจ์ํ ํ๋ก๊ทธ๋๋ฐ, ๊ณ ์ฑ๋ฅ ๊ณ์ฐ, ์์น ์ฐ์ฐ, AI ๋ชจ๋ธ ํ์ต, JAX ์ค์น, ๊ธฐ์ด ์ฌ์ฉ๋ฒ, ์ฑ๋ฅ ์ต์ ํ, ๊ณ ๊ธ ๊ธฐ๋ฅ
JAX์ ๊ฐ๋ ฅํจ์ ๋จ์ํ ์์น ์ฐ์ฐ์ ๋์ด์ ๋ค์ํ ๋จธ์ ๋ฌ๋ ํ๋ก์ ํธ์ ํ์ฉํ ์ ์๋ค๋ ๋ฐ ์์ต๋๋ค.
๋ค์ ๊ธ์์๋ JAX๋ฅผ ์ด์ฉํ ์ค์ ํ๋ก์ ํธ ์์ ๋ก ๊ฐํ ํ์ต ์๊ณ ๋ฆฌ์ฆ์ ๊ตฌํํด๋ณด๊ฒ ์ต๋๋ค! ๐
'Programming > Python' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
๐ฑ ํ๋ ์ธ๊ณต์ง๋ฅ ํ์ต 1๋จ๊ณ: ๊ธฐ์ด ์ํ ๋ฐ ์ ํ๋์ (0) | 2025.05.08 |
---|---|
๐ฑ ํ๋ ์ธ๊ณต์ง๋ฅ ํ์ต 1๋จ๊ณ: ๊ธฐ์ด ํ๋ก๊ทธ๋๋ฐ (Python) (0) | 2025.05.08 |
[Python] Ubuntu 20.04์ python 3.10 ์ค์น (0) | 2024.03.11 |
[Python] Tensoflow F1 score metrics (0) | 2022.09.22 |
[python]๋ฌด์์ ํ์ดํฌ ์ด๋ฆ ์์ฑ๊ธฐ (0) | 2022.05.13 |
- Total
- Today
- Yesterday
- PostgreSQL
- Ktor
- gatsbyjs
- nextJS
- ํ๋ก ํธ์๋
- SEO ์ต์ ํ
- Webpack
- CI/CD
- nodejs
- fastapi
- github
- kotlin
- NestJS
- App Router
- rag
- Next.js
- ํ๋ก ํธ์๋๋ฉด์
- Python
- ๋ฐฑ์๋๊ฐ๋ฐ
- ์น๊ฐ๋ฐ
- LangChain
- AI์ฑ๋ด
- Prisma
- REACT
- SEO์ต์ ํ
- Docker
- llm
- seo ์ต์ ํ 10๊ฐ
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- ๊ด๋ฆฌ์
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | ||||
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | 31 |