ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ ๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด ๋ชจ๋ธ ๊ตฌํ˜„ - ๋‹ค์ค‘ ์„ผ์„œ ์˜ˆ์ธก ๋ฐ ์ด์ƒ ํƒ์ง€


๐Ÿš€ ๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด(Multivariate Time Series)์ด๋ž€?

๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด์€ ์—ฌ๋Ÿฌ ๊ฐœ์˜ ์‹œ๊ณ„์—ด ๋ฐ์ดํ„ฐ(์˜ˆ: ์—ฌ๋Ÿฌ ์„ผ์„œ, ์ง€ํ‘œ ๋“ฑ)๊ฐ€ ๋™์‹œ์— ์กด์žฌํ•˜๋Š” ์‹œ๊ณ„์—ด์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.
๋‹จ์ผ ์‹œ๊ณ„์—ด ๋ชจ๋ธ์€ ํ•˜๋‚˜์˜ ๋ณ€์ˆ˜๋งŒ ์˜ˆ์ธกํ•˜์ง€๋งŒ,
๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด ๋ชจ๋ธ์€ ๋ณ€์ˆ˜๋“ค ๊ฐ„์˜ **์ƒํ˜ธ์ž‘์šฉ๊ณผ ๊ณต๋ณ€์„ฑ(covariance)**์„ ํ•จ๊ป˜ ํ•™์Šตํ•ฉ๋‹ˆ๋‹ค.


๐Ÿ’ก 1. ์ฃผ์š” ํ™œ์šฉ ์‚ฌ๋ก€

  • ์Šค๋งˆํŠธ ํŒฉํ† ๋ฆฌ ์„ผ์„œ ์ด์ƒ ํƒ์ง€
  • ๊ธˆ์œต ์ž์‚ฐ๊ตฐ์˜ ๋™์‹œ ์˜ˆ์ธก
  • IoT ๋””๋ฐ”์ด์Šค์˜ ์‹œ์Šคํ…œ ์ „์ฒด ์ƒํƒœ ๋ชจ๋‹ˆํ„ฐ๋ง

๐Ÿ“Š 2. ๋ฐ์ดํ„ฐ ์ค€๋น„ - ์˜ˆ: ๊ณต์žฅ ์„ผ์„œ 3์ข… ๋ฐ์ดํ„ฐ

import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler

df = pd.read_csv("factory_multisensor.csv")  # ์ปฌ๋Ÿผ: ['temp', 'vibration', 'pressure']
features = ['temp', 'vibration', 'pressure']
data = df[features].values

# ์ •๊ทœํ™”
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)

# ์œˆ๋„์šฐ ์ƒ์„ฑ
def create_multivariate_windows(data, window=30):
    X, y = [], []
    for i in range(len(data) - window):
        X.append(data[i:i+window])          # shape: (window, n_features)
        y.append(data[i+window])            # shape: (n_features,)
    return jnp.array(X), jnp.array(y)

x_data, y_data = create_multivariate_windows(data_scaled)

๐Ÿง  3. ๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด RNN ๋ชจ๋ธ ๊ตฌํ˜„ (Flax)

๋ฐ˜์‘ํ˜•

๐Ÿ“ RNN ๊ธฐ๋ฐ˜ ๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด ๋ชจ๋ธ

from flax import linen as nn

class MultivariateRNN(nn.Module):
    hidden_size: int
    output_size: int  # ์˜ˆ์ธกํ•  ๋ณ€์ˆ˜ ์ˆ˜ (e.g., 3)

    @nn.compact
    def __call__(self, x):
        batch_size, seq_len, feature_dim = x.shape
        h = jnp.zeros((batch_size, self.hidden_size))

        rnn_cell = nn.recurrent.LSTMCell()
        carry = rnn_cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.hidden_size)

        for t in range(seq_len):
            carry, h = rnn_cell(carry, x[:, t, :])

        output = nn.Dense(self.output_size)(h)
        return output

โš™๏ธ 4. ๋ชจ๋ธ ์ดˆ๊ธฐํ™” ๋ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •

from flax.training import train_state
import optax
import jax

model = MultivariateRNN(hidden_size=64, output_size=3)
key = jax.random.PRNGKey(42)

params = model.init(key, jnp.ones((1, 30, 3)))  # batch=1, window=30, feature=3
tx = optax.adam(1e-3)

state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

๐Ÿ“‰ 5. ์†์‹ค ํ•จ์ˆ˜ ๋ฐ ํ•™์Šต ๋ฃจํ”„

@jax.jit
def mse_loss(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred - y) ** 2)

@jax.jit
def train_step(state, x, y):
    loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

๐Ÿƒ 6. ๋ชจ๋ธ ํ•™์Šต

batch_size = 64
epochs = 20

for epoch in range(epochs):
    for i in range(0, len(x_data), batch_size):
        x_batch = x_data[i:i+batch_size]
        y_batch = y_data[i:i+batch_size]
        state, loss = train_step(state, x_batch, y_batch)
    print(f"Epoch {epoch+1}, Loss: {loss:.5f}")

โœ… 7. ์ด์ƒ ํƒ์ง€ (์žฌ๊ตฌ์„ฑ/์˜ˆ์ธก ๊ธฐ๋ฐ˜)

def compute_prediction_errors(model, params, x_data, y_data):
    preds = model.apply(params, x_data)
    errors = jnp.mean((preds - y_data) ** 2, axis=1)
    return errors

# ์˜ˆ์ธก ์˜ค์ฐจ ๊ธฐ๋ฐ˜ ์ด์ƒ์น˜ ํƒ์ง€
errors = compute_prediction_errors(model, state.params, x_data, y_data)
threshold = jnp.mean(errors) + 3 * jnp.std(errors)
anomalies = errors > threshold

๐Ÿ“ˆ 8. ๊ฒฐ๊ณผ ์‹œ๊ฐํ™”

import matplotlib.pyplot as plt

plt.figure(figsize=(15, 5))
plt.plot(errors, label='Prediction Error')
plt.axhline(threshold, color='red', linestyle='--', label='Anomaly Threshold')
plt.title("๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด ์ด์ƒ ํƒ์ง€ - ์˜ˆ์ธก ์˜ค์ฐจ ๊ธฐ๋ฐ˜")
plt.legend()
plt.show()

๐Ÿง  ์ถ”๊ฐ€ ๊ณ ๋„ํ™” ์•„์ด๋””์–ด

  • โœ… Transformer ๊ธฐ๋ฐ˜ ๋ชจ๋ธ๋กœ ํ™•์žฅ (Multivariate Transformer)
  • โœ… ์˜ˆ์ธก + ์žฌ๊ตฌ์„ฑ ์œตํ•ฉ ์ด์ƒ ํƒ์ง€ ๋ชจ๋ธ
  • โœ… Dynamic Thresholding ๋„์ž…
  • โœ… Variational Autoencoder (MV-VAE) ๊ธฐ๋ฐ˜ ์••์ถ•

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ Probabilistic Time Series Forecasting - ์˜ˆ์ธก ๋ถˆํ™•์‹ค์„ฑ ๋ชจ๋ธ๋งํ•˜๊ธฐ

๋‹ค์Œ ๊ธ€์—์„œ๋Š” **์˜ˆ์ธก๊ฐ’์˜ ํ™•๋ฅ  ๋ถ„ํฌ(Confidence Interval)**๋ฅผ ์ถ”์ •ํ•  ์ˆ˜ ์žˆ๋Š”
๋ฒ ์ด์ง€์•ˆ ์‹œ๊ณ„์—ด ๋ชจ๋ธ์„ JAX๋กœ ๊ตฌํ˜„ํ•ฉ๋‹ˆ๋‹ค.


 

JAX, ๋ฉ€ํ‹ฐ์‹œ๊ณ„์—ด, Time Series, ์ด์ƒ ํƒ์ง€, Multivariate, RNN, LSTM, Flax, ๊ณ ์† ์—ฐ์‚ฐ, ๋”ฅ๋Ÿฌ๋‹, ์‹œ๊ณ„์—ด ์˜ˆ์ธก, ๊ณ ์„ฑ๋Šฅ ์ปดํ“จํŒ…, Anomaly Detection, ์„ผ์„œ ๋ฐ์ดํ„ฐ, ๊ณต์ • ๋ฐ์ดํ„ฐ, JAX ๋ชจ๋ธ, JAX ์˜ˆ์ œ, Python

โ€ป ์ด ํฌ์ŠคํŒ…์€ ์ฟ ํŒก ํŒŒํŠธ๋„ˆ์Šค ํ™œ๋™์˜ ์ผํ™˜์œผ๋กœ, ์ด์— ๋”ฐ๋ฅธ ์ผ์ •์•ก์˜ ์ˆ˜์ˆ˜๋ฃŒ๋ฅผ ์ œ๊ณต๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ณต์ง€์‚ฌํ•ญ
์ตœ๊ทผ์— ์˜ฌ๋ผ์˜จ ๊ธ€
์ตœ๊ทผ์— ๋‹ฌ๋ฆฐ ๋Œ“๊ธ€
Total
Today
Yesterday
๋งํฌ
ยซ   2025/06   ยป
์ผ ์›” ํ™” ์ˆ˜ ๋ชฉ ๊ธˆ ํ† 
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
๊ธ€ ๋ณด๊ด€ํ•จ
๋ฐ˜์‘ํ˜•