본문 바로가기
AI 학습/강화학습

CartPole-v0 텐서플로우로 구현

by 오징어땅콩2 2021. 7. 30.
반응형

정리 안되어 있음, 단지 참고만

 

import gym
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Convolution2D
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activation
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.datasets import mnist

from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy




def build_model(input_size, nb_actions):

    model = Sequential()
    model.add(Flatten(input_shape=(1,) + env.observation_space.shape)) 
    model.add(Dense(256, activation='relu'))
    #model.add(Dense(256, activation='relu'))
    model.add(Dense(nb_actions, activation='linear'))
    model.summary()

    return model
    



def build_agent(model, nb_actions):
    #policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
    policy = EpsGreedyQPolicy()
    memory = SequentialMemory(limit=50000, window_length=1)
    #dqn = DQNAgent(model=model, memory=memory, policy=policy, target_model_update=1e-2, nb_actions=nb_actions, nb_steps_warmup=10)
    dqn = DQNAgent(model=model, memory=memory, policy=policy, enable_dueling_network=True, target_model_update=1e-2, nb_actions=nb_actions, nb_steps_warmup=10)

    return dqn


env = gym.make('CartPole-v0')
input_size = env.observation_space.shape[0] # 4
nb_actions = env.action_space.n 
#np.random.seed(123) 
#env.seed(123) 



model = build_model(input_size, nb_actions)
dqn = build_agent(model, nb_actions)
dqn.compile(Adam(lr=1e-4))

dqn.load_weights('SavedWeights/dqn/dqn_weights.h5f')
dqn.fit(env, nb_steps=5000, visualize=True, verbose=2)
dqn.save_weights('SavedWeights/dqn/dqn_weights.h5f', overwrite=True)


scores = dqn.test(env, nb_episodes=10, visualize=True)
print(np.mean(scores.history['episode_reward']))

댓글