Programming/JAX

๐Ÿ“Œ JAX๋ฅผ ํ™œ์šฉํ•œ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ถ• - MLP๋ฅผ ์ด์šฉํ•œ ์†๊ธ€์”จ ์ธ์‹ (MNIST)

octo54 2025. 5. 9. 11:07
๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋ฅผ ํ™œ์šฉํ•œ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๊ตฌ์ถ• - MLP๋ฅผ ์ด์šฉํ•œ ์†๊ธ€์”จ ์ธ์‹ (MNIST)


๐Ÿš€ JAX๋กœ ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ๋งŒ๋“ค๊ธฐ

์ด์ „ ๊ธ€์—์„œ๋Š” JAX์˜ ์ž๋™ ๋ฏธ๋ถ„๊ณผ JIT ์ปดํŒŒ์ผ์˜ ๊ธฐ๋ณธ ์‚ฌ์šฉ๋ฒ•์„ ๋ฐฐ์› ์Šต๋‹ˆ๋‹ค.
์ด๋ฒˆ์—๋Š” ์ด๋ฅผ ํ™œ์šฉํ•˜์—ฌ ๊ฐ„๋‹จํ•œ ๋‹ค์ธต ํผ์…‰ํŠธ๋ก (MLP) ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜์—ฌ ์†๊ธ€์”จ ๋ฐ์ดํ„ฐ(MNIST)๋ฅผ ๋ถ„๋ฅ˜ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


๐Ÿ“ 1. ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

๐Ÿ’พ MNIST ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ

JAX์—์„œ๋Š” ์ง์ ‘ ๋ฐ์ดํ„ฐ๋ฅผ ๋ถˆ๋Ÿฌ์™€์•ผ ํ•˜๋ฏ€๋กœ tensorflow_datasets๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ฐ์ดํ„ฐ๋ฅผ ๋กœ๋“œํ•ฉ๋‹ˆ๋‹ค.

pip install tensorflow-datasets
import tensorflow_datasets as tfds
import jax.numpy as jnp

# MNIST ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
ds = tfds.load('mnist', split='train', as_supervised=True)

# ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ ํ•จ์ˆ˜
def preprocess(image, label):
    image = jnp.array(image, dtype=jnp.float32) / 255.0
    label = jnp.array(label, dtype=jnp.int32)
    return image.reshape(-1), label

# ๋ฐ์ดํ„ฐ ๋ณ€ํ™˜
train_data = [(preprocess(image, label)) for image, label in tfds.as_numpy(ds)]
print(f"ํ›ˆ๋ จ ๋ฐ์ดํ„ฐ ์ƒ˜ํ”Œ ์ˆ˜: {len(train_data)}")

๐Ÿ’ก 2. ์‹ ๊ฒฝ๋ง ๋ชจ๋ธ ์ •์˜

๐Ÿง  MLP ๋ชจ๋ธ ๊ตฌ์กฐ

  • ์ž…๋ ฅ์ธต: 784 (28x28 ์ด๋ฏธ์ง€ ํŽผ์นจ)
  • ์€๋‹‰์ธต: 128 (ReLU ํ™œ์„ฑํ™” ํ•จ์ˆ˜)
  • ์ถœ๋ ฅ์ธต: 10 (Softmax ํ™œ์„ฑํ™” ํ•จ์ˆ˜)
import jax
from jax import random, jit, grad
import jax.numpy as jnp

# ํŒŒ๋ผ๋ฏธํ„ฐ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
def init_params(layer_sizes, key):
    params = []
    for n_in, n_out in zip(layer_sizes[:-1], layer_sizes[1:]):
        key, subkey = random.split(key)
        weights = random.normal(subkey, (n_in, n_out)) * 0.01
        biases = jnp.zeros(n_out)
        params.append((weights, biases))
    return params

# MLP ๋ชจ๋ธ ํ•จ์ˆ˜
def predict(params, x):
    for w, b in params[:-1]:
        x = jnp.dot(x, w) + b
        x = jnp.maximum(x, 0)  # ReLU ํ™œ์„ฑํ™” ํ•จ์ˆ˜
    final_w, final_b = params[-1]
    logits = jnp.dot(x, final_w) + final_b
    return logits - jax.scipy.special.logsumexp(logits, axis=1, keepdims=True)

โš™๏ธ 3. ์†์‹ค ํ•จ์ˆ˜์™€ ์ •ํ™•๋„ ๊ณ„์‚ฐ

๐Ÿ“‰ ์†์‹ค ํ•จ์ˆ˜ (Cross Entropy)

def cross_entropy_loss(params, x, y):
    logits = predict(params, x)
    one_hot = jax.nn.one_hot(y, num_classes=10)
    return -jnp.mean(jnp.sum(one_hot * logits, axis=1))

๐Ÿงฎ ์ •ํ™•๋„ ๊ณ„์‚ฐ ํ•จ์ˆ˜

๋ฐ˜์‘ํ˜•
def accuracy(params, x, y):
    logits = predict(params, x)
    predictions = jnp.argmax(logits, axis=1)
    return jnp.mean(predictions == y)

๐Ÿ”ง 4. ํ•™์Šต ๋ฃจํ”„ ์ •์˜

๐Ÿ” ํ›ˆ๋ จ ๋‹จ๊ณ„

learning_rate = 0.01
epochs = 5
batch_size = 128
key = random.PRNGKey(42)

# ๋ชจ๋ธ ์ดˆ๊ธฐํ™”
params = init_params([784, 128, 10], key)

# ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ ํ•จ์ˆ˜
grad_loss = jit(grad(cross_entropy_loss))

# ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜
@jit
def update(params, x, y, lr):
    grads = grad_loss(params, x, y)
    return [(w - lr * dw, b - lr * db) for (w, b), (dw, db) in zip(params, grads)]

๐Ÿƒ 5. ๋ชจ๋ธ ํ•™์Šต

๐Ÿ“Š ํ•™์Šต ๋ฃจํ”„

for epoch in range(epochs):
    # ๋ฏธ๋‹ˆ๋ฐฐ์น˜ ํ•™์Šต
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i:i + batch_size]
        x_batch, y_batch = zip(*batch)
        x_batch = jnp.stack(x_batch)
        y_batch = jnp.array(y_batch)

        # ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ
        params = update(params, x_batch, y_batch, learning_rate)

    # ์—ํฌํฌ๋ณ„ ์†์‹ค ๋ฐ ์ •ํ™•๋„ ์ถœ๋ ฅ
    train_loss = cross_entropy_loss(params, x_batch, y_batch)
    train_acc = accuracy(params, x_batch, y_batch)
    print(f"Epoch {epoch + 1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}")

โœ… 6. ํ•™์Šต ๊ฒฐ๊ณผ ํ‰๊ฐ€

ํ•™์Šต์ด ์™„๋ฃŒ๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ๋ฅผ ํ‰๊ฐ€ํ•ฉ๋‹ˆ๋‹ค.

๐Ÿงฉ ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ์ •ํ™•๋„ ๊ณ„์‚ฐ

# ํ…Œ์ŠคํŠธ ๋ฐ์ดํ„ฐ ๋กœ๋“œ
ds_test = tfds.load('mnist', split='test', as_supervised=True)
test_data = [(preprocess(image, label)) for image, label in tfds.as_numpy(ds_test)]

# ํ‰๊ฐ€
x_test, y_test = zip(*test_data)
x_test = jnp.stack(x_test)
y_test = jnp.array(y_test)

test_acc = accuracy(params, x_test, y_test)
print(f"ํ…Œ์ŠคํŠธ ์ •ํ™•๋„: {test_acc:.4f}")

๐ŸŒŸ JAX์˜ ์žฅ์  ํ™œ์šฉ

  1. ์ž๋™ ๋ฏธ๋ถ„
    • grad()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ธฐ์šธ๊ธฐ ๊ณ„์‚ฐ์„ ๊ฐ„๋‹จํ•˜๊ฒŒ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  2. JIT ์ปดํŒŒ์ผ
    • jit()๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ํ•™์Šต ์†๋„๋ฅผ ๋Œ€ํญ ํ–ฅ์ƒ์‹œ์ผฐ์Šต๋‹ˆ๋‹ค.
  3. ํ•จ์ˆ˜ ๋ฒกํ„ฐํ™”
    • JAX์˜ ๋ฒกํ„ฐํ™” ๊ธฐ๋Šฅ์„ ํ†ตํ•ด ๋ฐฐ์น˜ ๋ฐ์ดํ„ฐ๋ฅผ ํšจ๊ณผ์ ์œผ๋กœ ์ฒ˜๋ฆฌํ–ˆ์Šต๋‹ˆ๋‹ค.

๐Ÿ’ก ๋ชจ๋ธ ๊ฐœ์„  ์•„์ด๋””์–ด

  • ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•: ํ•™์Šต ๋ฐ์ดํ„ฐ๋ฅผ ๋Š˜๋ ค ๋ชจ๋ธ ์„ฑ๋Šฅ ํ–ฅ์ƒ
  • ๋“œ๋กญ์•„์›ƒ ์ถ”๊ฐ€: ๊ณผ์ ํ•ฉ ๋ฐฉ์ง€
  • ๊ณ ๊ธ‰ ์˜ตํ‹ฐ๋งˆ์ด์ € ์‚ฌ์šฉ: Adam์ด๋‚˜ RMSprop์œผ๋กœ ์„ฑ๋Šฅ ๊ฐœ์„ 

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ CNN ๊ตฌํ˜„ํ•˜๊ธฐ

๋‹ค์Œ ๊ธ€์—์„œ๋Š” JAX๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ **ํ•ฉ์„ฑ๊ณฑ ์‹ ๊ฒฝ๋ง(CNN)**์„ ๊ตฌ์ถ•ํ•˜๊ณ , ์ด๋ฏธ์ง€ ๋ถ„๋ฅ˜ ๋ฌธ์ œ๋ฅผ ๋” ๋ณต์žกํ•˜๊ฒŒ ํ•ด๊ฒฐํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.
GPU ํ™œ์šฉ์„ ๊ทน๋Œ€ํ™”ํ•˜์—ฌ ์„ฑ๋Šฅ์„ ๋น„๊ต ๋ถ„์„ํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


 

JAX, ์‹ ๊ฒฝ๋ง, MLP, MNIST, ๋”ฅ๋Ÿฌ๋‹, ๋จธ์‹ ๋Ÿฌ๋‹, Python, ์ž๋™ ๋ฏธ๋ถ„, JIT ์ปดํŒŒ์ผ, ํ•™์Šต ๋ฃจํ”„, ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ, ๋ชจ๋ธ ํ•™์Šต, ์ •ํ™•๋„ ๊ณ„์‚ฐ, GPU ํ™œ์šฉ, ๊ณ ์„ฑ๋Šฅ ๋ชจ๋ธ, ํ•จ์ˆ˜ ๋ฒกํ„ฐํ™”, ๋ชจ๋ธ ํ‰๊ฐ€