티스토리 뷰
Programming/JAX
📌 JAX로 Transformer 기반 다변량 시계열 예측 모델 구현 - Multi-head Attention을 활용한 고급 Forecasting
octo54 2025. 6. 10. 13:41반응형
📌 JAX로 Transformer 기반 다변량 시계열 예측 모델 구현 - Multi-head Attention을 활용한 고급 Forecasting
⚡️ Transformer는 어떻게 시계열 예측에 유리할까?
- Self-Attention: 시계열 전 구간에서 정보 상호작용 가능
- 병렬 처리: RNN에 비해 학습/추론 속도 빠름
- 멀티채널 지원: 여러 센서 또는 변수(feature) 간 관계 모델링 용이
- 멀티스텝 예측: 미래 10, 20, 30 step까지도 동시에 예측 가능
💡 모델 구조 개요
구성 요소 설명
Positional Encoding | 시간 정보를 인코딩 |
Encoder Block | 다중 시점과 피처 간 상호작용 |
Decoder Block (선택적) | 과거 + 미래 조건 기반 예측 |
Output Head | 다변량 미래값 출력 |
💾 1. 다변량 시계열 데이터 준비
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import StandardScaler
df = pd.read_csv("multivariate_timeseries.csv") # 예: ['temp', 'humidity', 'vibration']
features = df[['temp', 'humidity', 'vibration']].values
scaler = StandardScaler()
features_scaled = scaler.fit_transform(features)
def create_dataset(data, input_len=30, pred_len=10):
X, Y = [], []
for i in range(len(data) - input_len - pred_len):
X.append(data[i:i+input_len])
Y.append(data[i+input_len:i+input_len+pred_len])
return jnp.array(X), jnp.array(Y)
X, Y = create_dataset(features_scaled)
🧱 2. Transformer 예측 모델 구현 (Flax)
🧠 Positional Encoding
def positional_encoding(seq_len, d_model):
pos = jnp.arange(seq_len)[:, None]
i = jnp.arange(d_model)[None, :]
angle_rates = 1 / jnp.power(10000, (2 * (i // 2)) / d_model)
angle_rads = pos * angle_rates
angle_rads = angle_rads.at[:, 0::2].set(jnp.sin(angle_rads[:, 0::2]))
angle_rads = angle_rads.at[:, 1::2].set(jnp.cos(angle_rads[:, 1::2]))
return angle_rads
🔧 Multi-head Attention & Transformer Block
반응형
from flax import linen as nn
class MultiHeadAttention(nn.Module):
d_model: int
num_heads: int
@nn.compact
def __call__(self, x):
head_dim = self.d_model // self.num_heads
qkv = nn.Dense(self.d_model * 3)(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
def reshape_heads(t):
return t.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
q, k, v = map(reshape_heads, (q, k, v))
scores = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
weights = jax.nn.softmax(scores, axis=-1)
output = jnp.einsum('bhqk,bhvd->bhqd', weights, v)
output = output.transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)
return nn.Dense(self.d_model)(output)
class TransformerBlock(nn.Module):
d_model: int
num_heads: int
ff_dim: int
@nn.compact
def __call__(self, x):
attn = MultiHeadAttention(self.d_model, self.num_heads)(x)
x = nn.LayerNorm()(x + attn)
ff = nn.Dense(self.ff_dim)(x)
ff = nn.relu(ff)
ff = nn.Dense(self.d_model)(ff)
x = nn.LayerNorm()(x + ff)
return x
🔮 전체 Transformer Forecasting 모델
class TransformerForecast(nn.Module):
d_model: int = 64
num_heads: int = 4
ff_dim: int = 128
num_layers: int = 2
out_len: int = 10
out_dim: int = 3 # 다변량
@nn.compact
def __call__(self, x):
# Linear projection + positional encoding
x = nn.Dense(self.d_model)(x)
x += positional_encoding(x.shape[1], self.d_model)
for _ in range(self.num_layers):
x = TransformerBlock(self.d_model, self.num_heads, self.ff_dim)(x)
pooled = jnp.mean(x, axis=1)
x = nn.Dense(self.out_len * self.out_dim)(pooled)
return x.reshape(-1, self.out_len, self.out_dim)
⚙️ 3. 학습 준비
import optax
from flax.training import train_state
model = TransformerForecast()
params = model.init(jax.random.PRNGKey(0), jnp.ones((1, 30, 3)))
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
📉 4. 손실 함수 및 학습
@jax.jit
def loss_fn(params, x, y):
preds = model.apply(params, x)
return jnp.mean((preds - y) ** 2)
@jax.jit
def train_step(state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state, loss
🏃 5. 학습 루프
batch_size = 64
epochs = 20
for epoch in range(epochs):
for i in range(0, len(X), batch_size):
x_batch = X[i:i+batch_size]
y_batch = Y[i:i+batch_size]
state, loss = train_step(state, x_batch, y_batch)
print(f"Epoch {epoch+1} - Loss: {loss:.4f}")
📊 6. 예측 시각화
import matplotlib.pyplot as plt
sample_input = X[-1:]
predicted = model.apply(state.params, sample_input)
for i in range(predicted.shape[-1]):
plt.plot(Y[-1][:, i], label=f"True Var {i}")
plt.plot(predicted[0][:, i], linestyle="--", label=f"Predicted Var {i}")
plt.title("Multivariate Forecast (Transformer)")
plt.legend()
plt.show()
✅ 7. 확장 아이디어
기능 설명
미래 조건 추가 | Decoder 구조 확장 |
Attention 시각화 | 가중치 출력 후 해석 |
실시간 추론 | @jax.jit으로 최적화 |
결측치 보간 | 예측 구조 그대로 활용 가능 |
📌 다음 글 예고: JAX + Prophet 또는 Neural Basis Expansion 기반 시계열 Hybrid 모델 구현
JAX, Time Series Forecasting, Transformer, Multivariate Forecasting, Self-Attention, Flax, 딥러닝 시계열, Multi-step Prediction, JAX 예제, 시계열 분석, Deep Learning, AI Forecasting, 센서 예측, LSTM 대안, 시계열 해석
'Programming > JAX' 카테고리의 다른 글
📌 JAX로 Graph Neural Network (GNN) 기반 시계열 예측 모델 만들기 – 관계 + 시계열 데이터를 동시에 학습하는 방법 (0) | 2025.06.16 |
---|---|
📌 JAX로 Prophet 스타일 + Neural Basis Expansion Hybrid 시계열 예측 모델 구현 (0) | 2025.06.13 |
📌 JAX로 Seq2Seq 시계열 예측 모델 구현 - 미래 시점을 예측하는 딥러닝 구조 (0) | 2025.06.09 |
📌 JAX로 시계열 Autoencoder 구현 - 재구성 오차 기반 이상 탐지 모델 (0) | 2025.06.05 |
📌 JAX로 시계열 모델 해석하기 - SHAP 기반 Explainable AI for Time Series (0) | 2025.06.04 |
※ 이 포스팅은 쿠팡 파트너스 활동의 일환으로, 이에 따른 일정액의 수수료를 제공받습니다.
공지사항
최근에 올라온 글
최근에 달린 댓글
- Total
- Today
- Yesterday
링크
TAG
- 파이썬알고리즘
- AI챗봇
- REACT
- SEO최적화
- NestJS
- flax
- 딥러닝
- 개발블로그
- 백엔드개발
- kotlin
- Ktor
- nextJS
- Docker
- Python
- nodejs
- App Router
- fastapi
- llm
- 웹개발
- gatsbyjs
- CI/CD
- 프론트엔드
- Next.js
- seo 최적화 10개
- 프론트엔드면접
- JAX
- PostgreSQL
- Prisma
- SEO 최적화
- rag
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | ||
6 | 7 | 8 | 9 | 10 | 11 | 12 |
13 | 14 | 15 | 16 | 17 | 18 | 19 |
20 | 21 | 22 | 23 | 24 | 25 | 26 |
27 | 28 | 29 | 30 | 31 |
글 보관함
반응형