ํฐ์คํ ๋ฆฌ ๋ทฐ
Programming/JAX
๐ JAX๋ก ๋ฉํฐ์๊ณ์ด ๋ชจ๋ธ ๊ตฌํ - ๋ค์ค ์ผ์ ์์ธก ๋ฐ ์ด์ ํ์ง
octo54 2025. 5. 27. 10:53๋ฐ์ํ
๐ JAX๋ก ๋ฉํฐ์๊ณ์ด ๋ชจ๋ธ ๊ตฌํ - ๋ค์ค ์ผ์ ์์ธก ๋ฐ ์ด์ ํ์ง
๐ ๋ฉํฐ์๊ณ์ด(Multivariate Time Series)์ด๋?
๋ฉํฐ์๊ณ์ด์ ์ฌ๋ฌ ๊ฐ์ ์๊ณ์ด ๋ฐ์ดํฐ(์: ์ฌ๋ฌ ์ผ์, ์งํ ๋ฑ)๊ฐ ๋์์ ์กด์ฌํ๋ ์๊ณ์ด์ ์๋ฏธํฉ๋๋ค.
๋จ์ผ ์๊ณ์ด ๋ชจ๋ธ์ ํ๋์ ๋ณ์๋ง ์์ธกํ์ง๋ง,
๋ฉํฐ์๊ณ์ด ๋ชจ๋ธ์ ๋ณ์๋ค ๊ฐ์ **์ํธ์์ฉ๊ณผ ๊ณต๋ณ์ฑ(covariance)**์ ํจ๊ป ํ์ตํฉ๋๋ค.
๐ก 1. ์ฃผ์ ํ์ฉ ์ฌ๋ก
- ์ค๋งํธ ํฉํ ๋ฆฌ ์ผ์ ์ด์ ํ์ง
- ๊ธ์ต ์์ฐ๊ตฐ์ ๋์ ์์ธก
- IoT ๋๋ฐ์ด์ค์ ์์คํ ์ ์ฒด ์ํ ๋ชจ๋ํฐ๋ง
๐ 2. ๋ฐ์ดํฐ ์ค๋น - ์: ๊ณต์ฅ ์ผ์ 3์ข ๋ฐ์ดํฐ
import pandas as pd
import jax.numpy as jnp
from sklearn.preprocessing import MinMaxScaler
df = pd.read_csv("factory_multisensor.csv") # ์ปฌ๋ผ: ['temp', 'vibration', 'pressure']
features = ['temp', 'vibration', 'pressure']
data = df[features].values
# ์ ๊ทํ
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)
# ์๋์ฐ ์์ฑ
def create_multivariate_windows(data, window=30):
X, y = [], []
for i in range(len(data) - window):
X.append(data[i:i+window]) # shape: (window, n_features)
y.append(data[i+window]) # shape: (n_features,)
return jnp.array(X), jnp.array(y)
x_data, y_data = create_multivariate_windows(data_scaled)
๐ง 3. ๋ฉํฐ์๊ณ์ด RNN ๋ชจ๋ธ ๊ตฌํ (Flax)
๋ฐ์ํ
๐ RNN ๊ธฐ๋ฐ ๋ฉํฐ์๊ณ์ด ๋ชจ๋ธ
from flax import linen as nn
class MultivariateRNN(nn.Module):
hidden_size: int
output_size: int # ์์ธกํ ๋ณ์ ์ (e.g., 3)
@nn.compact
def __call__(self, x):
batch_size, seq_len, feature_dim = x.shape
h = jnp.zeros((batch_size, self.hidden_size))
rnn_cell = nn.recurrent.LSTMCell()
carry = rnn_cell.initialize_carry(jax.random.PRNGKey(0), (batch_size,), self.hidden_size)
for t in range(seq_len):
carry, h = rnn_cell(carry, x[:, t, :])
output = nn.Dense(self.output_size)(h)
return output
โ๏ธ 4. ๋ชจ๋ธ ์ด๊ธฐํ ๋ฐ ์ตํฐ๋ง์ด์ ์ค์
from flax.training import train_state
import optax
import jax
model = MultivariateRNN(hidden_size=64, output_size=3)
key = jax.random.PRNGKey(42)
params = model.init(key, jnp.ones((1, 30, 3))) # batch=1, window=30, feature=3
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
๐ 5. ์์ค ํจ์ ๋ฐ ํ์ต ๋ฃจํ
@jax.jit
def mse_loss(params, x, y):
pred = model.apply(params, x)
return jnp.mean((pred - y) ** 2)
@jax.jit
def train_step(state, x, y):
loss, grads = jax.value_and_grad(mse_loss)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state, loss
๐ 6. ๋ชจ๋ธ ํ์ต
batch_size = 64
epochs = 20
for epoch in range(epochs):
for i in range(0, len(x_data), batch_size):
x_batch = x_data[i:i+batch_size]
y_batch = y_data[i:i+batch_size]
state, loss = train_step(state, x_batch, y_batch)
print(f"Epoch {epoch+1}, Loss: {loss:.5f}")
โ 7. ์ด์ ํ์ง (์ฌ๊ตฌ์ฑ/์์ธก ๊ธฐ๋ฐ)
def compute_prediction_errors(model, params, x_data, y_data):
preds = model.apply(params, x_data)
errors = jnp.mean((preds - y_data) ** 2, axis=1)
return errors
# ์์ธก ์ค์ฐจ ๊ธฐ๋ฐ ์ด์์น ํ์ง
errors = compute_prediction_errors(model, state.params, x_data, y_data)
threshold = jnp.mean(errors) + 3 * jnp.std(errors)
anomalies = errors > threshold
๐ 8. ๊ฒฐ๊ณผ ์๊ฐํ
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
plt.plot(errors, label='Prediction Error')
plt.axhline(threshold, color='red', linestyle='--', label='Anomaly Threshold')
plt.title("๋ฉํฐ์๊ณ์ด ์ด์ ํ์ง - ์์ธก ์ค์ฐจ ๊ธฐ๋ฐ")
plt.legend()
plt.show()
๐ง ์ถ๊ฐ ๊ณ ๋ํ ์์ด๋์ด
- โ Transformer ๊ธฐ๋ฐ ๋ชจ๋ธ๋ก ํ์ฅ (Multivariate Transformer)
- โ ์์ธก + ์ฌ๊ตฌ์ฑ ์ตํฉ ์ด์ ํ์ง ๋ชจ๋ธ
- โ Dynamic Thresholding ๋์
- โ Variational Autoencoder (MV-VAE) ๊ธฐ๋ฐ ์์ถ
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก Probabilistic Time Series Forecasting - ์์ธก ๋ถํ์ค์ฑ ๋ชจ๋ธ๋งํ๊ธฐ
๋ค์ ๊ธ์์๋ **์์ธก๊ฐ์ ํ๋ฅ ๋ถํฌ(Confidence Interval)**๋ฅผ ์ถ์ ํ ์ ์๋
๋ฒ ์ด์ง์ ์๊ณ์ด ๋ชจ๋ธ์ JAX๋ก ๊ตฌํํฉ๋๋ค.
JAX, ๋ฉํฐ์๊ณ์ด, Time Series, ์ด์ ํ์ง, Multivariate, RNN, LSTM, Flax, ๊ณ ์ ์ฐ์ฐ, ๋ฅ๋ฌ๋, ์๊ณ์ด ์์ธก, ๊ณ ์ฑ๋ฅ ์ปดํจํ , Anomaly Detection, ์ผ์ ๋ฐ์ดํฐ, ๊ณต์ ๋ฐ์ดํฐ, JAX ๋ชจ๋ธ, JAX ์์ , Python
'Programming > JAX' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
โป ์ด ํฌ์คํ
์ ์ฟ ํก ํํธ๋์ค ํ๋์ ์ผํ์ผ๋ก, ์ด์ ๋ฐ๋ฅธ ์ผ์ ์ก์ ์์๋ฃ๋ฅผ ์ ๊ณต๋ฐ์ต๋๋ค.
๊ณต์ง์ฌํญ
์ต๊ทผ์ ์ฌ๋ผ์จ ๊ธ
์ต๊ทผ์ ๋ฌ๋ฆฐ ๋๊ธ
- Total
- Today
- Yesterday
๋งํฌ
TAG
- Ktor
- JAX
- SEO์ต์ ํ
- ์น๊ฐ๋ฐ
- AI์ฑ๋ด
- SEO ์ต์ ํ
- ๋ฐฑ์๋๊ฐ๋ฐ
- Next.js
- Prisma
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- REACT
- ํ๋ก ํธ์๋
- rag
- CI/CD
- ํ๋ก ํธ์๋๋ฉด์
- App Router
- nextJS
- NestJS
- PostgreSQL
- ๋ฅ๋ฌ๋
- seo ์ต์ ํ 10๊ฐ
- Python
- Webpack
- llm
- kotlin
- gatsbyjs
- nodejs
- fastapi
- ํ์ด์ฌ ์๊ณ ๋ฆฌ์ฆ
- Docker
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |
๊ธ ๋ณด๊ดํจ
๋ฐ์ํ