ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ ์‹œ๊ณ„์—ด ์˜ˆ์ธก Transformer ๊ตฌํ˜„ - ์žฅ๊ธฐ ํŒจํ„ด๊นŒ์ง€ ํ•™์Šตํ•˜๋Š” Self-Attention ๊ธฐ๋ฐ˜ ๋ชจ๋ธ


๐Ÿš€ ์™œ ์‹œ๊ณ„์—ด์— Transformer๋ฅผ ์‚ฌ์šฉํ• ๊นŒ?

์ „ํ†ต์ ์ธ RNN/LSTM์€ ์žฅ๊ธฐ ์˜์กด์„ฑ ๋ฌธ์ œ๋กœ ์ธํ•ด ๊ณผ๊ฑฐ ์ •๋ณด๋ฅผ ๋ฉ€๋ฆฌ ๋ฐ˜์˜ํ•˜๊ธฐ ์–ด๋ ต์Šต๋‹ˆ๋‹ค.
๋ฐ˜๋ฉด Transformer ๋ชจ๋ธ์€ Self-Attention ๋ฉ”์ปค๋‹ˆ์ฆ˜์„ ํ†ตํ•ด
์ž…๋ ฅ ์‹œํ€€์Šค ๋‚ด์˜ ๋ชจ๋“  ์‹œ์  ๊ฐ„ ๊ด€๊ณ„๋ฅผ ๋™์‹œ ๊ณ„์‚ฐํ•˜์—ฌ,
๋ฉ€๋ฆฌ ๋–จ์–ด์ง„ ์‹œ์  ๊ฐ„์˜ ํŒจํ„ด๊นŒ์ง€ ํšจ๊ณผ์ ์œผ๋กœ ํ•™์Šตํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.


๐Ÿ’ก 1. Transformer for Time Series - ๊ตฌ์กฐ ์š”์•ฝ

๐Ÿ“ ๊ตฌ์„ฑ ์š”์†Œ

  • ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ: ์‹œ๊ณ„์—ด ์ˆœ์„œ๋ฅผ ๋ฐ˜์˜
  • Self-Attention Layer: ๋ชจ๋“  ์‹œ์  ๊ฐ„์˜ ์ƒํ˜ธ๊ด€๊ณ„ ํ•™์Šต
  • Feed-Forward Network: ์ •๋ณด ๋ณ€ํ™˜
  • Output Layer: ๋ฏธ๋ž˜ ์‹œ์  ์˜ˆ์ธก

๐Ÿ”ง 2. ๋ฐ์ดํ„ฐ์…‹ ์ค€๋น„

import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import jax.numpy as jnp

# ์˜ˆ์‹œ: ์ฃผ์‹ ๊ฐ€๊ฒฉ ์‹œ๊ณ„์—ด
df = pd.read_csv('stock_prices.csv')  # 'date', 'close' ์ปฌ๋Ÿผ ๊ฐ€์ •
data = df['close'].values.reshape(-1, 1)

# ์ •๊ทœํ™”
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)

# ์œˆ๋„์šฐ ์ƒ์„ฑ
def make_window(data, window=30):
    X, y = [], []
    for i in range(len(data) - window):
        X.append(data[i:i+window])
        y.append(data[i+window])
    return jnp.array(X), jnp.array(y)

x_data, y_data = make_window(data_scaled)

๐Ÿง  3. Transformer ๋ชจ๋ธ ๊ตฌํ˜„ (Flax)

๐ŸŽฏ ํฌ์ง€์…”๋„ ์ธ์ฝ”๋”ฉ

def positional_encoding(seq_len, d_model):
    pos = jnp.arange(seq_len)[:, jnp.newaxis]
    i = jnp.arange(d_model)[jnp.newaxis, :]
    angle_rates = 1 / jnp.power(10000, (2 * (i // 2)) / d_model)
    angle_rads = pos * angle_rates
    angle_rads = angle_rads.at[:, 0::2].set(jnp.sin(angle_rads[:, 0::2]))
    angle_rads = angle_rads.at[:, 1::2].set(jnp.cos(angle_rads[:, 1::2]))
    return angle_rads

๐Ÿงฑ Self-Attention Layer

from flax import linen as nn

class MultiHeadAttention(nn.Module):
    d_model: int
    num_heads: int

    @nn.compact
    def __call__(self, x):
        head_dim = self.d_model // self.num_heads
        assert self.d_model % self.num_heads == 0

        qkv = nn.Dense(self.d_model * 3)(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        q = q.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)

        attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
        attn_weights = nn.softmax(attn, axis=-1)
        out = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
        out = out.transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)

        return nn.Dense(self.d_model)(out)

๐Ÿงฑ Transformer Block

๋ฐ˜์‘ํ˜•
class TransformerBlock(nn.Module):
    d_model: int
    num_heads: int
    ff_dim: int

    @nn.compact
    def __call__(self, x):
        attn = MultiHeadAttention(self.d_model, self.num_heads)(x)
        x = nn.LayerNorm()(x + attn)

        ff = nn.Dense(self.ff_dim)(x)
        ff = nn.relu(ff)
        ff = nn.Dense(self.d_model)(ff)
        x = nn.LayerNorm()(x + ff)
        return x

๐Ÿ”ฎ 4. ์ „์ฒด Transformer ์‹œ๊ณ„์—ด ๋ชจ๋ธ

class TimeSeriesTransformer(nn.Module):
    d_model: int = 64
    num_heads: int = 4
    ff_dim: int = 128
    num_layers: int = 2
    seq_len: int = 30

    @nn.compact
    def __call__(self, x):
        pos_enc = positional_encoding(self.seq_len, self.d_model)
        x = nn.Dense(self.d_model)(x)
        x += pos_enc

        for _ in range(self.num_layers):
            x = TransformerBlock(self.d_model, self.num_heads, self.ff_dim)(x)

        x = jnp.mean(x, axis=1)  # Global average pooling
        return nn.Dense(1)(x)

โš™๏ธ 5. ์†์‹ค ํ•จ์ˆ˜ ๋ฐ ํ•™์Šต ๋ฃจํ”„

import optax
from flax.training import train_state
import jax

model = TimeSeriesTransformer()
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 30, 1)))
tx = optax.adam(1e-3)

state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

@jax.jit
def loss_fn(params, x, y):
    pred = model.apply(params, x)
    return jnp.mean((pred.squeeze() - y.squeeze()) ** 2)

@jax.jit
def train_step(state, x, y):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
    state = state.apply_gradients(grads=grads)
    return state, loss

๐Ÿƒ 6. ํ•™์Šต ์ˆ˜ํ–‰

for epoch in range(10):
    for i in range(0, len(x_data), 64):
        x_batch = x_data[i:i+64][..., None]
        y_batch = y_data[i:i+64]
        state, loss = train_step(state, x_batch, y_batch)
    print(f"Epoch {epoch+1}, Loss: {loss:.4f}")

โœ… 7. ๋ฏธ๋ž˜ ์˜ˆ์ธก ๋ฐ ์‹œ๊ฐํ™”

import matplotlib.pyplot as plt

def predict_next(state, last_seq, steps=30):
    preds = []
    seq = last_seq
    for _ in range(steps):
        pred = model.apply(state.params, seq[None, ..., None])
        preds.append(pred.squeeze())
        seq = jnp.concatenate([seq[1:], pred], axis=0)
    return scaler.inverse_transform(jnp.array(preds).reshape(-1, 1))

# ์˜ˆ์ธก
future = predict_next(state, x_data[-1])
plt.plot(scaler.inverse_transform(data), label='์›๋ณธ')
plt.plot(range(len(data), len(data) + 30), future, label='Transformer ์˜ˆ์ธก', color='red')
plt.legend()
plt.title("Transformer ๊ธฐ๋ฐ˜ ์‹œ๊ณ„์—ด ์˜ˆ์ธก")
plt.show()

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ ์‹œ๊ณ„์—ด ์ด์ƒ ํƒ์ง€ ๋ชจ๋ธ ๊ตฌํ˜„ (Autoencoder ๊ธฐ๋ฐ˜ Anomaly Detection)

๋‹ค์Œ ๊ธ€์—์„œ๋Š” JAX๋ฅผ ํ™œ์šฉํ•˜์—ฌ Autoencoder ๊ธฐ๋ฐ˜ ์ด์ƒ ํƒ์ง€ ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•˜์—ฌ
์‹œ๊ณ„์—ด์—์„œ ์ด์ƒ์น˜(Anomaly)๋ฅผ ์ž๋™์œผ๋กœ ํƒ์ง€ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์†Œ๊ฐœํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค.


 

JAX, Transformer, ์‹œ๊ณ„์—ด ์˜ˆ์ธก, Time Series, Self-Attention, Multi-Head Attention, ๋”ฅ๋Ÿฌ๋‹, Python, ์‹œ๊ณ„์—ด ๋ถ„์„, ๊ณ ์† ์—ฐ์‚ฐ, Flax, ์ฃผ์‹ ์˜ˆ์ธก, ๊ธฐ์˜จ ์˜ˆ์ธก, ๋ชจ๋ธ ํ•™์Šต, ์ธ๊ณต์ง€๋Šฅ, ๊ณ ์„ฑ๋Šฅ ์ปดํ“จํŒ…

โ€ป ์ด ํฌ์ŠคํŒ…์€ ์ฟ ํŒก ํŒŒํŠธ๋„ˆ์Šค ํ™œ๋™์˜ ์ผํ™˜์œผ๋กœ, ์ด์— ๋”ฐ๋ฅธ ์ผ์ •์•ก์˜ ์ˆ˜์ˆ˜๋ฃŒ๋ฅผ ์ œ๊ณต๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ณต์ง€์‚ฌํ•ญ
์ตœ๊ทผ์— ์˜ฌ๋ผ์˜จ ๊ธ€
์ตœ๊ทผ์— ๋‹ฌ๋ฆฐ ๋Œ“๊ธ€
Total
Today
Yesterday
๋งํฌ
ยซ   2025/06   ยป
์ผ ์›” ํ™” ์ˆ˜ ๋ชฉ ๊ธˆ ํ† 
1 2 3 4 5 6 7
8 9 10 11 12 13 14
15 16 17 18 19 20 21
22 23 24 25 26 27 28
29 30
๊ธ€ ๋ณด๊ด€ํ•จ
๋ฐ˜์‘ํ˜•