๐ JAX๋ก ์๊ณ์ด ์์ธก ๋ชจ๋ธ ๊ตฌํ - RNN/LSTM ๊ธฐ๋ฐ ๋ฏธ๋ ๋ฐ์ดํฐ ์์ธก
๐ JAX๋ก ์๊ณ์ด ์์ธก ๋ชจ๋ธ ๊ตฌํ - RNN/LSTM ๊ธฐ๋ฐ ๋ฏธ๋ ๋ฐ์ดํฐ ์์ธก
๐ ์๊ณ์ด ์์ธก์ด๋?
์๊ณ์ด ์์ธก์ ์๊ฐ ์์๋ก ์ ๋ ฌ๋ ๋ฐ์ดํฐ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๋ฏธ๋ ๊ฐ์ ์์ธกํ๋ ๊ธฐ๋ฒ์
๋๋ค.
๊ธฐ์จ ์์ธก, ์ฃผ์ ๊ฐ๊ฒฉ ์์ธก, ์์ ์์ธก ๋ฑ ๋ค์ํ ๋ถ์ผ์์ ํ์ฉ๋ฉ๋๋ค.
์ด๋ฒ ๊ธ์์๋ JAX๋ฅผ ํ์ฉํ์ฌ **RNN(Recurrent Neural Network)**๊ณผ LSTM(Long Short-Term Memory) ๊ธฐ๋ฐ์ ์๊ณ์ด ์์ธก ๋ชจ๋ธ์ ๊ตฌํํฉ๋๋ค.
๐ก 1. ์๊ณ์ด ๋ฐ์ดํฐ์ ํน์ง
- ์์ฐจ์ฑ(Sequential): ์ด์ ๊ฐ์ด ๋ค์ ๊ฐ์ ์ํฅ์ ์ค
- ํจํด ๋ฐ๋ณต: ์ฃผ๊ธฐ์ฑ, ๊ณ์ ์ฑ
- ๋ถ์์ ์ฑ: ๋ ธ์ด์ฆ์ ์ด์์น ์กด์ฌ ๊ฐ๋ฅ์ฑ
๐ง 2. ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น
pip install jax jaxlib flax optax pandas matplotlib scikit-learn
๐พ 3. ๋ฐ์ดํฐ ์ค๋น - ์ผ๋ณ ๊ธฐ์จ ์์ธก ์์
๐ฅ CSV ๋๋ ์๊ณ์ด ๋ฐ์ดํฐ์ ์ฌ์ฉ
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import jax.numpy as jnp
# ์์: ์์ธ ๊ธฐ์จ ๋ฐ์ดํฐ (์ผ์, ํ๊ท ๊ธฐ์จ)
df = pd.read_csv('seoul_temperature.csv')
temps = df['avg_temp'].values.reshape(-1, 1)
# ์ ๊ทํ
scaler = MinMaxScaler()
temps_scaled = scaler.fit_transform(temps)
# ์๊ณ์ด ์ํ ์์ฑ
def create_sequences(data, window=30):
xs, ys = [], []
for i in range(len(data) - window):
xs.append(data[i:i+window])
ys.append(data[i+window])
return jnp.array(xs), jnp.array(ys)
x_data, y_data = create_sequences(temps_scaled)
๐ง 4. RNN ๋ชจ๋ธ ๊ตฌํ (Flax ๊ธฐ๋ฐ)
๐ RNNCell ์ ์
from flax import linen as nn
class SimpleRNNCell(nn.Module):
hidden_size: int
@nn.compact
def __call__(self, carry, x):
h = carry
h = nn.tanh(nn.Dense(self.hidden_size)(x) + nn.Dense(self.hidden_size)(h))
return h, h
๐ RNN ์ ์ฒด ๋ชจ๋ธ ์ ์
class RNNModel(nn.Module):
hidden_size: int
output_size: int = 1
@nn.compact
def __call__(self, x):
batch_size, seq_len, _ = x.shape
rnn_cell = SimpleRNNCell(self.hidden_size)
h = jnp.zeros((batch_size, self.hidden_size))
for t in range(seq_len):
h, _ = rnn_cell(h, x[:, t, :])
output = nn.Dense(self.output_size)(h)
return output
๐ 5. ์์ค ํจ์ ๋ฐ ํ์ต ๋ฃจํ
โ๏ธ ํ๋ จ ์ค์
import optax
from flax.training import train_state
model = RNNModel(hidden_size=64)
key = jax.random.PRNGKey(42)
params = model.init(key, jnp.ones((1, 30, 1))) # input: batch, seq, feature
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
๐งฎ ์์ค ๋ฐ ํ์ต
@jax.jit
def mse_loss(params, x, y):
preds = model.apply(params, x)
return jnp.mean((preds.squeeze() - y.squeeze()) ** 2)
@jax.jit
def train_step(state, x, y):
loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state, loss
๐ 6. ๋ชจ๋ธ ํ์ต
batch_size = 64
epochs = 20
for epoch in range(epochs):
for i in range(0, len(x_data), batch_size):
x_batch = x_data[i:i+batch_size]
y_batch = y_data[i:i+batch_size]
state, loss = train_step(state, x_batch[..., None], y_batch)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
โ 7. ๋ฏธ๋ ๊ฐ ์์ธก ๋ฐ ์๊ฐํ
import matplotlib.pyplot as plt
def predict_future(state, input_seq, steps=10):
results = []
current = input_seq
for _ in range(steps):
pred = model.apply(state.params, current[None, ..., None])
results.append(pred.squeeze())
current = jnp.concatenate([current[1:], pred.squeeze()[None]], axis=0)
return jnp.array(results)
# ์์ธก
last_seq = x_data[-1]
future = predict_future(state, last_seq, steps=30)
future_rescaled = scaler.inverse_transform(future.reshape(-1, 1))
# ์๊ฐํ
plt.plot(range(len(temps)), temps, label='์๋ณธ')
plt.plot(range(len(temps), len(temps) + 30), future_rescaled, label='์์ธก', color='red')
plt.legend()
plt.title("RNN ๊ธฐ๋ฐ ๊ธฐ์จ ์์ธก")
plt.show()
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก Attention ๊ธฐ๋ฐ ์๊ณ์ด ์์ธก ๋ชจ๋ธ ๊ตฌํ (Transformer for Time Series)
๋ค์ ๊ธ์์๋ Self-Attention ๊ตฌ์กฐ๋ฅผ ์ ์ฉํ ์๊ณ์ด ์์ธก ๋ชจ๋ธ์ JAX๋ก ๊ตฌํํ์ฌ,
๋ณต์กํ ์ฃผ๊ธฐ์ฑ๊ณผ ์ฅ๊ธฐ ์์กด์ฑ๊น์ง ๋ฐ์ํ ์์ธก์ ์ค์ตํด๋ณด๊ฒ ์ต๋๋ค.
JAX, ์๊ณ์ด ์์ธก, RNN, LSTM, ๊ธฐ์จ ์์ธก, ๋ฅ๋ฌ๋, Time Series, Python, ๊ณ ์ ์ฐ์ฐ, ํ์ต ๋ฃจํ, Flax, ์ค์ ๋ชจ๋ธ, ๋ฏธ๋ ์์ธก, ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ์ธ๊ณต์ง๋ฅ, ๊ณ ์ฑ๋ฅ ๊ณ์ฐ, ์๊ฐํ