ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ Prophet ์Šคํƒ€์ผ + Neural Basis Expansion Hybrid ์‹œ๊ณ„์—ด ์˜ˆ์ธก ๋ชจ๋ธ ๊ตฌํ˜„


๐ŸŒŸ ์™œ Hybrid ๋ชจ๋ธ์ธ๊ฐ€?

  • Prophet์€ ๊ณ„์ ˆ์„ฑ, ์ถ”์„ธ, ํœด์ผ ํšจ๊ณผ ๋“ฑ ํ•ด์„ ๊ฐ€๋Šฅํ•œ ์˜ˆ์ธก์ด ๊ฐ•์ 
  • **Neural Basis Expansion (N-BEATS)**๋Š” ๋”ฅ๋Ÿฌ๋‹ ๊ธฐ๋ฐ˜์˜ ๊ณ ์ •๋œ ๋ธ”๋ก์„ ์Œ“์•„ ๋†’์€ ์˜ˆ์ธก ์ •ํ™•๋„๋ฅผ ๋‹ฌ์„ฑ
  • JAX์—์„œ๋Š” ์ •ํ˜•์  ๊ณ„์ ˆ์„ฑ + ๋น„์„ ํ˜• ๋”ฅ๋Ÿฌ๋‹ ๊ตฌ์กฐ๋ฅผ ๋™์‹œ์— ๋ฐ˜์˜ํ•œ Hybrid ๋ชจ๋ธ์„ ์œ ์—ฐํ•˜๊ฒŒ ๊ตฌํ˜„ํ•  ์ˆ˜ ์žˆ์Œ

๐ŸŽฏ ์ด ๊ธ€์—์„œ ๊ตฌํ˜„ํ•˜๋Š” Hybrid ๊ตฌ์„ฑ

๊ตฌ์„ฑ์š”์†Œ ์—ญํ• 

Trend Component (Linear/Logistic) Prophet์˜ ์ถ”์„ธ ๋ชจ๋ธ๋ง
Seasonal Component (Fourier basis) ์ฃผ๊ธฐ์„ฑ ์ •๋ณด ๋ฐ˜์˜
Residual Neural Component (MLP or Transformer) ์˜ค์ฐจ ๋ณด์ •, ๋น„์„ ํ˜• ํŒจํ„ด ์ฒ˜๋ฆฌ

๐Ÿ’พ 1. ๋ฐ์ดํ„ฐ ์ค€๋น„

import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler

df = pd.read_csv("time_series_data.csv", parse_dates=["ds"])
y = df['y'].values.reshape(-1, 1)

scaler = MinMaxScaler()
y_scaled = scaler.fit_transform(y)
t = jnp.arange(len(y_scaled)).reshape(-1, 1)  # ์‹œ์  ์ •๋ณด

๐Ÿง  2. Fourier ๊ธฐ๋ฐ˜ ๊ณ„์ ˆ์„ฑ ํ•จ์ˆ˜ ๊ตฌํ˜„

def fourier_features(t, period=365.25, order=3):
    t = t / period
    features = [jnp.sin(2 * jnp.pi * t * i) for i in range(1, order + 1)]
    features += [jnp.cos(2 * jnp.pi * t * i) for i in range(1, order + 1)]
    return jnp.concatenate(features, axis=1)  # shape: (N, order*2)

๐Ÿงฑ 3. Hybrid ๋ชจ๋ธ ์ •์˜ (Flax)

from flax import linen as nn

class ProphetLikeHybridModel(nn.Module):
    fourier_order: int = 3

    @nn.compact
    def __call__(self, t):
        # Trend: ์„ ํ˜• ์ถ”์„ธ
        trend = nn.Dense(1)(t)  # (batch, 1)

        # Seasonality: Fourier
        fourier = fourier_features(t, order=self.fourier_order)
        seasonal = nn.Dense(1)(fourier)  # (batch, 1)

        # Residual (non-linear part)
        x = jnp.concatenate([t, fourier], axis=1)
        x = nn.Dense(32)(x)
        x = nn.relu(x)
        x = nn.Dense(1)(x)  # Residual output

        return trend + seasonal + x  # Hybrid ํ•ฉ์‚ฐ

โš™๏ธ 4. ํ•™์Šต ์ดˆ๊ธฐํ™”

๋ฐ˜์‘ํ˜•
import optax
from flax.training import train_state
import jax

model = ProphetLikeHybridModel()
params = model.init(jax.random.PRNGKey(0), t)
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

๐Ÿ“‰ 5. ์†์‹ค ๋ฐ ํ•™์Šต Step

@jax.jit
def loss_fn(params, t, y_true):
    preds = model.apply(params, t)
    return jnp.mean((preds - y_true) ** 2)

@jax.jit
def train_step(state, t, y_true):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, t, y_true)
    state = state.apply_gradients(grads=grads)
    return state, loss

๐Ÿƒ 6. ํ•™์Šต ๋ฃจํ”„

epochs = 500
for epoch in range(epochs):
    state, loss = train_step(state, t, y_scaled)
    if epoch % 50 == 0:
        print(f"Epoch {epoch} - Loss: {loss:.6f}")

๐Ÿ”ฎ 7. ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”

import matplotlib.pyplot as plt

preds = model.apply(state.params, t)
preds_inv = scaler.inverse_transform(preds)

plt.plot(df['ds'], y, label="True")
plt.plot(df['ds'], preds_inv, label="Hybrid Prediction")
plt.title("Prophet + NN Hybrid Forecast")
plt.legend()
plt.show()

โœ… 8. ํ™œ์šฉ ์‹œ๋‚˜๋ฆฌ์˜ค

๋ถ„์•ผ ์„ค๋ช…

์—๋„ˆ์ง€ ์˜ˆ์ธก ์ฃผ๊ฐ„/์—ฐ๊ฐ„ ์ฃผ๊ธฐ + ์ด์ƒ์น˜
๊ธˆ์œต ํŠธ๋ Œ๋“œ + ๋น„์ •ํ˜•์  ๊ธ‰๋“ฑ๋ฝ
์ˆ˜์š” ์˜ˆ์ธก ๋ช…์ ˆ/๊ณ„์ ˆ ํšจ๊ณผ + ๋จธ์‹ ๋Ÿฌ๋‹ ๋ณด์ •

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX ๊ธฐ๋ฐ˜์˜ Graph Neural Network(GNN)๋กœ ์‹œ๊ณ„์—ด + ๊ด€๊ณ„ ๋ฐ์ดํ„ฐ ์˜ˆ์ธก ๋ชจ๋ธ ๊ตฌํ˜„ํ•˜๊ธฐ

 

JAX, ์‹œ๊ณ„์—ด ์˜ˆ์ธก, Hybrid ๋ชจ๋ธ, Prophet, Fourier Transform, Neural Basis Expansion, Flax, ๋”ฅ๋Ÿฌ๋‹ ์‹œ๊ณ„์—ด, ์‹œ๊ณ„์—ด ํ•ด์„, ์‹œ๊ณ„์—ด ํŠธ๋ Œ๋“œ, ์‹œ๊ณ„์—ด ๊ณ„์ ˆ์„ฑ, AI ์˜ˆ์ธก, ๋น„์„ ํ˜• ์˜ˆ์ธก, Multi-component Forecasting, Python ML

โ€ป ์ด ํฌ์ŠคํŒ…์€ ์ฟ ํŒก ํŒŒํŠธ๋„ˆ์Šค ํ™œ๋™์˜ ์ผํ™˜์œผ๋กœ, ์ด์— ๋”ฐ๋ฅธ ์ผ์ •์•ก์˜ ์ˆ˜์ˆ˜๋ฃŒ๋ฅผ ์ œ๊ณต๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ณต์ง€์‚ฌํ•ญ
์ตœ๊ทผ์— ์˜ฌ๋ผ์˜จ ๊ธ€
์ตœ๊ทผ์— ๋‹ฌ๋ฆฐ ๋Œ“๊ธ€
Total
Today
Yesterday
๋งํฌ
ยซ   2025/07   ยป
์ผ ์›” ํ™” ์ˆ˜ ๋ชฉ ๊ธˆ ํ† 
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
๊ธ€ ๋ณด๊ด€ํ•จ
๋ฐ˜์‘ํ˜•