Programming/JAX

πŸ“Œ JAX둜 λ©€ν‹°μ‹œκ³„μ—΄ λͺ¨λΈ κ΅¬ν˜„ - 닀쀑 μ„Όμ„œ 예츑 및 이상 탐지

octo54 2025. 5. 27. 10:53
λ°˜μ‘ν˜•

πŸ“Œ JAX둜 λ©€ν‹°μ‹œκ³„μ—΄ λͺ¨λΈ κ΅¬ν˜„ - 닀쀑 μ„Όμ„œ 예츑 및 이상 탐지


πŸš€ λ©€ν‹°μ‹œκ³„μ—΄(Multivariate Time Series)μ΄λž€?

λ©€ν‹°μ‹œκ³„μ—΄μ€ μ—¬λŸ¬ 개의 μ‹œκ³„μ—΄ 데이터(예: μ—¬λŸ¬ μ„Όμ„œ, μ§€ν‘œ λ“±)κ°€ λ™μ‹œμ— μ‘΄μž¬ν•˜λŠ” μ‹œκ³„μ—΄μ„ μ˜λ―Έν•©λ‹ˆλ‹€.
단일 μ‹œκ³„μ—΄ λͺ¨λΈμ€ ν•˜λ‚˜μ˜ λ³€μˆ˜λ§Œ μ˜ˆμΈ‘ν•˜μ§€λ§Œ,
λ©€ν‹°μ‹œκ³„μ—΄ λͺ¨λΈμ€ λ³€μˆ˜λ“€ κ°„μ˜ **μƒν˜Έμž‘μš©κ³Ό 곡변성(covariance)**을 ν•¨κ»˜ ν•™μŠ΅ν•©λ‹ˆλ‹€.


πŸ’‘ 1. μ£Όμš” ν™œμš© 사둀

  • 슀마트 νŒ©ν† λ¦¬ μ„Όμ„œ 이상 탐지
  • 금육 μžμ‚°κ΅°μ˜ λ™μ‹œ 예츑
  • IoT λ””λ°”μ΄μŠ€μ˜ μ‹œμŠ€ν…œ 전체 μƒνƒœ λͺ¨λ‹ˆν„°λ§

πŸ“Š 2. 데이터 μ€€λΉ„ - 예: 곡μž₯ μ„Όμ„œ 3μ’… 데이터

import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler

df = pd.read_csv("factory_multisensor.csv")  # 컬럼: ['temp', 'vibration', 'pressure']
features = ['temp', 'vibration', 'pressure']
data = df[features].values

# μ •κ·œν™”
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)

# μœˆλ„μš° 생성
def create_multivariate_windows(data, window=30):
    X, y = [], []
    for i in range(len(data) - window):
        X.append(data[i:i+window])          # shape: (window, n_features)
        y.append(data[i+window])            # shape: (n_features,)
    return jnp.array(X), jnp.array(y)

x_data, y_data = create_multivariate_windows(data_scaled)

🧠 3. λ©€ν‹°μ‹œκ³„μ—΄ RNN λͺ¨λΈ κ΅¬ν˜„ (Flax)

λ°˜μ‘ν˜•

πŸ“ RNN 기반 λ©€ν‹°μ‹œκ³„μ—΄ λͺ¨λΈ

from flax import linen as nn

class MultivariateRNN(nn.Module):
    hidden_size: int
    output_size: int  # μ˜ˆμΈ‘ν•  λ³€μˆ˜ 수 (e.g., 3)

    @nn.compact
    def __call__(self, x):
        batch_size, seq_len, feature_dim = x.shape
        h = jnp.zeros((batch_size, self.hidden_size))

        rnn_cell = nn.recurrent.LSTMCell()
        carry = rnn_cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.hidden_size)

        for t in range(seq_len):
            carry, h = rnn_cell(carry, x[:, t, :])

        output = nn.Dense(self.output_size)(h)
        return output

βš™οΈ 4. λͺ¨λΈ μ΄ˆκΈ°ν™” 및 μ˜΅ν‹°λ§ˆμ΄μ € μ„€μ •

from flax.training import train_state
import optax
import jax

model = MultivariateRNN(hidden_size=64, output_size=3)
key = jax.random.PRNGKey(42)

params = model.init(key, jnp.ones((1, 30, 3)))  # batch=1, window=30, feature=3
tx = optax.adam(1e-3)

state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

πŸ“‰ 5. 손싀 ν•¨μˆ˜ 및 ν•™μŠ΅ 루프

@jax.jit
def mse_loss(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(state, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

πŸƒ 6. λͺ¨λΈ ν•™μŠ΅

batch_size = 64
epochs = 20

for epoch in range(epochs):
    for i in range(0, len(x_data), batch_size):
        x_batch = x_data[i:i+batch_size]
        y_batch = y_data[i:i+batch_size]
        state, loss = train_step(state, x_batch, y_batch)
    print(f"Epoch {epoch+1}, Loss: {loss:.5f}")

βœ… 7. 이상 탐지 (μž¬κ΅¬μ„±/예츑 기반)

def compute_prediction_errors(model, params, x_data, y_data):
    preds = model.apply(params, x_data)
    errors = jnp.mean((preds - y_data) ** 2, axis=1)
    return errors

# 예츑 였차 기반 μ΄μƒμΉ˜ 탐지
errors = compute_prediction_errors(model, state.params, x_data, y_data)
threshold = jnp.mean(errors) + 3 * jnp.std(errors)
anomalies = errors > threshold

πŸ“ˆ 8. κ²°κ³Ό μ‹œκ°ν™”

import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plt.plot(errors, label='Prediction Error')
plt.axhline(threshold, color='red', linestyle='--', label='Anomaly Threshold')
plt.title("λ©€ν‹°μ‹œκ³„μ—΄ 이상 탐지 - 예츑 였차 기반")
plt.legend()
plt.show()

🧠 μΆ”κ°€ 고도화 아이디어

  • βœ… Transformer 기반 λͺ¨λΈλ‘œ ν™•μž₯ (Multivariate Transformer)
  • βœ… 예츑 + μž¬κ΅¬μ„± μœ΅ν•© 이상 탐지 λͺ¨λΈ
  • βœ… Dynamic Thresholding λ„μž…
  • βœ… Variational Autoencoder (MV-VAE) 기반 μ••μΆ•

πŸ“Œ λ‹€μŒ κΈ€ 예고: JAX둜 Probabilistic Time Series Forecasting - 예츑 λΆˆν™•μ‹€μ„± λͺ¨λΈλ§ν•˜κΈ°

λ‹€μŒ κΈ€μ—μ„œλŠ” **μ˜ˆμΈ‘κ°’μ˜ ν™•λ₯  뢄포(Confidence Interval)**λ₯Ό μΆ”μ •ν•  수 μžˆλŠ”
λ² μ΄μ§€μ•ˆ μ‹œκ³„μ—΄ λͺ¨λΈμ„ JAX둜 κ΅¬ν˜„ν•©λ‹ˆλ‹€.


 

JAX, λ©€ν‹°μ‹œκ³„μ—΄, Time Series, 이상 탐지, Multivariate, RNN, LSTM, Flax, 고속 μ—°μ‚°, λ”₯λŸ¬λ‹, μ‹œκ³„μ—΄ 예츑, κ³ μ„±λŠ₯ μ»΄ν“¨νŒ…, Anomaly Detection, μ„Όμ„œ 데이터, 곡정 데이터, JAX λͺ¨λΈ, JAX 예제, Python