ํ‹ฐ์Šคํ† ๋ฆฌ ๋ทฐ

๋ฐ˜์‘ํ˜•

๐Ÿ“Œ JAX๋กœ ๊ฐ•ํ™” ํ•™์Šต ๋ชจ๋ธ ๊ตฌ์ถ• - Q-learning์„ ์ด์šฉํ•œ CartPole ๋ฌธ์ œ ํ•ด๊ฒฐ


๐Ÿš€ ๊ฐ•ํ™” ํ•™์Šต์ด๋ž€?

๊ฐ•ํ™” ํ•™์Šต(RL)์€ **์—์ด์ „ํŠธ(Agent)**๊ฐ€ ํ™˜๊ฒฝ(Environment)๊ณผ ์ƒํ˜ธ์ž‘์šฉํ•˜๋ฉฐ, ๋ณด์ƒ(Reward)์„ ์ตœ๋Œ€ํ™”ํ•˜๋Š” ์ •์ฑ…(Policy)์„ ํ•™์Šตํ•˜๋Š” ๊ธฐ๋ฒ•์ž…๋‹ˆ๋‹ค.
์ด๋ฒˆ ๊ธ€์—์„œ๋Š” JAX๋ฅผ ํ™œ์šฉํ•˜์—ฌ Q-learning ๊ธฐ๋ฐ˜์˜ ๊ฐ•ํ™” ํ•™์Šต ๋ชจ๋ธ์„ ๊ตฌ์ถ•ํ•˜์—ฌ CartPole ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


๐Ÿ’ก 1. Q-learning์˜ ํ•ต์‹ฌ ๊ฐœ๋…

๐Ÿ“ Q-learning ์•Œ๊ณ ๋ฆฌ์ฆ˜ ๊ตฌ์กฐ

  • Q ํ•จ์ˆ˜: ์ƒํƒœ-ํ–‰๋™ ์Œ (s, a)์˜ ๊ฐ€์น˜๋ฅผ ์ถ”์ •ํ•˜๋Š” ํ•จ์ˆ˜
  • ๋ชฉํ‘œ: Q๊ฐ’์„ ์—…๋ฐ์ดํŠธํ•˜์—ฌ ์ตœ์  ์ •์ฑ…์„ ํ•™์Šต
  • Bellman Equation:

Q(s,a)=Q(s,a)+α×[r+γ×maxโกa′Q(s′,a′)−Q(s,a)]Q(s, a) = Q(s, a) + \alpha \times [r + \gamma \times \max_a' Q(s', a') - Q(s, a)]

  • α\alpha: ํ•™์Šต๋ฅ  (Learning Rate)
  • γ\gamma: ํ• ์ธ์œจ (Discount Factor)
  • rr: ๋ณด์ƒ

๐Ÿ”ง 2. ํ™˜๊ฒฝ ์„ค์ •

๐Ÿ“ฅ ํ•„์ˆ˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ์„ค์น˜

pip install gym
pip install jax jaxlib
pip install optax  # JAX ์ตœ์ ํ™” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

๐Ÿ—บ๏ธ ํ™˜๊ฒฝ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ (Gym)

๋ฐ˜์‘ํ˜•
import gym

# CartPole ํ™˜๊ฒฝ ์„ค์ •
env = gym.make("CartPole-v1")
state = env.reset()
print(f"์ดˆ๊ธฐ ์ƒํƒœ: {state}")

๐Ÿ“ 3. Q ๋„คํŠธ์›Œํฌ ๋ชจ๋ธ ์ •์˜

๐Ÿง  ์‹ ๊ฒฝ๋ง ๊ตฌ์กฐ

  • ์ž…๋ ฅ: ์ƒํƒœ ๋ฒกํ„ฐ (4์ฐจ์›)
  • ์ถœ๋ ฅ: ํ–‰๋™ ๊ฐ€์น˜ (2์ฐจ์›: ์ขŒ/์šฐ ์ด๋™)
  • ์€๋‹‰์ธต: 128๊ฐœ์˜ ๋…ธ๋“œ, ReLU ํ™œ์„ฑํ™”
  • ์ถœ๋ ฅ์ธต: ์ƒํƒœ-ํ–‰๋™ ๊ฐ€์น˜
import jax
import jax.numpy as jnp
from jax import random
import optax  # JAX์šฉ ์ตœ์ ํ™” ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ

# ๋„คํŠธ์›Œํฌ ์ดˆ๊ธฐํ™” ํ•จ์ˆ˜
def init_params(key, input_dim, hidden_dim, output_dim):
    key1, key2 = random.split(key)
    w1 = random.normal(key1, (input_dim, hidden_dim)) * 0.1
    b1 = jnp.zeros(hidden_dim)
    w2 = random.normal(key2, (hidden_dim, output_dim)) * 0.1
    b2 = jnp.zeros(output_dim)
    return (w1, b1), (w2, b2)

# Q ๋„คํŠธ์›Œํฌ
def q_network(params, x):
    (w1, b1), (w2, b2) = params
    hidden = jnp.tanh(jnp.dot(x, w1) + b1)  # ReLU ๋Œ€์‹  Tanh ์‚ฌ์šฉ
    q_values = jnp.dot(hidden, w2) + b2
    return q_values

๐Ÿ“‰ 4. ์†์‹ค ํ•จ์ˆ˜์™€ ์ตœ์ ํ™”

๐Ÿงฎ ์†์‹ค ํ•จ์ˆ˜ (Mean Squared Error)

def loss_fn(params, state, action, target):
    q_values = q_network(params, state)
    q_value = q_values[action]
    return jnp.mean((target - q_value) ** 2)

๐Ÿ”ง ์˜ตํ‹ฐ๋งˆ์ด์ € ์„ค์ •

optimizer = optax.adam(learning_rate=0.001)

# ์ตœ์ ํ™” ์ƒํƒœ ์ดˆ๊ธฐํ™”
@jax.jit
def init_optim_state(params):
    return optimizer.init(params)

# ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ ํ•จ์ˆ˜
@jax.jit
def update(params, opt_state, state, action, target):
    loss, grads = jax.value_and_grad(loss_fn)(params, state, action, target)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

๐Ÿ” 5. ํ•™์Šต ๋ฃจํ”„ ์ •์˜

๐ŸŒŸ ํ•˜์ดํผํŒŒ๋ผ๋ฏธํ„ฐ ์„ค์ •

epsilon = 0.1
gamma = 0.99
episodes = 300
batch_size = 32
buffer_size = 10000
learning_starts = 1000
target_update_freq = 50

๐Ÿ—ƒ๏ธ ๊ฒฝํ—˜ ๋ฆฌํ”Œ๋ ˆ์ด ๋ฒ„ํผ

import collections
import random

# ๊ฒฝํ—˜ ๋ฒ„ํผ ์ •์˜
buffer = collections.deque(maxlen=buffer_size)

def store_transition(state, action, reward, next_state, done):
    buffer.append((state, action, reward, next_state, done))

def sample_batch(batch_size):
    batch = random.sample(buffer, batch_size)
    states, actions, rewards, next_states, dones = zip(*batch)
    return jnp.array(states), jnp.array(actions), jnp.array(rewards), jnp.array(next_states), jnp.array(dones)

๐Ÿƒ 6. ๋ชจ๋ธ ํ•™์Šต

๐Ÿš€ ํ•™์Šต ๋ฃจํ”„

key = random.PRNGKey(42)
params = init_params(key, 4, 128, 2)
opt_state = init_optim_state(params)

for episode in range(episodes):
    state = env.reset()
    episode_reward = 0

    for t in range(200):
        # Epsilon-greedy ์ •์ฑ…
        if random.random() < epsilon:
            action = env.action_space.sample()
        else:
            q_values = q_network(params, jnp.array(state))
            action = int(jnp.argmax(q_values))

        # ํ–‰๋™ ์ˆ˜ํ–‰
        next_state, reward, done, _ = env.step(action)
        store_transition(state, action, reward, next_state, done)

        # ํ•™์Šต ์‹œ์ž‘ ์กฐ๊ฑด
        if len(buffer) > learning_starts:
            states, actions, rewards, next_states, dones = sample_batch(batch_size)

            # Q ํƒ€๊ฒŸ ๊ณ„์‚ฐ
            next_q_values = q_network(params, next_states)
            targets = rewards + gamma * jnp.max(next_q_values, axis=1) * (1 - dones)

            # ํŒŒ๋ผ๋ฏธํ„ฐ ์—…๋ฐ์ดํŠธ
            params, opt_state, loss = update(params, opt_state, states, actions, targets)

        state = next_state
        episode_reward += reward

        if done:
            break

    print(f"Episode {episode + 1}, Reward: {episode_reward}")

    # ํƒ€๊ฒŸ ๋„คํŠธ์›Œํฌ ์—…๋ฐ์ดํŠธ
    if episode % target_update_freq == 0:
        target_params = params

โœ… 7. ํ…Œ์ŠคํŠธ์™€ ์„ฑ๋Šฅ ํ‰๊ฐ€

def evaluate(env, params, episodes=10):
    total_reward = 0
    for _ in range(episodes):
        state = env.reset()
        episode_reward = 0
        done = False
        while not done:
            q_values = q_network(params, jnp.array(state))
            action = int(jnp.argmax(q_values))
            state, reward, done, _ = env.step(action)
            episode_reward += reward
        total_reward += episode_reward
    return total_reward / episodes

avg_reward = evaluate(env, params)
print(f"ํ‰๊ท  ํ…Œ์ŠคํŠธ ๋ณด์ƒ: {avg_reward}")

๐Ÿ“Œ ๋‹ค์Œ ๊ธ€ ์˜ˆ๊ณ : JAX๋กœ Transformer ๋ชจ๋ธ ๊ตฌํ˜„ํ•˜๊ธฐ

๋‹ค์Œ ๊ธ€์—์„œ๋Š” JAX๋ฅผ ํ™œ์šฉํ•˜์—ฌ Transformer ๋ชจ๋ธ์„ ๊ตฌํ˜„ํ•˜๊ณ ,
์ž์—ฐ์–ด ์ฒ˜๋ฆฌ(NLP) ๋ฌธ์ œ๋ฅผ ํ•ด๊ฒฐํ•˜๋Š” ๋ฐ ์ ์šฉํ•ด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค.


 

JAX, ๊ฐ•ํ™” ํ•™์Šต, Q-learning, CartPole, ๋”ฅ๋Ÿฌ๋‹, Python, GPU ํ•™์Šต, ์—์ด์ „ํŠธ ํ•™์Šต, ๊ณ ์† ์—ฐ์‚ฐ, ๋ชจ๋ธ ํ•™์Šต, ๋ฐ์ดํ„ฐ ์ „์ฒ˜๋ฆฌ, ๊ฐ•ํ™” ํ•™์Šต ๋ชจ๋ธ, ๊ณ ์„ฑ๋Šฅ ์ปดํ“จํŒ…, Gym ํ™˜๊ฒฝ, ๋ชจ๋ธ ํ‰๊ฐ€

โ€ป ์ด ํฌ์ŠคํŒ…์€ ์ฟ ํŒก ํŒŒํŠธ๋„ˆ์Šค ํ™œ๋™์˜ ์ผํ™˜์œผ๋กœ, ์ด์— ๋”ฐ๋ฅธ ์ผ์ •์•ก์˜ ์ˆ˜์ˆ˜๋ฃŒ๋ฅผ ์ œ๊ณต๋ฐ›์Šต๋‹ˆ๋‹ค.
๊ณต์ง€์‚ฌํ•ญ
์ตœ๊ทผ์— ์˜ฌ๋ผ์˜จ ๊ธ€
์ตœ๊ทผ์— ๋‹ฌ๋ฆฐ ๋Œ“๊ธ€
Total
Today
Yesterday
๋งํฌ
ยซ   2025/05   ยป
์ผ ์›” ํ™” ์ˆ˜ ๋ชฉ ๊ธˆ ํ† 
1 2 3
4 5 6 7 8 9 10
11 12 13 14 15 16 17
18 19 20 21 22 23 24
25 26 27 28 29 30 31
๊ธ€ ๋ณด๊ด€ํ•จ
๋ฐ˜์‘ํ˜•