Programming/JAX

๐Ÿ“Œ JAX๋กœ ์‹œ๊ณ„์—ด ๋ชจ๋ธ ํ•ด์„ํ•˜๊ธฐ - SHAP ๊ธฐ๋ฐ˜ Explainable AI for Time Series

octo54 2025. 6. 4. 10:53
๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ ์‹œ๊ณ„์—ด ๋ชจ๋ธ ํ•ด์„ํ•˜๊ธฐ - SHAP ๊ธฐ๋ฐ˜ Explainable AI for Time Series


๐Ÿง  ์™œ ์‹œ๊ณ„์—ด ๋ชจ๋ธ๋„ ์„ค๋ช… ๊ฐ€๋Šฅํ•ด์•ผ ํ• ๊นŒ?

๋”ฅ๋Ÿฌ๋‹ ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ๋ชจ๋ธ์€ ์˜ˆ์ธก ์„ฑ๋Šฅ์€ ๋†’์ง€๋งŒ, **"์™œ ๊ทธ๋ ‡๊ฒŒ ์˜ˆ์ธกํ–ˆ๋Š”๊ฐ€?"**๋ผ๋Š” ์งˆ๋ฌธ์— ๋‹ตํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.
ํŠนํžˆ ์˜๋ฃŒ/๊ธˆ์œต/์ œ์กฐ/์Šค๋งˆํŠธ์‹œํ‹ฐ ๋ถ„์•ผ์—์„œ ๋ชจ๋ธ์˜ ํŒ๋‹จ ๊ทผ๊ฑฐ๋Š” ์ค‘์š”ํ•ฉ๋‹ˆ๋‹ค.

**SHAP (SHapley Additive exPlanations)**๋Š” ๊ฐ ์ž…๋ ฅ ์‹œ์ (feature)์ด ์˜ˆ์ธก ๊ฒฐ๊ณผ์— ์–ผ๋งˆ๋‚˜ ๊ธฐ์—ฌํ–ˆ๋Š”์ง€๋ฅผ ์ˆ˜์น˜๋กœ ์ •๋Ÿ‰ํ™”ํ•ฉ๋‹ˆ๋‹ค.


๐ŸŽฏ 1. ๋ชฉํ‘œ

  • JAX ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ๋ถ„๋ฅ˜/์˜ˆ์ธก ๋ชจ๋ธ์˜ ์ž…๋ ฅ์— ๋Œ€ํ•ด
  • **์–ด๋–ค ์‹œ์ (timestep)**์ด ์˜ˆ์ธก์— ๊ฐ€์žฅ ํฌ๊ฒŒ ์˜ํ–ฅ์„ ๋ฏธ์ณค๋Š”์ง€ ๋ถ„์„
  • SHAP ์œ ์‚ฌ ๋กœ์ง์„ JAX+Flax๋กœ ๊ตฌํ˜„

๐Ÿ”ง 2. SHAP ์›๋ฆฌ ์š”์•ฝ

  • SHAP์€ ๊ฒŒ์ž„์ด๋ก ์˜ Shapley value๋ฅผ ํ™œ์šฉํ•˜์—ฌ
    ๊ฐ ํŠน์„ฑ(feature)์˜ **๊ธฐ์—ฌ๋„(Contribution)**๋ฅผ ์ •๋Ÿ‰ํ™”ํ•ฉ๋‹ˆ๋‹ค.
  • ์‹œ๊ณ„์—ด์—์„œ๋Š” ์‹œ๊ฐ„์ถ•(์˜ˆ: tโ‚, tโ‚‚, ..., tโ‚™)์˜ ๊ฐ ํฌ์ธํŠธ๊ฐ€ ๊ธฐ์—ฌ๋„ ๋‹จ์œ„๊ฐ€ ๋ฉ๋‹ˆ๋‹ค.

๐Ÿ’พ 3. ์˜ˆ์ œ ๋ฐ์ดํ„ฐ (์„ผ์„œ ์‹œ๊ณ„์—ด ๋ถ„๋ฅ˜)

# X_test: shape (samples, 100, 3)
# y_test: shape (samples,)
# ๋ชจ๋ธ: TimeSeriesTransformerClassifier

sample = X_test[0:1]  # ํ•˜๋‚˜์˜ ์ƒ˜ํ”Œ ์„ ํƒ

๐Ÿง  4. Perturbation ๊ธฐ๋ฐ˜ SHAP ์œ ์‚ฌ ๊ตฌํ˜„

def shap_values_approx(model, params, sample, baseline=None, n_iter=100, key=None):
    seq_len, num_feat = sample.shape[1], sample.shape[2]
    if baseline is None:
        baseline = jnp.zeros_like(sample)

    shap_vals = jnp.zeros((seq_len, num_feat))

    for i in range(n_iter):
        key, subkey = jax.random.split(key)
        mask = jax.random.bernoulli(subkey, p=0.5, shape=(seq_len, num_feat))
        masked_input = jnp.where(mask, sample[0], baseline[0])
        pred_full = model.apply(params, sample)
        pred_masked = model.apply(params, masked_input[None, ...])

        diff = (pred_full - pred_masked)[0]
        shap_vals += mask * diff[None, None, :num_feat]  # Broadcasting ์ง€์›

    shap_vals /= n_iter
    return shap_vals.squeeze()

๐Ÿ“Š 5. SHAP ๊ฐ’ ์‹œ๊ฐํ™”

๋ฐ˜์‘ํ˜•
import matplotlib.pyplot as plt

shap_matrix = shap_values_approx(model, state.params, sample, key=jax.random.PRNGKey(42))
shap_sum = shap_matrix.sum(axis=1)  # ๊ฐ ์‹œ์ ๋ณ„ ์ค‘์š”๋„

plt.figure(figsize=(12, 4))
plt.plot(shap_sum, label="SHAP Importance per timestep")
plt.xlabel("Timestep")
plt.ylabel("Importance")
plt.title("SHAP-based Time-Series Feature Importance")
plt.legend()
plt.show()

โœ… 6. ํ™œ์šฉ ์‹œ๋‚˜๋ฆฌ์˜ค

๋ถ„์•ผ ํ•ด์„ ์˜ˆ์‹œ

ํ—ฌ์Šค์ผ€์–ด ์‹ฌ๋ฐ•/ํ˜ธํก ๋ณ€ํ™” ์ค‘ ์–ด๋А ์‹œ์ ์ด ์งˆ๋ณ‘ ๋ถ„๋ฅ˜์— ์ค‘์š”ํ–ˆ๋Š”๊ฐ€
์Šค๋งˆํŠธํŒฉํ† ๋ฆฌ ์„ค๋น„ ์ด์ƒ ์˜ˆ์ธก์—์„œ ๊ฐ€์žฅ ์ค‘์š”ํ•œ ์ง„๋™ ๋ณ€ํ™” ๊ตฌ๊ฐ„
๊ธˆ์œต ์ฃผ๊ฐ€ ์˜ˆ์ธก์—์„œ ๊ฐ€์žฅ ์˜ํ–ฅ์„ ๋ผ์นœ ์‹œ์„ธ ๋ณ€ํ™” ์‹œ์ 

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ ์‹œ๊ณ„์—ด Autoencoder ๊ตฌํ˜„ - ์ด์ƒ ํƒ์ง€๋ฅผ ์œ„ํ•œ ์žฌ๊ตฌ์„ฑ ์˜ค์ฐจ ๊ธฐ๋ฐ˜ ๋ชจ๋ธ

๋‹ค์Œ ๊ธ€์—์„œ๋Š” ์‹œ๊ณ„์—ด์„ ์••์ถ•→๋ณต์›ํ•˜๋Š” Autoencoder๋ฅผ JAX๋กœ ๊ตฌํ˜„ํ•˜์—ฌ
**์ด์ƒ ํƒ์ง€ (Anomaly Detection)**์— ํ™œ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•ฉ๋‹ˆ๋‹ค.


 

JAX, Explainable AI, SHAP, ์‹œ๊ณ„์—ด ํ•ด์„, ๋ชจ๋ธ ์„ค๋ช…, Time Series Interpretation, Flax, ์‹œ๊ณ„์—ด ๋ถ„๋ฅ˜, Feature Importance, AI ํ•ด์„ ๊ฐ€๋Šฅ์„ฑ, Transformer ํ•ด์„, Python SHAP, Game Theory, Shapley Value, AI Transparency, Time Series XAI