ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ JAX๋ก ์๊ณ์ด ์์ธก Transformer ๊ตฌํ - ์ฅ๊ธฐ ํจํด๊น์ง ํ์ตํ๋ Self-Attention ๊ธฐ๋ฐ ๋ชจ๋ธ
octo54 2025. 5. 23. 11:39๐ JAX๋ก ์๊ณ์ด ์์ธก Transformer ๊ตฌํ - ์ฅ๊ธฐ ํจํด๊น์ง ํ์ตํ๋ Self-Attention ๊ธฐ๋ฐ ๋ชจ๋ธ
๐ ์ ์๊ณ์ด์ Transformer๋ฅผ ์ฌ์ฉํ ๊น?
์ ํต์ ์ธ RNN/LSTM์ ์ฅ๊ธฐ ์์กด์ฑ ๋ฌธ์ ๋ก ์ธํด ๊ณผ๊ฑฐ ์ ๋ณด๋ฅผ ๋ฉ๋ฆฌ ๋ฐ์ํ๊ธฐ ์ด๋ ต์ต๋๋ค.
๋ฐ๋ฉด Transformer ๋ชจ๋ธ์ Self-Attention ๋ฉ์ปค๋์ฆ์ ํตํด
์
๋ ฅ ์ํ์ค ๋ด์ ๋ชจ๋ ์์ ๊ฐ ๊ด๊ณ๋ฅผ ๋์ ๊ณ์ฐํ์ฌ,
๋ฉ๋ฆฌ ๋จ์ด์ง ์์ ๊ฐ์ ํจํด๊น์ง ํจ๊ณผ์ ์ผ๋ก ํ์ตํ ์ ์์ต๋๋ค.
๐ก 1. Transformer for Time Series - ๊ตฌ์กฐ ์์ฝ
๐ ๊ตฌ์ฑ ์์
- ํฌ์ง์ ๋ ์ธ์ฝ๋ฉ: ์๊ณ์ด ์์๋ฅผ ๋ฐ์
- Self-Attention Layer: ๋ชจ๋ ์์ ๊ฐ์ ์ํธ๊ด๊ณ ํ์ต
- Feed-Forward Network: ์ ๋ณด ๋ณํ
- Output Layer: ๋ฏธ๋ ์์ ์์ธก
๐ง 2. ๋ฐ์ดํฐ์ ์ค๋น
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
import jax.numpy as jnp
# ์์: ์ฃผ์ ๊ฐ๊ฒฉ ์๊ณ์ด
df = pd.read_csv('stock_prices.csv') # 'date', 'close' ์ปฌ๋ผ ๊ฐ์
data = df['close'].values.reshape(-1, 1)
# ์ ๊ทํ
scaler = MinMaxScaler()
data_scaled = scaler.fit_transform(data)
# ์๋์ฐ ์์ฑ
def make_window(data, window=30):
X, y = [], []
for i in range(len(data) - window):
X.append(data[i:i+window])
y.append(data[i+window])
return jnp.array(X), jnp.array(y)
x_data, y_data = make_window(data_scaled)
๐ง 3. Transformer ๋ชจ๋ธ ๊ตฌํ (Flax)
๐ฏ ํฌ์ง์ ๋ ์ธ์ฝ๋ฉ
def positional_encoding(seq_len, d_model):
pos = jnp.arange(seq_len)[:, jnp.newaxis]
i = jnp.arange(d_model)[jnp.newaxis, :]
angle_rates = 1 / jnp.power(10000, (2 * (i // 2)) / d_model)
angle_rads = pos * angle_rates
angle_rads = angle_rads.at[:, 0::2].set(jnp.sin(angle_rads[:, 0::2]))
angle_rads = angle_rads.at[:, 1::2].set(jnp.cos(angle_rads[:, 1::2]))
return angle_rads
๐งฑ Self-Attention Layer
from flax import linen as nn
class MultiHeadAttention(nn.Module):
d_model: int
num_heads: int
@nn.compact
def __call__(self, x):
head_dim = self.d_model // self.num_heads
assert self.d_model % self.num_heads == 0
qkv = nn.Dense(self.d_model * 3)(x)
q, k, v = jnp.split(qkv, 3, axis=-1)
q = q.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
k = k.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
v = v.reshape(x.shape[0], x.shape[1], self.num_heads, head_dim).transpose(0, 2, 1, 3)
attn = jnp.einsum('bhqd,bhkd->bhqk', q, k) / jnp.sqrt(head_dim)
attn_weights = nn.softmax(attn, axis=-1)
out = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)
out = out.transpose(0, 2, 1, 3).reshape(x.shape[0], x.shape[1], self.d_model)
return nn.Dense(self.d_model)(out)
๐งฑ Transformer Block
class TransformerBlock(nn.Module):
d_model: int
num_heads: int
ff_dim: int
@nn.compact
def __call__(self, x):
attn = MultiHeadAttention(self.d_model, self.num_heads)(x)
x = nn.LayerNorm()(x + attn)
ff = nn.Dense(self.ff_dim)(x)
ff = nn.relu(ff)
ff = nn.Dense(self.d_model)(ff)
x = nn.LayerNorm()(x + ff)
return x
๐ฎ 4. ์ ์ฒด Transformer ์๊ณ์ด ๋ชจ๋ธ
class TimeSeriesTransformer(nn.Module):
d_model: int = 64
num_heads: int = 4
ff_dim: int = 128
num_layers: int = 2
seq_len: int = 30
@nn.compact
def __call__(self, x):
pos_enc = positional_encoding(self.seq_len, self.d_model)
x = nn.Dense(self.d_model)(x)
x += pos_enc
for _ in range(self.num_layers):
x = TransformerBlock(self.d_model, self.num_heads, self.ff_dim)(x)
x = jnp.mean(x, axis=1) # Global average pooling
return nn.Dense(1)(x)
โ๏ธ 5. ์์ค ํจ์ ๋ฐ ํ์ต ๋ฃจํ
import optax
from flax.training import train_state
import jax
model = TimeSeriesTransformer()
key = jax.random.PRNGKey(0)
params = model.init(key, jnp.ones((1, 30, 1)))
tx = optax.adam(1e-3)
state = train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)
@jax.jit
def loss_fn(params, x, y):
pred = model.apply(params, x)
return jnp.mean((pred.squeeze() - y.squeeze()) ** 2)
@jax.jit
def train_step(state, x, y):
loss, grads = jax.value_and_grad(loss_fn)(state.params, x, y)
state = state.apply_gradients(grads=grads)
return state, loss
๐ 6. ํ์ต ์ํ
for epoch in range(10):
for i in range(0, len(x_data), 64):
x_batch = x_data[i:i+64][..., None]
y_batch = y_data[i:i+64]
state, loss = train_step(state, x_batch, y_batch)
print(f"Epoch {epoch+1}, Loss: {loss:.4f}")
โ 7. ๋ฏธ๋ ์์ธก ๋ฐ ์๊ฐํ
import matplotlib.pyplot as plt
def predict_next(state, last_seq, steps=30):
preds = []
seq = last_seq
for _ in range(steps):
pred = model.apply(state.params, seq[None, ..., None])
preds.append(pred.squeeze())
seq = jnp.concatenate([seq[1:], pred], axis=0)
return scaler.inverse_transform(jnp.array(preds).reshape(-1, 1))
# ์์ธก
future = predict_next(state, x_data[-1])
plt.plot(scaler.inverse_transform(data), label='์๋ณธ')
plt.plot(range(len(data), len(data) + 30), future, label='Transformer ์์ธก', color='red')
plt.legend()
plt.title("Transformer ๊ธฐ๋ฐ ์๊ณ์ด ์์ธก")
plt.show()
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก ์๊ณ์ด ์ด์ ํ์ง ๋ชจ๋ธ ๊ตฌํ (Autoencoder ๊ธฐ๋ฐ Anomaly Detection)
๋ค์ ๊ธ์์๋ JAX๋ฅผ ํ์ฉํ์ฌ Autoencoder ๊ธฐ๋ฐ ์ด์ ํ์ง ๋ชจ๋ธ์ ๊ตฌํํ์ฌ
์๊ณ์ด์์ ์ด์์น(Anomaly)๋ฅผ ์๋์ผ๋ก ํ์งํ๋ ๋ฐฉ๋ฒ์ ์๊ฐํ๊ฒ ์ต๋๋ค.
JAX, Transformer, ์๊ณ์ด ์์ธก, Time Series, Self-Attention, Multi-Head Attention, ๋ฅ๋ฌ๋, Python, ์๊ณ์ด ๋ถ์, ๊ณ ์ ์ฐ์ฐ, Flax, ์ฃผ์ ์์ธก, ๊ธฐ์จ ์์ธก, ๋ชจ๋ธ ํ์ต, ์ธ๊ณต์ง๋ฅ, ๊ณ ์ฑ๋ฅ ์ปดํจํ
'Programming > JAX' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
- Total
- Today
- Yesterday
- Next.js
- SEO ์ต์ ํ
- CI/CD
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- kotlin
- Prisma
- gatsbyjs
- Python
- Ktor
- llm
- REACT
- ์น๊ฐ๋ฐ
- nextJS
- ํ๋ก ํธ์๋๋ฉด์
- ํ๋ก ํธ์๋
- seo ์ต์ ํ 10๊ฐ
- NestJS
- ๋ฅ๋ฌ๋
- AI์ฑ๋ด
- PostgreSQL
- rag
- Webpack
- ๋ฐฑ์๋๊ฐ๋ฐ
- App Router
- nodejs
- fastapi
- SEO์ต์ ํ
- Docker
- ํ์ด์ฌ ์๊ณ ๋ฆฌ์ฆ
- JAX
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | 4 | 5 | 6 | 7 |
8 | 9 | 10 | 11 | 12 | 13 | 14 |
15 | 16 | 17 | 18 | 19 | 20 | 21 |
22 | 23 | 24 | 25 | 26 | 27 | 28 |
29 | 30 |