ํฐ์คํ ๋ฆฌ ๋ทฐ
๐ JAX๋ก ๊ฐํ ํ์ต ๋ชจ๋ธ ๊ตฌ์ถ - Q-learning์ ์ด์ฉํ CartPole ๋ฌธ์ ํด๊ฒฐ
octo54 2025. 5. 13. 12:10๐ JAX๋ก ๊ฐํ ํ์ต ๋ชจ๋ธ ๊ตฌ์ถ - Q-learning์ ์ด์ฉํ CartPole ๋ฌธ์ ํด๊ฒฐ
๐ ๊ฐํ ํ์ต์ด๋?
๊ฐํ ํ์ต(RL)์ **์์ด์ ํธ(Agent)**๊ฐ ํ๊ฒฝ(Environment)๊ณผ ์ํธ์์ฉํ๋ฉฐ, ๋ณด์(Reward)์ ์ต๋ํํ๋ ์ ์ฑ
(Policy)์ ํ์ตํ๋ ๊ธฐ๋ฒ์
๋๋ค.
์ด๋ฒ ๊ธ์์๋ JAX๋ฅผ ํ์ฉํ์ฌ Q-learning ๊ธฐ๋ฐ์ ๊ฐํ ํ์ต ๋ชจ๋ธ์ ๊ตฌ์ถํ์ฌ CartPole ๋ฌธ์ ๋ฅผ ํด๊ฒฐํด๋ณด๊ฒ ์ต๋๋ค.
๐ก 1. Q-learning์ ํต์ฌ ๊ฐ๋
๐ Q-learning ์๊ณ ๋ฆฌ์ฆ ๊ตฌ์กฐ
- Q ํจ์: ์ํ-ํ๋ ์ (s, a)์ ๊ฐ์น๋ฅผ ์ถ์ ํ๋ ํจ์
- ๋ชฉํ: Q๊ฐ์ ์ ๋ฐ์ดํธํ์ฌ ์ต์ ์ ์ฑ ์ ํ์ต
- Bellman Equation:
Q(s,a)=Q(s,a)+α×[r+γ×maxโกa′Q(s′,a′)−Q(s,a)]Q(s, a) = Q(s, a) + \alpha \times [r + \gamma \times \max_a' Q(s', a') - Q(s, a)]
- α\alpha: ํ์ต๋ฅ (Learning Rate)
- γ\gamma: ํ ์ธ์จ (Discount Factor)
- rr: ๋ณด์
๐ง 2. ํ๊ฒฝ ์ค์
๐ฅ ํ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ค์น
pip install gym
pip install jax jaxlib
pip install optax # JAX ์ต์ ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
๐บ๏ธ ํ๊ฒฝ ๋ถ๋ฌ์ค๊ธฐ (Gym)
import gym
# CartPole ํ๊ฒฝ ์ค์
env = gym.make("CartPole-v1")
state = env.reset()
print(f"์ด๊ธฐ ์ํ: {state}")
๐ 3. Q ๋คํธ์ํฌ ๋ชจ๋ธ ์ ์
๐ง ์ ๊ฒฝ๋ง ๊ตฌ์กฐ
- ์ ๋ ฅ: ์ํ ๋ฒกํฐ (4์ฐจ์)
- ์ถ๋ ฅ: ํ๋ ๊ฐ์น (2์ฐจ์: ์ข/์ฐ ์ด๋)
- ์๋์ธต: 128๊ฐ์ ๋ ธ๋, ReLU ํ์ฑํ
- ์ถ๋ ฅ์ธต: ์ํ-ํ๋ ๊ฐ์น
import jax
import jax.numpy as jnp
from jax import random
import optax # JAX์ฉ ์ต์ ํ ๋ผ์ด๋ธ๋ฌ๋ฆฌ
# ๋คํธ์ํฌ ์ด๊ธฐํ ํจ์
def init_params(key, input_dim, hidden_dim, output_dim):
key1, key2 = random.split(key)
w1 = random.normal(key1, (input_dim, hidden_dim)) * 0.1
b1 = jnp.zeros(hidden_dim)
w2 = random.normal(key2, (hidden_dim, output_dim)) * 0.1
b2 = jnp.zeros(output_dim)
return (w1, b1), (w2, b2)
# Q ๋คํธ์ํฌ
def q_network(params, x):
(w1, b1), (w2, b2) = params
hidden = jnp.tanh(jnp.dot(x, w1) + b1) # ReLU ๋์ Tanh ์ฌ์ฉ
q_values = jnp.dot(hidden, w2) + b2
return q_values
๐ 4. ์์ค ํจ์์ ์ต์ ํ
๐งฎ ์์ค ํจ์ (Mean Squared Error)
def loss_fn(params, state, action, target):
q_values = q_network(params, state)
q_value = q_values[action]
return jnp.mean((target - q_value) ** 2)
๐ง ์ตํฐ๋ง์ด์ ์ค์
optimizer = optax.adam(learning_rate=0.001)
# ์ต์ ํ ์ํ ์ด๊ธฐํ
@jax.jit
def init_optim_state(params):
return optimizer.init(params)
# ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ ํจ์
@jax.jit
def update(params, opt_state, state, action, target):
loss, grads = jax.value_and_grad(loss_fn)(params, state, action, target)
updates, opt_state = optimizer.update(grads, opt_state)
new_params = optax.apply_updates(params, updates)
return new_params, opt_state, loss
๐ 5. ํ์ต ๋ฃจํ ์ ์
๐ ํ์ดํผํ๋ผ๋ฏธํฐ ์ค์
epsilon = 0.1
gamma = 0.99
episodes = 300
batch_size = 32
buffer_size = 10000
learning_starts = 1000
target_update_freq = 50
๐๏ธ ๊ฒฝํ ๋ฆฌํ๋ ์ด ๋ฒํผ
import collections
import random
# ๊ฒฝํ ๋ฒํผ ์ ์
buffer = collections.deque(maxlen=buffer_size)
def store_transition(state, action, reward, next_state, done):
buffer.append((state, action, reward, next_state, done))
def sample_batch(batch_size):
batch = random.sample(buffer, batch_size)
states, actions, rewards, next_states, dones = zip(*batch)
return jnp.array(states), jnp.array(actions), jnp.array(rewards), jnp.array(next_states), jnp.array(dones)
๐ 6. ๋ชจ๋ธ ํ์ต
๐ ํ์ต ๋ฃจํ
key = random.PRNGKey(42)
params = init_params(key, 4, 128, 2)
opt_state = init_optim_state(params)
for episode in range(episodes):
state = env.reset()
episode_reward = 0
for t in range(200):
# Epsilon-greedy ์ ์ฑ
if random.random() < epsilon:
action = env.action_space.sample()
else:
q_values = q_network(params, jnp.array(state))
action = int(jnp.argmax(q_values))
# ํ๋ ์ํ
next_state, reward, done, _ = env.step(action)
store_transition(state, action, reward, next_state, done)
# ํ์ต ์์ ์กฐ๊ฑด
if len(buffer) > learning_starts:
states, actions, rewards, next_states, dones = sample_batch(batch_size)
# Q ํ๊ฒ ๊ณ์ฐ
next_q_values = q_network(params, next_states)
targets = rewards + gamma * jnp.max(next_q_values, axis=1) * (1 - dones)
# ํ๋ผ๋ฏธํฐ ์
๋ฐ์ดํธ
params, opt_state, loss = update(params, opt_state, states, actions, targets)
state = next_state
episode_reward += reward
if done:
break
print(f"Episode {episode + 1}, Reward: {episode_reward}")
# ํ๊ฒ ๋คํธ์ํฌ ์
๋ฐ์ดํธ
if episode % target_update_freq == 0:
target_params = params
โ 7. ํ ์คํธ์ ์ฑ๋ฅ ํ๊ฐ
def evaluate(env, params, episodes=10):
total_reward = 0
for _ in range(episodes):
state = env.reset()
episode_reward = 0
done = False
while not done:
q_values = q_network(params, jnp.array(state))
action = int(jnp.argmax(q_values))
state, reward, done, _ = env.step(action)
episode_reward += reward
total_reward += episode_reward
return total_reward / episodes
avg_reward = evaluate(env, params)
print(f"ํ๊ท ํ
์คํธ ๋ณด์: {avg_reward}")
๐ ๋ค์ ๊ธ ์๊ณ : JAX๋ก Transformer ๋ชจ๋ธ ๊ตฌํํ๊ธฐ
๋ค์ ๊ธ์์๋ JAX๋ฅผ ํ์ฉํ์ฌ Transformer ๋ชจ๋ธ์ ๊ตฌํํ๊ณ ,
์์ฐ์ด ์ฒ๋ฆฌ(NLP) ๋ฌธ์ ๋ฅผ ํด๊ฒฐํ๋ ๋ฐ ์ ์ฉํด๋ณด๊ฒ ์ต๋๋ค.
JAX, ๊ฐํ ํ์ต, Q-learning, CartPole, ๋ฅ๋ฌ๋, Python, GPU ํ์ต, ์์ด์ ํธ ํ์ต, ๊ณ ์ ์ฐ์ฐ, ๋ชจ๋ธ ํ์ต, ๋ฐ์ดํฐ ์ ์ฒ๋ฆฌ, ๊ฐํ ํ์ต ๋ชจ๋ธ, ๊ณ ์ฑ๋ฅ ์ปดํจํ , Gym ํ๊ฒฝ, ๋ชจ๋ธ ํ๊ฐ
'Programming > Python' ์นดํ ๊ณ ๋ฆฌ์ ๋ค๋ฅธ ๊ธ
- Total
- Today
- Yesterday
- ๊ด๋ฆฌ์
- ์น๊ฐ๋ฐ
- SEO์ต์ ํ
- ํ์ด์ฌ ์๊ณ ๋ฆฌ์ฆ
- Next.js
- Docker
- Ktor
- ๊ฐ๋ฐ๋ธ๋ก๊ทธ
- LangChain
- CI/CD
- ํ๋ก ํธ์๋๋ฉด์
- github
- kotlin
- nextJS
- NestJS
- ๋ฐฑ์๋๊ฐ๋ฐ
- fastapi
- Webpack
- llm
- AI์ฑ๋ด
- Python
- REACT
- nodejs
- App Router
- seo ์ต์ ํ 10๊ฐ
- Prisma
- PostgreSQL
- gatsbyjs
- ํ๋ก ํธ์๋
- rag
์ผ | ์ | ํ | ์ | ๋ชฉ | ๊ธ | ํ |
---|---|---|---|---|---|---|
1 | 2 | 3 | ||||
4 | 5 | 6 | 7 | 8 | 9 | 10 |
11 | 12 | 13 | 14 | 15 | 16 | 17 |
18 | 19 | 20 | 21 | 22 | 23 | 24 |
25 | 26 | 27 | 28 | 29 | 30 | 31 |