Programming/JAX

πŸ“Œ JAX둜 GAN(생성적 μ λŒ€ 신경망) κ΅¬ν˜„ - 이미지 생성 ν”„λ‘œμ νŠΈ

octo54 2025. 5. 20. 11:01
λ°˜μ‘ν˜•

πŸ“Œ JAX둜 GAN(생성적 μ λŒ€ 신경망) κ΅¬ν˜„ - 이미지 생성 ν”„λ‘œμ νŠΈ


πŸš€ GAN(생성적 μ λŒ€ 신경망)μ΄λž€?

GAN(Generative Adversarial Network)은 **μƒμ„±μž(Generator)**와 **νŒλ³„μž(Discriminator)**λΌλŠ” 두 개의 신경망을 μ‚¬μš©ν•˜μ—¬ 데이터λ₯Ό ν•™μŠ΅ν•˜κ³  μƒμ„±ν•˜λŠ” λͺ¨λΈμž…λ‹ˆλ‹€.

  • μƒμ„±μž(G): λžœλ€ν•œ 작음 λ²‘ν„°λ‘œλΆ€ν„° μ‹€μ œ 같은 데이터λ₯Ό 생성
  • νŒλ³„μž(D): 데이터가 μ‹€μ œ(real)인지 μƒμ„±λœ(fake) 것인지 νŒλ³„
  • λͺ©ν‘œ: μƒμ„±μžλŠ” νŒλ³„μžλ₯Ό 속일 μ •λ„λ‘œ μ‹€μ œ 같은 데이터λ₯Ό 생성

πŸ’‘ 1. GAN의 ν•™μŠ΅ κ³Όμ •

πŸ“ GAN μ•Œκ³ λ¦¬μ¦˜ ꡬ쑰

  1. μƒμ„±μž ν•™μŠ΅: νŒλ³„μžλ₯Ό 속이도둝 κ°€μ§œ 데이터λ₯Ό 생성
  2. νŒλ³„μž ν•™μŠ΅: μ‹€μ œ 데이터와 생성 데이터λ₯Ό ꡬ뢄
  3. μ λŒ€μ  ν›ˆλ ¨: 두 λͺ¨λΈμ΄ μ„œλ‘œ κ²½μŸν•˜λ©° λ°œμ „

πŸ”„ 손싀 ν•¨μˆ˜ μ •μ˜

  • μƒμ„±μž 손싀:

LG=−E[log⁑(D(G(z)))]L_G = -\mathbb{E}[\log(D(G(z)))]

  • νŒλ³„μž 손싀:

LD=−E[log⁑(D(x))]−E[log⁑(1−D(G(z)))]L_D = -\mathbb{E}[\log(D(x))] - \mathbb{E}[\log(1 - D(G(z)))]


πŸ”§ 2. 라이브러리 μ„€μΉ˜

pip install jax jaxlib flax optax tensorflow-datasets

πŸ’Ύ 3. 데이터 μ€€λΉ„ - MNIST 손글씨 데이터

import tensorflow_datasets as tfds

# MNIST 데이터 뢈러였기
ds = tfds.load("mnist", split="train", as_supervised=True)

# 데이터 μ „μ²˜λ¦¬ ν•¨μˆ˜
def preprocess(image, label):
    image = (image / 255.0) * 2 - 1  # [-1, 1]둜 μ •κ·œν™”
    return jnp.array(image, dtype=jnp.float32).reshape(-1), label

# 데이터 λ³€ν™˜
train_data = [preprocess(image, label) for image, label in tfds.as_numpy(ds)]
print(f"ν›ˆλ ¨ 데이터 수: {len(train_data)}")

🧠 4. GAN λͺ¨λΈ κ΅¬ν˜„

λ°˜μ‘ν˜•

πŸ“ μƒμ„±μž λͺ¨λΈ (Generator)

from flax import linen as nn
import jax.numpy as jnp

class Generator(nn.Module):
    latent_dim: int

    @nn.compact
    def __call__(self, z):
        x = nn.Dense(128)(z)
        x = nn.relu(x)
        x = nn.Dense(256)(x)
        x = nn.relu(x)
        x = nn.Dense(512)(x)
        x = nn.relu(x)
        x = nn.Dense(28 * 28)(x)
        x = nn.tanh(x)  # [-1, 1] λ²”μœ„λ‘œ
        return x.reshape((-1, 28, 28, 1))

πŸ“ νŒλ³„μž λͺ¨λΈ (Discriminator)

class Discriminator(nn.Module):

    @nn.compact
    def __call__(self, x):
        x = x.reshape(-1, 28 * 28)  # 평탄화
        x = nn.Dense(512)(x)
        x = nn.leaky_relu(x, 0.2)
        x = nn.Dense(256)(x)
        x = nn.leaky_relu(x, 0.2)
        x = nn.Dense(1)(x)
        return x

πŸ“‰ 5. 손싀 ν•¨μˆ˜μ™€ μ΅œμ ν™” μ„€μ •

πŸ“ 손싀 ν•¨μˆ˜

import optax

# μƒμ„±μž 손싀
def generator_loss(logits_fake):
    return -jnp.mean(jax.nn.sigmoid(logits_fake))

# νŒλ³„μž 손싀
def discriminator_loss(logits_real, logits_fake):
    loss_real = -jnp.mean(jax.nn.log_sigmoid(logits_real))
    loss_fake = -jnp.mean(jax.nn.log_sigmoid(-logits_fake))
    return loss_real + loss_fake

# μ˜΅ν‹°λ§ˆμ΄μ €
optimizer = optax.adam(1e-4)

πŸ—οΈ 6. λͺ¨λΈ μ΄ˆκΈ°ν™”

from flax.training import train_state
import jax
from jax import random

key = random.PRNGKey(42)

# λͺ¨λΈ μ΄ˆκΈ°ν™”
gen_model = Generator(latent_dim=100)
dis_model = Discriminator()

gen_params = gen_model.init(key, random.normal(key, (1, 100)))
dis_params = dis_model.init(key, random.normal(key, (1, 28, 28, 1)))

# ν•™μŠ΅ μƒνƒœ μ΄ˆκΈ°ν™”
gen_state = train_state.TrainState.create(apply_fn=gen_model.apply, params=gen_params, tx=optimizer)
dis_state = train_state.TrainState.create(apply_fn=dis_model.apply, params=dis_params, tx=optimizer)

πŸ” 7. ν•™μŠ΅ 루프 μ •μ˜

πŸ“ μ—…λ°μ΄νŠΈ ν•¨μˆ˜

@jax.jit
def train_step(gen_state, dis_state, real_images, z):
    def gen_loss_fn(params):
        fake_images = gen_model.apply(params, z)
        logits_fake = dis_model.apply(dis_state.params, fake_images)
        return generator_loss(logits_fake)

    def dis_loss_fn(params):
        logits_real = dis_model.apply(params, real_images)
        fake_images = gen_model.apply(gen_state.params, z)
        logits_fake = dis_model.apply(params, fake_images)
        return discriminator_loss(logits_real, logits_fake)

    # 기울기 계산
    gen_loss, gen_grads = jax.value_and_grad(gen_loss_fn)(gen_state.params)
    dis_loss, dis_grads = jax.value_and_grad(dis_loss_fn)(dis_state.params)

    # νŒŒλΌλ―Έν„° μ—…λ°μ΄νŠΈ
    gen_state = gen_state.apply_gradients(grads=gen_grads)
    dis_state = dis_state.apply_gradients(grads=dis_grads)
    
    return gen_state, dis_state, gen_loss, dis_loss

πŸƒ 8. λͺ¨λΈ ν•™μŠ΅

epochs = 50
batch_size = 64

for epoch in range(epochs):
    for i in range(0, len(train_data), batch_size):
        batch = train_data[i:i + batch_size]
        real_images, _ = zip(*batch)
        real_images = jnp.stack(real_images)

        # 랜덀 작음 생성
        z = random.normal(key, (batch_size, 100))

        # ν•™μŠ΅ μŠ€ν…
        gen_state, dis_state, g_loss, d_loss = train_step(gen_state, dis_state, real_images, z)

    print(f"Epoch {epoch+1}, Gen Loss: {g_loss:.4f}, Dis Loss: {d_loss:.4f}")

βœ… 9. 이미지 생성 κ²°κ³Ό 확인

import matplotlib.pyplot as plt

def generate_images(gen_state, num_images=16):
    z = random.normal(key, (num_images, 100))
    generated = gen_model.apply(gen_state.params, z)
    generated = (generated + 1) / 2  # [-1, 1] -> [0, 1]

    plt.figure(figsize=(4, 4))
    for i in range(num_images):
        plt.subplot(4, 4, i+1)
        plt.imshow(generated[i, :, :, 0], cmap='gray')
        plt.axis('off')
    plt.show()

generate_images(gen_state)

πŸ“Œ λ‹€μŒ κΈ€ 예고: JAX둜 VAE(λ³€μ΄ν˜• μ˜€ν† μΈμ½”λ”) κ΅¬ν˜„

λ‹€μŒ κΈ€μ—μ„œλŠ” JAXλ₯Ό μ‚¬μš©ν•˜μ—¬ VAE λͺ¨λΈμ„ κ΅¬μΆ•ν•˜κ³ ,
잠재 κ³΅κ°„μ—μ„œμ˜ 의미 μžˆλŠ” 데이터 생성 방식을 νƒκ΅¬ν•˜κ² μŠ΅λ‹ˆλ‹€.


 

JAX, GAN, 이미지 생성, λ”₯λŸ¬λ‹, Python, 고속 μ—°μ‚°, 생성적 μ λŒ€ 신경망, κ³ μ„±λŠ₯ μ»΄ν“¨νŒ…, ν•™μŠ΅ 루프, λͺ¨λΈ ν•™μŠ΅, MNIST, 이미지 생성 λͺ¨λΈ, 랜덀 작음, Generator, Discriminator, 인곡지λŠ₯ λͺ¨λΈλ§