ํฐ์คํ ๋ฆฌ ๋ทฐ
Programming/JAX
๐ JAX๋ก Prophet ์คํ์ผ + Neural Basis Expansion Hybrid ์๊ณ์ด ์์ธก ๋ชจ๋ธ ๊ตฌํ
octo54 2025. 6. 13. 14:24๋ฐ์ํ
๐ JAX๋ก Prophet ์คํ์ผ + Neural Basis Expansion Hybrid ์๊ณ์ด ์์ธก ๋ชจ๋ธ ๊ตฌํ
๐ ์ Hybrid ๋ชจ๋ธ์ธ๊ฐ?
- Prophet์ ๊ณ์ ์ฑ, ์ถ์ธ, ํด์ผ ํจ๊ณผ ๋ฑ ํด์ ๊ฐ๋ฅํ ์์ธก์ด ๊ฐ์
- **Neural Basis Expansion (N-BEATS)**๋ ๋ฅ๋ฌ๋ ๊ธฐ๋ฐ์ ๊ณ ์ ๋ ๋ธ๋ก์ ์์ ๋์ ์์ธก ์ ํ๋๋ฅผ ๋ฌ์ฑ
- JAX์์๋ ์ ํ์ ๊ณ์ ์ฑ + ๋น์ ํ ๋ฅ๋ฌ๋ ๊ตฌ์กฐ๋ฅผ ๋์์ ๋ฐ์ํ Hybrid ๋ชจ๋ธ์ ์ ์ฐํ๊ฒ ๊ตฌํํ ์ ์์
๐ฏ ์ด ๊ธ์์ ๊ตฌํํ๋ Hybrid ๊ตฌ์ฑ
๊ตฌ์ฑ์์ ์ญํ
Trend Component (Linear/Logistic) | Prophet์ ์ถ์ธ ๋ชจ๋ธ๋ง |
Seasonal Component (Fourier basis) | ์ฃผ๊ธฐ์ฑ ์ ๋ณด ๋ฐ์ |
Residual Neural Component (MLP or Transformer) | ์ค์ฐจ ๋ณด์ , ๋น์ ํ ํจํด ์ฒ๋ฆฌ |
๐พ 1. ๋ฐ์ดํฐ ์ค๋น
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler
df = pd.read_csv("time_series_data.csv", parse_dates=["ds"])
y = df['y'].values.reshape(-1, 1)
scaler = MinMaxScaler()
y_scaled = scaler.fit_transform(y)
t = jnp.arange(len(y_scaled)).reshape(-1, 1) # ์์ ์ ๋ณด
๐ง 2. Fourier ๊ธฐ๋ฐ ๊ณ์ ์ฑ ํจ์ ๊ตฌํ
def fourier_features(t, period=365.25, order=3):
t = t / period
features = [jnp.sin(2 * jnp.pi * t * i) for i in range(1, order + 1)]
features += [jnp.cos(2 * jnp.pi * t * i) for i in range(1, order + 1)]
return jnp.concatenate(features, axis=1) # shape: (N, order*2)
๐งฑ 3. Hybrid ๋ชจ๋ธ ์ ์ (Flax)
from flax import linen as nn
class ProphetLikeHybridModel(nn.Module):
fourier_order: int = 3
@nn.compact
def __call__(self, t):
# Trend: ์ ํ ์ถ์ธ
trend = nn.Dense(1)(t) # (batch, 1)
# Seasonality: Fourier
fourier = fourier_features(t, order=self.fourier_order)
seasonal = nn.Dense(1)(fourier) # (batch, 1)
# Residual (non-linear part)
x = jnp.concatenate([t, fourier], axis=1)
x = nn.Dense(32)(x)
x = nn.relu(x)
x = nn.Dense(1)(x) # Residual output
return trend + seasonal + x # Hybrid ํฉ์ฐ
โ๏ธ 4. ํ์ต ์ด๊ธฐํ
๋ฐ์ํ
import optax
from flax.training import train_state
import jax
model = ProphetLikeHybridModel()
params = model.init(jax.random.PRNGKey(0), t)
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
๐ 5. ์์ค ๋ฐ ํ์ต Step
@jax.jit
def loss_fn(params, t, y_true):
preds = model.apply(params, t)
return jnp.mean((preds - y_true) ** 2)
@jax.jit
def train_step(state, t, y_true):
loss, grads = jax.value_and_grad(loss_fn)(state.params, t, y_true)
state = state.apply_gradients(grads=grads)
return state, loss
๐ 6. ํ์ต ๋ฃจํ
epochs = 500
for epoch in range(epochs):
state, loss = train_step(state, t, y_scaled)
if epoch % 50 == 0:
print(f"Epoch {epoch} - Loss: {loss:.6f}")
๐ฎ 7. ์์ธก ๋ฐ ์๊ฐํ
import matplotlib.pyplot as plt
preds = model.apply(state.params, t)
preds_inv = scaler.inverse_transform(preds)
plt.plot(df['ds'], y, label="True")
plt.plot(df['ds'], preds_inv, label="Hybrid Prediction")
plt.title("Prophet + NN Hybrid Forecast")
plt.legend()
plt.show()
โ 8. ํ์ฉ ์๋๋ฆฌ์ค
๋ถ์ผ ์ค๋ช
์๋์ง ์์ธก | ์ฃผ๊ฐ/์ฐ๊ฐ ์ฃผ๊ธฐ + ์ด์์น |
๊ธ์ต | ํธ๋ ๋ + ๋น์ ํ์ ๊ธ๋ฑ๋ฝ |
์์ ์์ธก | ๋ช ์ /๊ณ์ ํจ๊ณผ + ๋จธ์ ๋ฌ๋ ๋ณด์ |
๐ ๋ค์ ๊ธ ์๊ณ : JAX ๊ธฐ๋ฐ์ Graph Neural Network(GNN)๋ก ์๊ณ์ด + ๊ด๊ณ ๋ฐ์ดํฐ ์์ธก ๋ชจ๋ธ ๊ตฌํํ๊ธฐ
JAX, ์๊ณ์ด ์์ธก, Hybrid ๋ชจ๋ธ, Prophet, Fourier Transform, Neural Basis Expansion, Flax, ๋ฅ๋ฌ๋ ์๊ณ์ด, ์๊ณ์ด ํด์, ์๊ณ์ด ํธ๋ ๋, ์๊ณ์ด ๊ณ์ ์ฑ, AI ์์ธก, ๋น์ ํ ์์ธก, Multi-component Forecasting, Python ML
'Programming > JAX' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
โป ์ด ํฌ์คํ
์ ์ฟ ํก ํํธ๋์ค ํ๋์ ์ผํ์ผ๋ก, ์ด์ ๋ฐ๋ฅธ ์ผ์ ์ก์ ์์๋ฃ๋ฅผ ์ ๊ณต๋ฐ์ต๋๋ค.
๊ณต์ง์ฌํญ
์ต๊ทผ์ ์ฌ๋ผ์จ ๊ธ
์ต๊ทผ์ ๋ฌ๋ฆฐ ๋๊ธ
- Total
- Today
- Yesterday
๋งํฌ
TAG
- fastapi
- kotlin
- Next.js
- PostgreSQL
- SEO์ต์ ํ
- ํ๋ก ํธ์๋๋ฉด์
- ์น๊ฐ๋ฐ
- AI์ฑ๋ด
- ๋ฅ๋ฌ๋
- nodejs
- REACT
- SEO ์ต์ ํ
- CI/CD
- ํ์ด์ฌ์๊ณ ๋ฆฌ์ฆ
- ํ๋ก ํธ์๋
- NestJS
- seo ์ต์ ํ 10๊ฐ
- nextJS
- Python
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- llm
- rag
- gatsbyjs
- Docker
- App Router
- ๋ฐฑ์๋๊ฐ๋ฐ
- JAX
- Prisma
- Ktor
- flax
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
๊ธ ๋ณด๊ดํจ
๋ฐ์ํ