๐ JAX๋ก ์๊ณ์ด ๋ชจ๋ธ ํด์ํ๊ธฐ - SHAP ๊ธฐ๋ฐ Explainable AI for Time Series
๐ JAX๋ก ์๊ณ์ด ๋ชจ๋ธ ํด์ํ๊ธฐ - SHAP ๊ธฐ๋ฐ Explainable AI for Time Series
๐ง ์ ์๊ณ์ด ๋ชจ๋ธ๋ ์ค๋ช ๊ฐ๋ฅํด์ผ ํ ๊น?
๋ฅ๋ฌ๋ ๊ธฐ๋ฐ ์๊ณ์ด ๋ชจ๋ธ์ ์์ธก ์ฑ๋ฅ์ ๋์ง๋ง, **"์ ๊ทธ๋ ๊ฒ ์์ธกํ๋๊ฐ?"**๋ผ๋ ์ง๋ฌธ์ ๋ตํ๊ธฐ ์ด๋ ต์ต๋๋ค.
ํนํ ์๋ฃ/๊ธ์ต/์ ์กฐ/์ค๋งํธ์ํฐ ๋ถ์ผ์์ ๋ชจ๋ธ์ ํ๋จ ๊ทผ๊ฑฐ๋ ์ค์ํฉ๋๋ค.
**SHAP (SHapley Additive exPlanations)**๋ ๊ฐ ์ ๋ ฅ ์์ (feature)์ด ์์ธก ๊ฒฐ๊ณผ์ ์ผ๋ง๋ ๊ธฐ์ฌํ๋์ง๋ฅผ ์์น๋ก ์ ๋ํํฉ๋๋ค.
๐ฏ 1. ๋ชฉํ
- JAX ๊ธฐ๋ฐ ์๊ณ์ด ๋ถ๋ฅ/์์ธก ๋ชจ๋ธ์ ์ ๋ ฅ์ ๋ํด
- **์ด๋ค ์์ (timestep)**์ด ์์ธก์ ๊ฐ์ฅ ํฌ๊ฒ ์ํฅ์ ๋ฏธ์ณค๋์ง ๋ถ์
- SHAP ์ ์ฌ ๋ก์ง์ JAX+Flax๋ก ๊ตฌํ
๐ง 2. SHAP ์๋ฆฌ ์์ฝ
- SHAP์ ๊ฒ์์ด๋ก ์ Shapley value๋ฅผ ํ์ฉํ์ฌ
๊ฐ ํน์ฑ(feature)์ **๊ธฐ์ฌ๋(Contribution)**๋ฅผ ์ ๋ํํฉ๋๋ค. - ์๊ณ์ด์์๋ ์๊ฐ์ถ(์: tโ, tโ, ..., tโ)์ ๊ฐ ํฌ์ธํธ๊ฐ ๊ธฐ์ฌ๋ ๋จ์๊ฐ ๋ฉ๋๋ค.
๐พ 3. ์์ ๋ฐ์ดํฐ (์ผ์ ์๊ณ์ด ๋ถ๋ฅ)
# X_test: shape (samples, 100, 3)
# y_test: shape (samples,)
# ๋ชจ๋ธ: TimeSeriesTransformerClassifier
sample = X_test[0:1] # ํ๋์ ์ํ ์ ํ
๐ง 4. Perturbation ๊ธฐ๋ฐ SHAP ์ ์ฌ ๊ตฌํ
def shap_values_approx(model, params, sample, baseline=None, n_iter=100, key=None):
seq_len, num_feat = sample.shape[1], sample.shape[2]
if baseline is None:
baseline = jnp.zeros_like(sample)
shap_vals = jnp.zeros((seq_len, num_feat))
for i in range(n_iter):
key, subkey = jax.random.split(key)
mask = jax.random.bernoulli(subkey, p=0.5, shape=(seq_len, num_feat))
masked_input = jnp.where(mask, sample[0], baseline[0])
pred_full = model.apply(params, sample)
pred_masked = model.apply(params, masked_input[None, ...])
diff = (pred_full - pred_masked)[0]
shap_vals += mask * diff[None, None, :num_feat] # Broadcasting ์ง์
shap_vals /= n_iter
return shap_vals.squeeze()
๐ 5. SHAP ๊ฐ ์๊ฐํ
import matplotlib.pyplot as plt
shap_matrix = shap_values_approx(model, state.params, sample, key=jax.random.PRNGKey(42))
shap_sum = shap_matrix.sum(axis=1) # ๊ฐ ์์ ๋ณ ์ค์๋
plt.figure(figsize=(12, 4))
plt.plot(shap_sum, label="SHAP Importance per timestep")
plt.xlabel("Timestep")
plt.ylabel("Importance")
plt.title("SHAP-based Time-Series Feature Importance")
plt.legend()
plt.show()
โ 6. ํ์ฉ ์๋๋ฆฌ์ค
๋ถ์ผ ํด์ ์์
ํฌ์ค์ผ์ด | ์ฌ๋ฐ/ํธํก ๋ณํ ์ค ์ด๋ ์์ ์ด ์ง๋ณ ๋ถ๋ฅ์ ์ค์ํ๋๊ฐ |
์ค๋งํธํฉํ ๋ฆฌ | ์ค๋น ์ด์ ์์ธก์์ ๊ฐ์ฅ ์ค์ํ ์ง๋ ๋ณํ ๊ตฌ๊ฐ |
๊ธ์ต | ์ฃผ๊ฐ ์์ธก์์ ๊ฐ์ฅ ์ํฅ์ ๋ผ์น ์์ธ ๋ณํ ์์ |
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก ์๊ณ์ด Autoencoder ๊ตฌํ - ์ด์ ํ์ง๋ฅผ ์ํ ์ฌ๊ตฌ์ฑ ์ค์ฐจ ๊ธฐ๋ฐ ๋ชจ๋ธ
๋ค์ ๊ธ์์๋ ์๊ณ์ด์ ์์ถ→๋ณต์ํ๋ Autoencoder๋ฅผ JAX๋ก ๊ตฌํํ์ฌ
**์ด์ ํ์ง (Anomaly Detection)**์ ํ์ฉํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํฉ๋๋ค.
JAX, Explainable AI, SHAP, ์๊ณ์ด ํด์, ๋ชจ๋ธ ์ค๋ช , Time Series Interpretation, Flax, ์๊ณ์ด ๋ถ๋ฅ, Feature Importance, AI ํด์ ๊ฐ๋ฅ์ฑ, Transformer ํด์, Python SHAP, Game Theory, Shapley Value, AI Transparency, Time Series XAI