Programming/JAX
π JAXλ‘ λ©ν°μκ³μ΄ λͺ¨λΈ ꡬν - λ€μ€ μΌμ μμΈ‘ λ° μ΄μ νμ§
octo54
2025. 5. 27. 10:53
λ°μν
π JAXλ‘ λ©ν°μκ³μ΄ λͺ¨λΈ ꡬν - λ€μ€ μΌμ μμΈ‘ λ° μ΄μ νμ§
π λ©ν°μκ³μ΄(Multivariate Time Series)μ΄λ?
λ©ν°μκ³μ΄μ μ¬λ¬ κ°μ μκ³μ΄ λ°μ΄ν°(μ: μ¬λ¬ μΌμ, μ§ν λ±)κ° λμμ μ‘΄μ¬νλ μκ³μ΄μ μλ―Έν©λλ€.
λ¨μΌ μκ³μ΄ λͺ¨λΈμ νλμ λ³μλ§ μμΈ‘νμ§λ§,
λ©ν°μκ³μ΄ λͺ¨λΈμ λ³μλ€ κ°μ **μνΈμμ©κ³Ό 곡λ³μ±(covariance)**μ ν¨κ» νμ΅ν©λλ€.
π‘ 1. μ£Όμ νμ© μ¬λ‘
- μ€λ§νΈ ν©ν 리 μΌμ μ΄μ νμ§
- κΈμ΅ μμ°κ΅°μ λμ μμΈ‘
- IoT λλ°μ΄μ€μ μμ€ν μ 체 μν λͺ¨λν°λ§
π 2. λ°μ΄ν° μ€λΉ - μ: 곡μ₯ μΌμ 3μ’ λ°μ΄ν°
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler
df = pd.read_csv("factory_multisensor.csv") # 컬λΌ: ['temp', 'vibration', 'pressure']
features = ['temp', 'vibration', 'pressure']
data = df[features].values
# μ κ·ν
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)
# μλμ° μμ±
def create_multivariate_windows(data, window=30):
X, y = [], []
for i in range(len(data) - window):
X.append(data[i:i+window]) # shape: (window, n_features)
y.append(data[i+window]) # shape: (n_features,)
return jnp.array(X), jnp.array(y)
x_data, y_data = create_multivariate_windows(data_scaled)
π§ 3. λ©ν°μκ³μ΄ RNN λͺ¨λΈ ꡬν (Flax)
λ°μν
π RNN κΈ°λ° λ©ν°μκ³μ΄ λͺ¨λΈ
from flax import linen as nn
class MultivariateRNN(nn.Module):
hidden_size: int
output_size: int # μμΈ‘ν λ³μ μ (e.g., 3)
@nn.compact
def __call__(self, x):
batch_size, seq_len, feature_dim = x.shape
h = jnp.zeros((batch_size, self.hidden_size))
rnn_cell = nn.recurrent.LSTMCell()
carry = rnn_cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.hidden_size)
for t in range(seq_len):
carry, h = rnn_cell(carry, x[:, t, :])
output = nn.Dense(self.output_size)(h)
return output
βοΈ 4. λͺ¨λΈ μ΄κΈ°ν λ° μ΅ν°λ§μ΄μ μ€μ
from flax.training import train_state
import optax
import jax
model = MultivariateRNN(hidden_size=64, output_size=3)
key = jax.random.PRNGKey(42)
params = model.init(key, jnp.ones((1, 30, 3))) # batch=1, window=30, feature=3
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
π 5. μμ€ ν¨μ λ° νμ΅ λ£¨ν
@jax.jit
def mse_loss(params, x, y):
pred = model.apply(params, x)
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(state, x, y):
loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state, loss
π 6. λͺ¨λΈ νμ΅
batch_size = 64
epochs = 20
for epoch in range(epochs):
for i in range(0, len(x_data), batch_size):
x_batch = x_data[i:i+batch_size]
y_batch = y_data[i:i+batch_size]
state, loss = train_step(state, x_batch, y_batch)
print(f"Epoch {epoch+1}, Loss: {loss:.5f}")
β 7. μ΄μ νμ§ (μ¬κ΅¬μ±/μμΈ‘ κΈ°λ°)
def compute_prediction_errors(model, params, x_data, y_data):
preds = model.apply(params, x_data)
errors = jnp.mean((preds - y_data) ** 2, axis=1)
return errors
# μμΈ‘ μ€μ°¨ κΈ°λ° μ΄μμΉ νμ§
errors = compute_prediction_errors(model, state.params, x_data, y_data)
threshold = jnp.mean(errors) + 3 * jnp.std(errors)
anomalies = errors > threshold
π 8. κ²°κ³Ό μκ°ν
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
plt.plot(errors, label='Prediction Error')
plt.axhline(threshold, color='red', linestyle='--', label='Anomaly Threshold')
plt.title("λ©ν°μκ³μ΄ μ΄μ νμ§ - μμΈ‘ μ€μ°¨ κΈ°λ°")
plt.legend()
plt.show()
π§ μΆκ° κ³ λν μμ΄λμ΄
- β Transformer κΈ°λ° λͺ¨λΈλ‘ νμ₯ (Multivariate Transformer)
- β μμΈ‘ + μ¬κ΅¬μ± μ΅ν© μ΄μ νμ§ λͺ¨λΈ
- β Dynamic Thresholding λμ
- β Variational Autoencoder (MV-VAE) κΈ°λ° μμΆ
π λ€μ κΈ μκ³ : JAXλ‘ Probabilistic Time Series Forecasting - μμΈ‘ λΆνμ€μ± λͺ¨λΈλ§νκΈ°
λ€μ κΈμμλ **μμΈ‘κ°μ νλ₯ λΆν¬(Confidence Interval)**λ₯Ό μΆμ ν μ μλ
λ² μ΄μ§μ μκ³μ΄ λͺ¨λΈμ JAXλ‘ κ΅¬νν©λλ€.
JAX, λ©ν°μκ³μ΄, Time Series, μ΄μ νμ§, Multivariate, RNN, LSTM, Flax, κ³ μ μ°μ°, λ₯λ¬λ, μκ³μ΄ μμΈ‘, κ³ μ±λ₯ μ»΄ν¨ν , Anomaly Detection, μΌμ λ°μ΄ν°, 곡μ λ°μ΄ν°, JAX λͺ¨λΈ, JAX μμ , Python