ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ Time-Series GAN(TSGAN) ๊ตฌํ˜„ - ํ˜„์‹ค๊ฐ ์žˆ๋Š” ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ์ƒ์„ฑํ•˜๊ธฐ


๐Ÿš€ ์™œ ์‹œ๊ณ„์—ด GAN์ด ํ•„์š”ํ•œ๊ฐ€?

ํ˜„์‹ค ์„ธ๊ณ„์—์„œ๋Š” ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ๊ฐ€ ์ ๊ฑฐ๋‚˜ ๋ถˆ๊ท ํ˜•ํ•œ ๊ฒฝ์šฐ๊ฐ€ ๋งŽ์Šต๋‹ˆ๋‹ค.
์˜ˆ๋ฅผ ๋“ค์–ด:

  • ์„ผ์„œ ์˜ค์ž‘๋™ ์ƒํ™ฉ์ด ์ ๊ฒŒ ๊ธฐ๋ก๋จ
  • ์ด์ƒ ๋ฐ์ดํ„ฐ ์ˆ˜์ง‘์ด ์ œํ•œ์ 
  • ๊ณ ๊ฐ€์˜ ์ธก์ • ๋น„์šฉ ๋“ฑ

**์‹œ๊ณ„์—ด GAN(Time-Series GAN, TSGAN)**์€ ์ด๋Ÿฌํ•œ ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๊ธฐ ์œ„ํ•ด
์‹ค์ œ์™€ ์œ ์‚ฌํ•œ ์ƒˆ๋กœ์šด ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ๋ฅผ ์ƒ์„ฑํ•  ์ˆ˜ ์žˆ๋Š” ๊ฐ•๋ ฅํ•œ ๋„๊ตฌ์ž…๋‹ˆ๋‹ค.


๐Ÿ’ก 1. TSGAN ๊ตฌ์กฐ

๊ตฌ์„ฑ ์š”์†Œ ์—ญํ• 

Generator (G) ๋žœ๋ค ๋…ธ์ด์ฆˆ๋ฅผ ๋ฐ›์•„์„œ ์‹œ๊ณ„์—ด์„ ์ƒ์„ฑ
Discriminator (D) ์‹œ๊ณ„์—ด์ด ์ง„์งœ์ธ์ง€(GT) ๊ฐ€์งœ์ธ์ง€(G์—์„œ ์ƒ์„ฑ) ํŒ๋ณ„
Loss G๋Š” D๋ฅผ ์†์ด๋„๋ก ํ•™์Šต, D๋Š” G๋ฅผ ํŒ๋ณ„ํ•˜๋„๋ก ํ•™์Šต

๐Ÿงฑ 2. ๊ตฌํ˜„ ๊ฐœ์š” (JAX + Flax)

  • ์ƒ์„ฑ์ž: LSTM ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ์ƒ์„ฑ ๋„คํŠธ์›Œํฌ
  • ํŒ๋ณ„์ž: ์‹œ๊ณ„์—ด์„ ๋ณด๊ณ  ์ง„์งœ/๊ฐ€์งœ ์—ฌ๋ถ€๋ฅผ ๋ถ„๋ฅ˜
  • ์ž…๋ ฅ ๋ฐ์ดํ„ฐ: ์„ผ์„œ ๋˜๋Š” ๊ธˆ์œต ์‹œ๊ณ„์—ด (์œˆ๋„์šฐํ˜• ๋ฐ์ดํ„ฐ)

๐Ÿ’พ 3. ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ

import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler

# ์˜ˆ์‹œ: ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
df = pd.read_csv("sensor_data.csv")  # 'value' ์—ด ํฌํ•จ ๊ฐ€์ •
data = df['value'].values.reshape(-1, 1)

# ์ •๊ทœํ™”
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)

# ์œˆ๋„์šฐํ™”
def make_windows(data, window=30):
    X = []
    for i in range(len(data) - window):
        X.append(data[i:i+window])
    return jnp.array(X)

real_data = make_windows(data_scaled)  # shape: (samples, 30, 1)

๐Ÿง  4. ๋ชจ๋ธ ์ •์˜ - Generator / Discriminator

๋ฐ˜์‘ํ˜•

๐Ÿงฌ Generator

from flax import linen as nn

class Generator(nn.Module):
    hidden_size: int = 64
    seq_len: int = 30

    @nn.compact
    def __call__(self, z):  # z: (batch, latent_dim)
        x = nn.Dense(self.hidden_size)(z)
        x = nn.relu(x)
        x = nn.Dense(self.seq_len * 1)(x)  # output: (batch, 30*1)
        x = x.reshape((z.shape[0], self.seq_len, 1))
        return x

๐Ÿ” Discriminator

class Discriminator(nn.Module):
    hidden_size: int = 64

    @nn.compact
    def __call__(self, x):  # x: (batch, 30, 1)
        x = x.reshape((x.shape[0], -1))  # flatten
        x = nn.Dense(self.hidden_size)(x)
        x = nn.leaky_relu(x, 0.2)
        x = nn.Dense(1)(x)
        return x

โš™๏ธ 5. ํ•™์Šต ์ดˆ๊ธฐํ™”

import optax
from flax.training import train_state
import jax
import jax.random as random

gen = Generator()
dis = Discriminator()
key = random.PRNGKey(0)

latent_dim = 20
batch_input = jnp.ones((1, latent_dim))
real_input = jnp.ones((1, 30, 1))

gen_params = gen.init(key, batch_input)
dis_params = dis.init(key, real_input)

gen_tx = optax.adam(1e-4)
dis_tx = optax.adam(1e-4)

gen_state = train_state.TrainState.create(apply_fn=gen.apply, params=gen_params, tx=gen_tx)
dis_state = train_state.TrainState.create(apply_fn=dis.apply, params=dis_params, tx=dis_tx)

๐Ÿ“‰ 6. ์†์‹ค ํ•จ์ˆ˜ ์ •์˜ (GAN ๋ฐฉ์‹)

def discriminator_loss_fn(dis_params, real, fake):
    real_logits = dis.apply(dis_params, real)
    fake_logits = dis.apply(dis_params, fake)
    loss = -jnp.mean(jnp.log(jax.nn.sigmoid(real_logits)) + jnp.log(1 - jax.nn.sigmoid(fake_logits)))
    return loss

def generator_loss_fn(gen_params, dis_params, z):
    fake = gen.apply(gen_params, z)
    fake_logits = dis.apply(dis_params, fake)
    loss = -jnp.mean(jnp.log(jax.nn.sigmoid(fake_logits)))
    return loss

๐Ÿƒ 7. ํ•™์Šต ๋ฃจํ”„

@jax.jit
def train_step(gen_state, dis_state, real_batch, key):
    z = random.normal(key, (real_batch.shape[0], latent_dim))

    # Generator ์†์‹ค
    g_loss, g_grads = jax.value_and_grad(generator_loss_fn)(gen_state.params, dis_state.params, z)
    gen_state = gen_state.apply_gradients(grads=g_grads)

    # Discriminator ์†์‹ค
    fake_data = gen.apply(gen_state.params, z)
    d_loss, d_grads = jax.value_and_grad(discriminator_loss_fn)(dis_state.params, real_batch, fake_data)
    dis_state = dis_state.apply_gradients(grads=d_grads)

    return gen_state, dis_state, g_loss, d_loss

๐Ÿงช 8. ํ•™์Šต ์‹คํ–‰

batch_size = 64
epochs = 100

for epoch in range(epochs):
    for i in range(0, len(real_data), batch_size):
        real_batch = real_data[i:i+batch_size]
        key, subkey = random.split(key)
        gen_state, dis_state, g_loss, d_loss = train_step(gen_state, dis_state, real_batch, subkey)
    print(f"Epoch {epoch+1} | G Loss: {g_loss:.4f} | D Loss: {d_loss:.4f}")

๐ŸŽจ 9. ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ์ƒ์„ฑ ๋ฐ ์‹œ๊ฐํ™”

def generate_samples(gen_state, num_samples=10):
    z = random.normal(key, (num_samples, latent_dim))
    fake = gen.apply(gen_state.params, z)
    return fake

import matplotlib.pyplot as plt

samples = generate_samples(gen_state, 5)

plt.figure(figsize=(12, 4))
for i, seq in enumerate(samples):
    plt.plot(seq.squeeze(), label=f"Sample {i+1}")
plt.title("Generated Time-Series Samples (TSGAN)")
plt.legend()
plt.show()

โœ… 10. TSGAN ํ™œ์šฉ ์•„์ด๋””์–ด

์‘์šฉ ๋ถ„์•ผ ์„ค๋ช…

๋ฐ์ดํ„ฐ ์ฆ๊ฐ• ์ด์ƒ ์ƒํ™ฉ ์žฌํ˜„ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ
GAN-based ์ด์ƒ ํƒ์ง€ ์ƒ์„ฑ์ž๊ฐ€ ํ•™์Šตํ•˜์ง€ ๋ชปํ•œ ์‹œ๊ณ„์—ด์€ ์ด์ƒ์น˜๋กœ ๊ฐ„์ฃผ
์‹œ๋ฎฌ๋ ˆ์ด์…˜ ์‹œ์Šคํ…œ ์„ผ์„œ, ํ™˜๊ฒฝ ๋ณ€์ˆ˜, ์‚ฌ์šฉ๋Ÿ‰ ๋ฐ์ดํ„ฐ ์ƒ์„ฑ

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ ์‹œ๊ณ„์—ด + ์กฐ๊ฑด๋ถ€ GAN (cGAN) ๊ตฌํ˜„ - ์กฐ๊ฑด ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ์ƒ์„ฑ

๋‹ค์Œ ๊ธ€์—์„œ๋Š” **์กฐ๊ฑด๋ถ€ GAN(Conditional GAN)**์„ ์‚ฌ์šฉํ•ด
"์„ผ์„œ ์ƒํƒœ๊ฐ€ ์ข‹์„ ๋•Œ์˜ ์‹œ๊ณ„์—ด", "์ฃผ๊ฐ€ ์ƒ์Šน ๊ตฌ๊ฐ„ ์˜ˆ์ธก" ๋“ฑ
์กฐ๊ฑด์„ ์ฃผ์–ด ๊ฐ€์งœ ์‹œ๊ณ„์—ด์„ ์ œ์–ดํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ฐฐ์›๋‹ˆ๋‹ค.


 

JAX, GAN, ์‹œ๊ณ„์—ด ์ƒ์„ฑ, TSGAN, ๋”ฅ๋Ÿฌ๋‹, Flax, ๊ณ ์† ์—ฐ์‚ฐ, Anomaly Detection, ์„ผ์„œ ์‹œ๊ณ„์—ด, ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ ์ฆ๊ฐ•, Generative Model, Autoencoder, Time-Series Forecasting, JAX ์˜ˆ์ œ, Python GAN, ์‹œ๊ณ„์—ด ๋”ฅ๋Ÿฌ๋‹

โ€ป ์ด ํฌ์ŠคํŒ…์€ ์ฟ ํŒก ํŒŒํŠธ๋„ˆ์Šค ํ™œ๋™์˜ ์ผํ™˜์œผ๋กœ, ์ด์— ๋”ฐ๋ฅธ ์ผ์ •์•ก์˜ ์ˆ˜์ˆ˜๋ฃŒ๋ฅผ ์ œ๊ณต๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ณต์ง€์‚ฌํ•ญ
์ตœ๊ทผ์— ์˜ฌ๋ผ์˜จ ๊ธ€
์ตœ๊ทผ์— ๋‹ฌ๋ฆฐ ๋Œ“๊ธ€
Total
Today
Yesterday
๋งํฌ
ยซ   2025/06   ยป
์ผ ์›” ํ™” ์ˆ˜ ๋ชฉ ๊ธˆ ํ† 
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
๊ธ€ ๋ณด๊ด€ํ•จ
๋ฐ˜์‘ํ˜•