๐ JAX๋ฅผ ํ์ฉํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ถ - MLP๋ฅผ ์ด์ฉํ ์๊ธ์จ ์ธ์ (MNIST)
๐ JAX๋ฅผ ํ์ฉํ ์ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ถ - MLP๋ฅผ ์ด์ฉํ ์๊ธ์จ ์ธ์ (MNIST)
๐ JAX๋ก ์ ๊ฒฝ๋ง ๋ชจ๋ธ ๋ง๋ค๊ธฐ
์ด์ ๊ธ์์๋ JAX์ ์๋ ๋ฏธ๋ถ๊ณผ JIT ์ปดํ์ผ์ ๊ธฐ๋ณธ ์ฌ์ฉ๋ฒ์ ๋ฐฐ์ ์ต๋๋ค.
์ด๋ฒ์๋ ์ด๋ฅผ ํ์ฉํ์ฌ ๊ฐ๋จํ ๋ค์ธต ํผ์
ํธ๋ก (MLP) ๋ชจ๋ธ์ ๊ตฌ์ถํ์ฌ ์๊ธ์จ ๋ฐ์ดํฐ(MNIST)๋ฅผ ๋ถ๋ฅํด๋ณด๊ฒ ์ต๋๋ค.
๐ 1. ๋ฐ์ดํฐ์ ์ค๋น
๐พ MNIST ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
JAX์์๋ ์ง์ ๋ฐ์ดํฐ๋ฅผ ๋ถ๋ฌ์์ผ ํ๋ฏ๋ก tensorflow_datasets๋ฅผ ์ฌ์ฉํ์ฌ ๋ฐ์ดํฐ๋ฅผ ๋ก๋ํฉ๋๋ค.
pip install tensorflow-datasets
import tensorflow_datasets as tfds
import jax.numpy as jnp
# MNIST ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
ds = tfds.load('mnist', split='train', as_supervised=True)
# ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ ํจ์
def preprocess(image, label):
image = jnp.array(image, dtype=jnp.float32) / 255.0
label = jnp.array(label, dtype=jnp.int32)
return image.reshape(-1), label
# ๋ฐ์ดํฐ ๋ณํ
train_data = [(preprocess(image, label)) for image, label in tfds.as_numpy(ds)]
print(f"ํ๋ จ ๋ฐ์ดํฐ ์ํ ์: {len(train_data)}")
๐ก 2. ์ ๊ฒฝ๋ง ๋ชจ๋ธ ์ ์
๐ง MLP ๋ชจ๋ธ ๊ตฌ์กฐ
- ์ ๋ ฅ์ธต: 784 (28x28 ์ด๋ฏธ์ง ํผ์นจ)
- ์๋์ธต: 128 (ReLU ํ์ฑํ ํจ์)
- ์ถ๋ ฅ์ธต: 10 (Softmax ํ์ฑํ ํจ์)
import jax
from jax import random, jit, grad
import jax.numpy as jnp
# ํ๋ผ๋ฏธํฐ ์ด๊ธฐํ ํจ์
def init_params(layer_sizes, key):
params = []
for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
key, subkey = random.split(key)
weights = random.normal(subkey, (n_in, n_out)) * 0.01
biases = jnp.zeros(n_out)
params.append((weights, biases))
return params
# MLP ๋ชจ๋ธ ํจ์
def predict(params, x):
for w, b in params[:-1]:
x = jnp.dot(x, w) + b
x = jnp.maximum(x, 0) # ReLU ํ์ฑํ ํจ์
final_w, final_b = params[-1]
logits = jnp.dot(x, final_w) + final_b
return logits - jax.scipy.special.logsumexp(logits, axis=1, keepdims=True)
โ๏ธ 3. ์์ค ํจ์์ ์ ํ๋ ๊ณ์ฐ
๐ ์์ค ํจ์ (Cross Entropy)
def cross_entropy_loss(params, x, y):
logits = predict(params, x)
one_hot = jax.nn.one_hot(y, num_classes=10)
return -jnp.mean(jnp.sum(one_hot * logits, axis=1))
๐งฎ ์ ํ๋ ๊ณ์ฐ ํจ์
def accuracy(params, x, y):
logits = predict(params, x)
predictions = jnp.argmax(logits, axis=1)
return jnp.mean(predictions == y)
๐ง 4. ํ์ต ๋ฃจํ ์ ์
๐ ํ๋ จ ๋จ๊ณ
learning_rate = 0.01
epochs = 5
batch_size = 128
key = random.PRNGKey(42)
# ๋ชจ๋ธ ์ด๊ธฐํ
params = init_params([784, 128, 10], key)
# ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ ํจ์
grad_loss = jit(grad(cross_entropy_loss))
# ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ ํจ์
@jit
def update(params, x, y, lr):
grads = grad_loss(params, x, y)
return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]
๐ 5. ๋ชจ๋ธ ํ์ต
๐ ํ์ต ๋ฃจํ
for epoch in range(epochs):
# ๋ฏธ๋๋ฐฐ์น ํ์ต
for i in range(0, len(train_data), batch_size):
batch = train_data[i:i + batch_size]
x_batch, y_batch = zip(*batch)
x_batch = jnp.stack(x_batch)
y_batch = jnp.array(y_batch)
# ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ
params = update(params, x_batch, y_batch, learning_rate)
# ์ํฌํฌ๋ณ ์์ค ๋ฐ ์ ํ๋ ์ถ๋ ฅ
train_loss = cross_entropy_loss(params, x_batch, y_batch)
train_acc = accuracy(params, x_batch, y_batch)
print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")
โ 6. ํ์ต ๊ฒฐ๊ณผ ํ๊ฐ
ํ์ต์ด ์๋ฃ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ํ ์คํธ ๋ฐ์ดํฐ๋ฅผ ํ๊ฐํฉ๋๋ค.
๐งฉ ํ ์คํธ ๋ฐ์ดํฐ ์ ํ๋ ๊ณ์ฐ
# ํ
์คํธ ๋ฐ์ดํฐ ๋ก๋
ds_test = tfds.load('mnist', split='test', as_supervised=True)
test_data = [(preprocess(image, label)) for image, label in tfds.as_numpy(ds_test)]
# ํ๊ฐ
x_test, y_test = zip(*test_data)
x_test = jnp.stack(x_test)
y_test = jnp.array(y_test)
test_acc = accuracy(params, x_test, y_test)
print(f"ํ
์คํธ ์ ํ๋: {test_acc:.4f}")
๐ JAX์ ์ฅ์ ํ์ฉ
- ์๋ ๋ฏธ๋ถ
- grad()๋ฅผ ์ฌ์ฉํ์ฌ ๊ธฐ์ธ๊ธฐ ๊ณ์ฐ์ ๊ฐ๋จํ๊ฒ ์ํํฉ๋๋ค.
- JIT ์ปดํ์ผ
- jit()๋ฅผ ์ฌ์ฉํ์ฌ ํ์ต ์๋๋ฅผ ๋ํญ ํฅ์์์ผฐ์ต๋๋ค.
- ํจ์ ๋ฒกํฐํ
- JAX์ ๋ฒกํฐํ ๊ธฐ๋ฅ์ ํตํด ๋ฐฐ์น ๋ฐ์ดํฐ๋ฅผ ํจ๊ณผ์ ์ผ๋ก ์ฒ๋ฆฌํ์ต๋๋ค.
๐ก ๋ชจ๋ธ ๊ฐ์ ์์ด๋์ด
- ๋ฐ์ดํฐ ์ฆ๊ฐ: ํ์ต ๋ฐ์ดํฐ๋ฅผ ๋๋ ค ๋ชจ๋ธ ์ฑ๋ฅ ํฅ์
- ๋๋กญ์์ ์ถ๊ฐ: ๊ณผ์ ํฉ ๋ฐฉ์ง
- ๊ณ ๊ธ ์ตํฐ๋ง์ด์ ์ฌ์ฉ: Adam์ด๋ RMSprop์ผ๋ก ์ฑ๋ฅ ๊ฐ์
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก CNN ๊ตฌํํ๊ธฐ
๋ค์ ๊ธ์์๋ JAX๋ฅผ ์ฌ์ฉํ์ฌ **ํฉ์ฑ๊ณฑ ์ ๊ฒฝ๋ง(CNN)**์ ๊ตฌ์ถํ๊ณ , ์ด๋ฏธ์ง ๋ถ๋ฅ ๋ฌธ์ ๋ฅผ ๋ ๋ณต์กํ๊ฒ ํด๊ฒฐํด๋ณด๊ฒ ์ต๋๋ค.
GPU ํ์ฉ์ ๊ทน๋ํํ์ฌ ์ฑ๋ฅ์ ๋น๊ต ๋ถ์ํด๋ณด๊ฒ ์ต๋๋ค.
JAX, ์ ๊ฒฝ๋ง, MLP, MNIST, ๋ฅ๋ฌ๋, ๋จธ์ ๋ฌ๋, Python, ์๋ ๋ฏธ๋ถ, JIT ์ปดํ์ผ, ํ์ต ๋ฃจํ, ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ๋ชจ๋ธ ํ์ต, ์ ํ๋ ๊ณ์ฐ, GPU ํ์ฉ, ๊ณ ์ฑ๋ฅ ๋ชจ๋ธ, ํจ์ ๋ฒกํฐํ, ๋ชจ๋ธ ํ๊ฐ