Programming/JAX
π JAXλ‘ GAN(μμ±μ μ λ μ κ²½λ§) ꡬν - μ΄λ―Έμ§ μμ± νλ‘μ νΈ
octo54
2025. 5. 20. 11:01
λ°μν
π JAXλ‘ GAN(μμ±μ μ λ μ κ²½λ§) ꡬν - μ΄λ―Έμ§ μμ± νλ‘μ νΈ
π GAN(μμ±μ μ λ μ κ²½λ§)μ΄λ?
GAN(Generative Adversarial Network)μ **μμ±μ(Generator)**μ **νλ³μ(Discriminator)**λΌλ λ κ°μ μ κ²½λ§μ μ¬μ©νμ¬ λ°μ΄ν°λ₯Ό νμ΅νκ³ μμ±νλ λͺ¨λΈμ λλ€.
- μμ±μ(G): λλ€ν μ‘μ 벑ν°λ‘λΆν° μ€μ κ°μ λ°μ΄ν°λ₯Ό μμ±
- νλ³μ(D): λ°μ΄ν°κ° μ€μ (real)μΈμ§ μμ±λ(fake) κ²μΈμ§ νλ³
- λͺ©ν: μμ±μλ νλ³μλ₯Ό μμΌ μ λλ‘ μ€μ κ°μ λ°μ΄ν°λ₯Ό μμ±
π‘ 1. GANμ νμ΅ κ³Όμ
π GAN μκ³ λ¦¬μ¦ κ΅¬μ‘°
- μμ±μ νμ΅: νλ³μλ₯Ό μμ΄λλ‘ κ°μ§ λ°μ΄ν°λ₯Ό μμ±
- νλ³μ νμ΅: μ€μ λ°μ΄ν°μ μμ± λ°μ΄ν°λ₯Ό ꡬλΆ
- μ λμ νλ ¨: λ λͺ¨λΈμ΄ μλ‘ κ²½μνλ©° λ°μ
π μμ€ ν¨μ μ μ
- μμ±μ μμ€:
LG=−E[logβ‘(D(G(z)))]L_G = -\mathbb{E}[\log(D(G(z)))]
- νλ³μ μμ€:
LD=−E[logβ‘(D(x))]−E[logβ‘(1−D(G(z)))]L_D = -\mathbb{E}[\log(D(x))] - \mathbb{E}[\log(1 - D(G(z)))]
π§ 2. λΌμ΄λΈλ¬λ¦¬ μ€μΉ
pip install jax jaxlib flax optax tensorflow-datasets
πΎ 3. λ°μ΄ν° μ€λΉ - MNIST μκΈμ¨ λ°μ΄ν°
import tensorflow_datasets as tfds
# MNIST λ°μ΄ν° λΆλ¬μ€κΈ°
ds = tfds.load("mnist", split="train", as_supervised=True)
# λ°μ΄ν° μ μ²λ¦¬ ν¨μ
def preprocess(image, label):
image = (image / 255.0) * 2 - 1 # [-1, 1]λ‘ μ κ·ν
return jnp.array(image, dtype=jnp.float32).reshape(-1), label
# λ°μ΄ν° λ³ν
train_data = [preprocess(image, label) for image, label in tfds.as_numpy(ds)]
print(f"νλ ¨ λ°μ΄ν° μ: {len(train_data)}")
π§ 4. GAN λͺ¨λΈ ꡬν
λ°μν
π μμ±μ λͺ¨λΈ (Generator)
from flax import linen as nn
import jax.numpy as jnp
class Generator(nn.Module):
latent_dim: int
@nn.compact
def __call__(self, z):
x = nn.Dense(128)(z)
x = nn.relu(x)
x = nn.Dense(256)(x)
x = nn.relu(x)
x = nn.Dense(512)(x)
x = nn.relu(x)
x = nn.Dense(28 * 28)(x)
x = nn.tanh(x) # [-1, 1] λ²μλ‘
return x.reshape((-1, 28, 28, 1))
π νλ³μ λͺ¨λΈ (Discriminator)
class Discriminator(nn.Module):
@nn.compact
def __call__(self, x):
x = x.reshape(-1, 28 * 28) # ννν
x = nn.Dense(512)(x)
x = nn.leaky_relu(x, 0.2)
x = nn.Dense(256)(x)
x = nn.leaky_relu(x, 0.2)
x = nn.Dense(1)(x)
return x
π 5. μμ€ ν¨μμ μ΅μ ν μ€μ
π μμ€ ν¨μ
import optax
# μμ±μ μμ€
def generator_loss(logits_fake):
return -jnp.mean(jax.nn.sigmoid(logits_fake))
# νλ³μ μμ€
def discriminator_loss(logits_real, logits_fake):
loss_real = -jnp.mean(jax.nn.log_sigmoid(logits_real))
loss_fake = -jnp.mean(jax.nn.log_sigmoid(-logits_fake))
return loss_real + loss_fake
# μ΅ν°λ§μ΄μ
optimizer = optax.adam(1e-4)
ποΈ 6. λͺ¨λΈ μ΄κΈ°ν
from flax.training import train_state
import jax
from jax import random
key = random.PRNGKey(42)
# λͺ¨λΈ μ΄κΈ°ν
gen_model = Generator(latent_dim=100)
dis_model = Discriminator()
gen_params = gen_model.init(key, random.normal(key, (1, 100)))
dis_params = dis_model.init(key, random.normal(key, (1, 28, 28, 1)))
# νμ΅ μν μ΄κΈ°ν
gen_state = train_state.TrainState.create(apply_fn=gen_model.apply, params=gen_params, tx=optimizer)
dis_state = train_state.TrainState.create(apply_fn=dis_model.apply, params=dis_params, tx=optimizer)
π 7. νμ΅ λ£¨ν μ μ
π μ λ°μ΄νΈ ν¨μ
@jax.jit
def train_step(gen_state, dis_state, real_images, z):
def gen_loss_fn(params):
fake_images = gen_model.apply(params, z)
logits_fake = dis_model.apply(dis_state.params, fake_images)
return generator_loss(logits_fake)
def dis_loss_fn(params):
logits_real = dis_model.apply(params, real_images)
fake_images = gen_model.apply(gen_state.params, z)
logits_fake = dis_model.apply(params, fake_images)
return discriminator_loss(logits_real, logits_fake)
# κΈ°μΈκΈ° κ³μ°
gen_loss, gen_grads = jax.value_and_grad(gen_loss_fn)(gen_state.params)
dis_loss, dis_grads = jax.value_and_grad(dis_loss_fn)(dis_state.params)
# νλΌλ―Έν° μ
λ°μ΄νΈ
gen_state = gen_state.apply_gradients(grads=gen_grads)
dis_state = dis_state.apply_gradients(grads=dis_grads)
return gen_state, dis_state, gen_loss, dis_loss
π 8. λͺ¨λΈ νμ΅
epochs = 50
batch_size = 64
for epoch in range(epochs):
for i in range(0, len(train_data), batch_size):
batch = train_data[i:i + batch_size]
real_images, _ = zip(*batch)
real_images = jnp.stack(real_images)
# λλ€ μ‘μ μμ±
z = random.normal(key, (batch_size, 100))
# νμ΅ μ€ν
gen_state, dis_state, g_loss, d_loss = train_step(gen_state, dis_state, real_images, z)
print(f"Epoch {epoch+1}, Gen Loss: {g_loss:.4f}, Dis Loss: {d_loss:.4f}")
β 9. μ΄λ―Έμ§ μμ± κ²°κ³Ό νμΈ
import matplotlib.pyplot as plt
def generate_images(gen_state, num_images=16):
z = random.normal(key, (num_images, 100))
generated = gen_model.apply(gen_state.params, z)
generated = (generated + 1) / 2 # [-1, 1] -> [0, 1]
plt.figure(figsize=(4, 4))
for i in range(num_images):
plt.subplot(4, 4, i+1)
plt.imshow(generated[i, :, :, 0], cmap='gray')
plt.axis('off')
plt.show()
generate_images(gen_state)
π λ€μ κΈ μκ³ : JAXλ‘ VAE(λ³μ΄ν μ€ν μΈμ½λ) ꡬν
λ€μ κΈμμλ JAXλ₯Ό μ¬μ©νμ¬ VAE λͺ¨λΈμ ꡬμΆνκ³ ,
μ μ¬ κ³΅κ°μμμ μλ―Έ μλ λ°μ΄ν° μμ± λ°©μμ νꡬνκ² μ΅λλ€.
JAX, GAN, μ΄λ―Έμ§ μμ±, λ₯λ¬λ, Python, κ³ μ μ°μ°, μμ±μ μ λ μ κ²½λ§, κ³ μ±λ₯ μ»΄ν¨ν , νμ΅ λ£¨ν, λͺ¨λΈ νμ΅, MNIST, μ΄λ―Έμ§ μμ± λͺ¨λΈ, λλ€ μ‘μ, Generator, Discriminator, μΈκ³΅μ§λ₯ λͺ¨λΈλ§