ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ JAX๋ก Time-Series GAN(TSGAN) ๊ตฌํ - ํ์ค๊ฐ ์๋ ์๊ณ์ด ๋ฐ์ดํฐ ์์ฑํ๊ธฐ
octo54 2025. 5. 29. 12:30๐ JAX๋ก Time-Series GAN(TSGAN) ๊ตฌํ - ํ์ค๊ฐ ์๋ ์๊ณ์ด ๋ฐ์ดํฐ ์์ฑํ๊ธฐ
๐ ์ ์๊ณ์ด GAN์ด ํ์ํ๊ฐ?
ํ์ค ์ธ๊ณ์์๋ ์๊ณ์ด ๋ฐ์ดํฐ๊ฐ ์ ๊ฑฐ๋ ๋ถ๊ท ํํ ๊ฒฝ์ฐ๊ฐ ๋ง์ต๋๋ค.
์๋ฅผ ๋ค์ด:
- ์ผ์ ์ค์๋ ์ํฉ์ด ์ ๊ฒ ๊ธฐ๋ก๋จ
- ์ด์ ๋ฐ์ดํฐ ์์ง์ด ์ ํ์
- ๊ณ ๊ฐ์ ์ธก์ ๋น์ฉ ๋ฑ
**์๊ณ์ด GAN(Time-Series GAN, TSGAN)**์ ์ด๋ฌํ ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๊ธฐ ์ํด
์ค์ ์ ์ ์ฌํ ์๋ก์ด ์๊ณ์ด ๋ฐ์ดํฐ๋ฅผ ์์ฑํ ์ ์๋ ๊ฐ๋ ฅํ ๋๊ตฌ์
๋๋ค.
๐ก 1. TSGAN ๊ตฌ์กฐ
๊ตฌ์ฑ ์์ ์ญํ
Generator (G) | ๋๋ค ๋ ธ์ด์ฆ๋ฅผ ๋ฐ์์ ์๊ณ์ด์ ์์ฑ |
Discriminator (D) | ์๊ณ์ด์ด ์ง์ง์ธ์ง(GT) ๊ฐ์ง์ธ์ง(G์์ ์์ฑ) ํ๋ณ |
Loss | G๋ D๋ฅผ ์์ด๋๋ก ํ์ต, D๋ G๋ฅผ ํ๋ณํ๋๋ก ํ์ต |
๐งฑ 2. ๊ตฌํ ๊ฐ์ (JAX + Flax)
- ์์ฑ์: LSTM ๊ธฐ๋ฐ ์๊ณ์ด ์์ฑ ๋คํธ์ํฌ
- ํ๋ณ์: ์๊ณ์ด์ ๋ณด๊ณ ์ง์ง/๊ฐ์ง ์ฌ๋ถ๋ฅผ ๋ถ๋ฅ
- ์ ๋ ฅ ๋ฐ์ดํฐ: ์ผ์ ๋๋ ๊ธ์ต ์๊ณ์ด (์๋์ฐํ ๋ฐ์ดํฐ)
๐พ 3. ์๊ณ์ด ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler
# ์์: ์๊ณ์ด ๋ฐ์ดํฐ ๋ถ๋ฌ์ค๊ธฐ
df = pd.read_csv("sensor_data.csv") # 'value' ์ด ํฌํจ ๊ฐ์
data = df['value'].values.reshape(-1, 1)
# ์ ๊ทํ
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)
# ์๋์ฐํ
def make_windows(data, window=30):
X = []
for i in range(len(data) - window):
X.append(data[i:i+window])
return jnp.array(X)
real_data = make_windows(data_scaled) # shape: (samples, 30, 1)
๐ง 4. ๋ชจ๋ธ ์ ์ - Generator / Discriminator
๐งฌ Generator
from flax import linen as nn
class Generator(nn.Module):
hidden_size: int = 64
seq_len: int = 30
@nn.compact
def __call__(self, z): # z: (batch, latent_dim)
x = nn.Dense(self.hidden_size)(z)
x = nn.relu(x)
x = nn.Dense(self.seq_len * 1)(x) # output: (batch, 30*1)
x = x.reshape((z.shape[0], self.seq_len, 1))
return x
๐ Discriminator
class Discriminator(nn.Module):
hidden_size: int = 64
@nn.compact
def __call__(self, x): # x: (batch, 30, 1)
x = x.reshape((x.shape[0], -1)) # flatten
x = nn.Dense(self.hidden_size)(x)
x = nn.leaky_relu(x, 0.2)
x = nn.Dense(1)(x)
return x
โ๏ธ 5. ํ์ต ์ด๊ธฐํ
import optax
from flax.training import train_state
import jax
import jax.random as random
gen = Generator()
dis = Discriminator()
key = random.PRNGKey(0)
latent_dim = 20
batch_input = jnp.ones((1, latent_dim))
real_input = jnp.ones((1, 30, 1))
gen_params = gen.init(key, batch_input)
dis_params = dis.init(key, real_input)
gen_tx = optax.adam(1e-4)
dis_tx = optax.adam(1e-4)
gen_state = train_state.TrainState.create(apply_fn=gen.apply, params=gen_params, tx=gen_tx)
dis_state = train_state.TrainState.create(apply_fn=dis.apply, params=dis_params, tx=dis_tx)
๐ 6. ์์ค ํจ์ ์ ์ (GAN ๋ฐฉ์)
def discriminator_loss_fn(dis_params, real, fake):
real_logits = dis.apply(dis_params, real)
fake_logits = dis.apply(dis_params, fake)
loss = -jnp.mean(jnp.log(jax.nn.sigmoid(real_logits)) + jnp.log(1 - jax.nn.sigmoid(fake_logits)))
return loss
def generator_loss_fn(gen_params, dis_params, z):
fake = gen.apply(gen_params, z)
fake_logits = dis.apply(dis_params, fake)
loss = -jnp.mean(jnp.log(jax.nn.sigmoid(fake_logits)))
return loss
๐ 7. ํ์ต ๋ฃจํ
@jax.jit
def train_step(gen_state, dis_state, real_batch, key):
z = random.normal(key, (real_batch.shape[0], latent_dim))
# Generator ์์ค
g_loss, g_grads = jax.value_and_grad(generator_loss_fn)(gen_state.params, dis_state.params, z)
gen_state = gen_state.apply_gradients(grads=g_grads)
# Discriminator ์์ค
fake_data = gen.apply(gen_state.params, z)
d_loss, d_grads = jax.value_and_grad(discriminator_loss_fn)(dis_state.params, real_batch, fake_data)
dis_state = dis_state.apply_gradients(grads=d_grads)
return gen_state, dis_state, g_loss, d_loss
๐งช 8. ํ์ต ์คํ
batch_size = 64
epochs = 100
for epoch in range(epochs):
for i in range(0, len(real_data), batch_size):
real_batch = real_data[i:i+batch_size]
key, subkey = random.split(key)
gen_state, dis_state, g_loss, d_loss = train_step(gen_state, dis_state, real_batch, subkey)
print(f"Epoch {epoch+1} | G Loss: {g_loss:.4f} | D Loss: {d_loss:.4f}")
๐จ 9. ์๊ณ์ด ๋ฐ์ดํฐ ์์ฑ ๋ฐ ์๊ฐํ
def generate_samples(gen_state, num_samples=10):
z = random.normal(key, (num_samples, latent_dim))
fake = gen.apply(gen_state.params, z)
return fake
import matplotlib.pyplot as plt
samples = generate_samples(gen_state, 5)
plt.figure(figsize=(12, 4))
for i, seq in enumerate(samples):
plt.plot(seq.squeeze(), label=f"Sample {i+1}")
plt.title("Generated Time-Series Samples (TSGAN)")
plt.legend()
plt.show()
โ 10. TSGAN ํ์ฉ ์์ด๋์ด
์์ฉ ๋ถ์ผ ์ค๋ช
๋ฐ์ดํฐ ์ฆ๊ฐ | ์ด์ ์ํฉ ์ฌํ ๋ฐ์ดํฐ ์์ฑ |
GAN-based ์ด์ ํ์ง | ์์ฑ์๊ฐ ํ์ตํ์ง ๋ชปํ ์๊ณ์ด์ ์ด์์น๋ก ๊ฐ์ฃผ |
์๋ฎฌ๋ ์ด์ ์์คํ | ์ผ์, ํ๊ฒฝ ๋ณ์, ์ฌ์ฉ๋ ๋ฐ์ดํฐ ์์ฑ |
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก ์๊ณ์ด + ์กฐ๊ฑด๋ถ GAN (cGAN) ๊ตฌํ - ์กฐ๊ฑด ๊ธฐ๋ฐ ์๊ณ์ด ์์ฑ
๋ค์ ๊ธ์์๋ **์กฐ๊ฑด๋ถ GAN(Conditional GAN)**์ ์ฌ์ฉํด
"์ผ์ ์ํ๊ฐ ์ข์ ๋์ ์๊ณ์ด", "์ฃผ๊ฐ ์์น ๊ตฌ๊ฐ ์์ธก" ๋ฑ
์กฐ๊ฑด์ ์ฃผ์ด ๊ฐ์ง ์๊ณ์ด์ ์ ์ดํ๋ ๋ฐฉ๋ฒ์ ๋ฐฐ์๋๋ค.
JAX, GAN, ์๊ณ์ด ์์ฑ, TSGAN, ๋ฅ๋ฌ๋, Flax, ๊ณ ์ ์ฐ์ฐ, Anomaly Detection, ์ผ์ ์๊ณ์ด, ์๊ณ์ด ๋ฐ์ดํฐ ์ฆ๊ฐ, Generative Model, Autoencoder, Time-Series Forecasting, JAX ์์ , Python GAN, ์๊ณ์ด ๋ฅ๋ฌ๋
'Programming > JAX' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
- Total
- Today
- Yesterday
- Ktor
- REACT
- nodejs
- fastapi
- NestJS
- ํ๋ก ํธ์๋๋ฉด์
- Webpack
- ํ์ด์ฌ ์๊ณ ๋ฆฌ์ฆ
- ํ๋ก ํธ์๋
- Prisma
- SEO์ต์ ํ
- rag
- ๋ฅ๋ฌ๋
- AI์ฑ๋ด
- ์น๊ฐ๋ฐ
- Next.js
- nextJS
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- ๋ฐฑ์๋๊ฐ๋ฐ
- SEO ์ต์ ํ
- llm
- Docker
- CI/CD
- Python
- gatsbyjs
- JAX
- App Router
- PostgreSQL
- seo ์ต์ ํ 10๊ฐ
- kotlin
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |