티스토리 뷰

반응형

📌 JAX로 Transformer 기반 다변량 시계열 예측 모델 구현 - Multi-head Attention을 활용한 고급 Forecasting


⚡️ Transformer는 어떻게 시계열 예측에 유리할까?

  • Self-Attention: 시계열 전 구간에서 정보 상호작용 가능
  • 병렬 처리: RNN에 비해 학습/추론 속도 빠름
  • 멀티채널 지원: 여러 센서 또는 변수(feature) 간 관계 모델링 용이
  • 멀티스텝 예측: 미래 10, 20, 30 step까지도 동시에 예측 가능

💡 모델 구조 개요

구성 요소 설명

Positional Encoding 시간 정보를 인코딩
Encoder Block 다중 시점과 피처 간 상호작용
Decoder Block (선택적) 과거 + 미래 조건 기반 예측
Output Head 다변량 미래값 출력

💾 1. 다변량 시계열 데이터 준비

import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler

df = pd.read_csv("multivariate_timeseries.csv")  # 예: ['temp', 'humidity', 'vibration']
features = df[['temp', 'humidity', 'vibration']].values

scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)

def create_dataset(data, input_len=30, pred_len=10):
    X, Y = [], []
    for i in range(len(data) - input_len - pred_len):
        X.append(data[i:i+input_len])
        Y.append(data[i+input_len:i+input_len+pred_len])
    return jnp.array(X), jnp.array(Y)

X, Y = create_dataset(features_scaled)

🧱 2. Transformer 예측 모델 구현 (Flax)

🧠 Positional Encoding

def positional_encoding(seq_len, d_model):
    pos = jnp.arange(seq_len)[:, None]
    i = jnp.arange(d_model)[None, :]
    angle_rates = 1 / jnp.power(10000, (2 * (i // 2)) / d_model)
    angle_rads = pos * angle_rates
    angle_rads = angle_rads.at[:, 0::2].set(jnp.sin(angle_rads[:, 0::2]))
    angle_rads = angle_rads.at[:, 1::2].set(jnp.cos(angle_rads[:, 1::2]))
    return angle_rads

🔧 Multi-head Attention & Transformer Block

반응형
from flax import linen as nn

class MultiHeadAttention(nn.Module):
    d_model: int
    num_heads: int

    @nn.compact
    def __call__(self, x):
        head_dim = self.d_model // self.num_heads
        qkv = nn.Dense(self.d_model * 3)(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        def reshape_heads(t):
            return t.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)

        q, k, v = map(reshape_heads, (q, k, v))

        scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
        weights = jax.nn.softmax(scores, axis=-1)
        output = jnp.einsum('bhqk,bhvd->bhqd', weights, v)

        output = output.transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)
        return nn.Dense(self.d_model)(output)

class TransformerBlock(nn.Module):
    d_model: int
    num_heads: int
    ff_dim: int

    @nn.compact
    def __call__(self, x):
        attn = MultiHeadAttention(self.d_model, self.num_heads)(x)
        x = nn.LayerNorm()(x + attn)
        ff = nn.Dense(self.ff_dim)(x)
        ff = nn.relu(ff)
        ff = nn.Dense(self.d_model)(ff)
        x = nn.LayerNorm()(x + ff)
        return x

🔮 전체 Transformer Forecasting 모델

class TransformerForecast(nn.Module):
    d_model: int = 64
    num_heads: int = 4
    ff_dim: int = 128
    num_layers: int = 2
    out_len: int = 10
    out_dim: int = 3  # 다변량

    @nn.compact
    def __call__(self, x):
        # Linear projection + positional encoding
        x = nn.Dense(self.d_model)(x)
        x += positional_encoding(x.shape[1], self.d_model)

        for _ in range(self.num_layers):
            x = TransformerBlock(self.d_model, self.num_heads, self.ff_dim)(x)

        pooled = jnp.mean(x, axis=1)
        x = nn.Dense(self.out_len * self.out_dim)(pooled)
        return x.reshape(-1, self.out_len, self.out_dim)

⚙️ 3. 학습 준비

import optax
from flax.training import train_state

model = TransformerForecast()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 30, 3)))
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

📉 4. 손실 함수 및 학습

@jax.jit
def loss_fn(params, x, y):
    preds = model.apply(params, x)
    return jnp.mean((preds - y) ** 2)

@jax.jit
def train_step(state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

🏃 5. 학습 루프

batch_size = 64
epochs = 20

for epoch in range(epochs):
    for i in range(0, len(X), batch_size):
        x_batch = X[i:i+batch_size]
        y_batch = Y[i:i+batch_size]
        state, loss = train_step(state, x_batch, y_batch)
    print(f"Epoch {epoch+1} - Loss: {loss:.4f}")

📊 6. 예측 시각화

import matplotlib.pyplot as plt

sample_input = X[-1:]
predicted = model.apply(state.params, sample_input)

for i in range(predicted.shape[-1]):
    plt.plot(Y[-1][:, i], label=f"True Var {i}")
    plt.plot(predicted[0][:, i], linestyle="--", label=f"Predicted Var {i}")
plt.title("Multivariate Forecast (Transformer)")
plt.legend()
plt.show()

7. 확장 아이디어

기능 설명

미래 조건 추가 Decoder 구조 확장
Attention 시각화 가중치 출력 후 해석
실시간 추론 @jax.jit으로 최적화
결측치 보간 예측 구조 그대로 활용 가능

📌 다음 글 예고: JAX + Prophet 또는 Neural Basis Expansion 기반 시계열 Hybrid 모델 구현


 

JAX, Time Series Forecasting, Transformer, Multivariate Forecasting, Self-Attention, Flax, 딥러닝 시계열, Multi-step Prediction, JAX 예제, 시계열 분석, Deep Learning, AI Forecasting, 센서 예측, LSTM 대안, 시계열 해석

※ 이 포스팅은 쿠팡 파트너스 활동의 일환으로, 이에 따른 일정액의 수수료를 제공받습니다.
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2025/07   »
1 2 3 4 5
6 7 8 9 10 11 12
13 14 15 16 17 18 19
20 21 22 23 24 25 26
27 28 29 30 31
글 보관함
반응형