Programming/JAX

๐Ÿ“Œ JAX๋กœ ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ ๊ตฌํ˜„ - RNN/LSTM ๊ธฐ๋ฐ˜ ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก

octo54 2025. 5. 22. 15:51
๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ ๊ตฌํ˜„ - RNN/LSTM ๊ธฐ๋ฐ˜ ๋ฏธ๋ž˜ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก


๐Ÿš€ ์‹œ๊ณ„์—ด ์˜ˆ์ธก์ด๋ž€?

์‹œ๊ณ„์—ด ์˜ˆ์ธก์€ ์‹œ๊ฐ„ ์ˆœ์„œ๋กœ ์ •๋ ฌ๋œ ๋ฐ์ดํ„ฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฏธ๋ž˜ ๊ฐ’์„ ์˜ˆ์ธกํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
๊ธฐ์˜จ ์˜ˆ์ธก, ์ฃผ์‹ ๊ฐ€๊ฒฉ ์˜ˆ์ธก, ์ˆ˜์š” ์˜ˆ์ธก ๋“ฑ ๋‹ค์–‘ํ•œ ๋ถ„์•ผ์—์„œ ํ™œ์šฉ๋ฉ๋‹ˆ๋‹ค.

์ด๋ฒˆ ๊ธ€์—์„œ๋Š” JAX๋ฅผ ํ™œ์šฉํ•˜์—ฌ **RNN(Recurrent Neural Network)**๊ณผ LSTM(Long Short-Term Memory) ๊ธฐ๋ฐ˜์˜ ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ’ก 1. ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ์˜ ํŠน์ง•

  • ์ˆœ์ฐจ์„ฑ(Sequential): ์ด์ „ ๊ฐ’์ด ๋‹ค์Œ ๊ฐ’์— ์˜ํ–ฅ์„ ์คŒ
  • ํŒจํ„ด ๋ฐ˜๋ณต: ์ฃผ๊ธฐ์„ฑ, ๊ณ„์ ˆ์„ฑ
  • ๋ถˆ์•ˆ์ •์„ฑ: ๋…ธ์ด์ฆˆ์™€ ์ด์ƒ์น˜ ์กด์žฌ ๊ฐ€๋Šฅ์„ฑ

๐Ÿ”ง 2. ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜

pip install jax jaxlib flax optax pandas matplotlib scikit-learn

๐Ÿ’พ 3. ๋ฐ์ดํ„ฐ ์ค€๋น„ - ์ผ๋ณ„ ๊ธฐ์˜จ ์˜ˆ์ธก ์˜ˆ์ œ

๐Ÿ“ฅ CSV ๋˜๋Š” ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ์…‹ ์‚ฌ์šฉ

import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import jax.numpy as jnp

# ์˜ˆ์‹œ: ์„œ์šธ ๊ธฐ์˜จ ๋ฐ์ดํ„ฐ (์ผ์ž, ํ‰๊ท ๊ธฐ์˜จ)
df = pd.read_csv('seoul_temperature.csv')
temps = df['avg_temp'].values.reshape(-1, 1)

# ์ •๊ทœํ™”
scaler = MinMaxScaler()
temps_scaled = scaler.fit_transform(temps)

# ์‹œ๊ณ„์—ด ์ƒ˜ํ”Œ ์ƒ์„ฑ
def create_sequences(data, window=30):
    xs, ys = [], []
    for i in range(len(data) - window):
        xs.append(data[i:i+window])
        ys.append(data[i+window])
    return jnp.array(xs), jnp.array(ys)

x_data, y_data = create_sequences(temps_scaled)

๐Ÿง  4. RNN ๋ชจ๋ธ ๊ตฌํ˜„ (Flax ๊ธฐ๋ฐ˜)

๋ฐ˜์‘ํ˜•

๐Ÿ“ RNNCell ์ •์˜

from flax import linen as nn

class SimpleRNNCell(nn.Module):
    hidden_size: int

    @nn.compact
    def __call__(self, carry, x):
        h = carry
        h = nn.tanh(nn.Dense(self.hidden_size)(x) + nn.Dense(self.hidden_size)(h))
        return h, h

๐Ÿ” RNN ์ „์ฒด ๋ชจ๋ธ ์ •์˜

class RNNModel(nn.Module):
    hidden_size: int
    output_size: int = 1

    @nn.compact
    def __call__(self, x):
        batch_size, seq_len, _ = x.shape
        rnn_cell = SimpleRNNCell(self.hidden_size)
        h = jnp.zeros((batch_size, self.hidden_size))

        for t in range(seq_len):
            h, _ = rnn_cell(h, x[:, t, :])

        output = nn.Dense(self.output_size)(h)
        return output

๐Ÿ“‰ 5. ์†์‹ค ํ•จ์ˆ˜ ๋ฐ ํ•™์Šต ๋ฃจํ”„

โš™๏ธ ํ›ˆ๋ จ ์„ค์ •

import optax
from flax.training import train_state

model = RNNModel(hidden_size=64)
key = jax.random.PRNGKey(42)
params = model.init(key, jnp.ones((1, 30, 1)))  # input: batch, seq, feature

tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

๐Ÿงฎ ์†์‹ค ๋ฐ ํ•™์Šต

@jax.jit
def mse_loss(params, x, y):
    preds = model.apply(params, x)
    return jnp.mean((preds.squeeze() - y.squeeze()) ** 2)

@jax.jit
def train_step(state, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

๐Ÿ” 6. ๋ชจ๋ธ ํ•™์Šต

batch_size = 64
epochs = 20

for epoch in range(epochs):
    for i in range(0, len(x_data), batch_size):
        x_batch = x_data[i:i+batch_size]
        y_batch = y_data[i:i+batch_size]
        state, loss = train_step(state, x_batch[..., None], y_batch)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

โœ… 7. ๋ฏธ๋ž˜ ๊ฐ’ ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”

import matplotlib.pyplot as plt

def predict_future(state, input_seq, steps=10):
    results = []
    current = input_seq
    for _ in range(steps):
        pred = model.apply(state.params, current[None, ..., None])
        results.append(pred.squeeze())
        current = jnp.concatenate([current[1:], pred.squeeze()[None]], axis=0)
    return jnp.array(results)

# ์˜ˆ์ธก
last_seq = x_data[-1]
future = predict_future(state, last_seq, steps=30)
future_rescaled = scaler.inverse_transform(future.reshape(-1, 1))

# ์‹œ๊ฐํ™”
plt.plot(range(len(temps)), temps, label='์›๋ณธ')
plt.plot(range(len(temps), len(temps) + 30), future_rescaled, label='์˜ˆ์ธก', color='red')
plt.legend()
plt.title("RNN ๊ธฐ๋ฐ˜ ๊ธฐ์˜จ ์˜ˆ์ธก")
plt.show()

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ Attention ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ ๊ตฌํ˜„ (Transformer for Time Series)

๋‹ค์Œ ๊ธ€์—์„œ๋Š” Self-Attention ๊ตฌ์กฐ๋ฅผ ์ ์šฉํ•œ ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ์„ JAX๋กœ ๊ตฌํ˜„ํ•˜์—ฌ,
๋ณต์žกํ•œ ์ฃผ๊ธฐ์„ฑ๊ณผ ์žฅ๊ธฐ ์˜์กด์„ฑ๊นŒ์ง€ ๋ฐ˜์˜ํ•œ ์˜ˆ์ธก์„ ์‹ค์Šตํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


 

JAX, ์‹œ๊ณ„์—ด ์˜ˆ์ธก, RNN, LSTM, ๊ธฐ์˜จ ์˜ˆ์ธก, ๋”ฅ๋Ÿฌ๋‹, Time Series, Python, ๊ณ ์† ์—ฐ์‚ฐ, ํ•™์Šต ๋ฃจํ”„, Flax, ์‹ค์ „ ๋ชจ๋ธ, ๋ฏธ๋ž˜ ์˜ˆ์ธก, ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ, ์ธ๊ณต์ง€๋Šฅ, ๊ณ ์„ฑ๋Šฅ ๊ณ„์‚ฐ, ์‹œ๊ฐํ™”