반응형
1. rom 파일은 구글드라이버를 통해서 수동으로 파일을 올렸다.
2. 나머지는 경로만 잘 지정하면 되는것 같다.
3. 구글 드라이버를 통해서 파일저장 관리까지 아주편하다.
!pip install tensorflow==2.3.1 gym keras-rl2 gym[atari]
!pip install gym[atari]
!pip install atari-py
from google.colab import drive
drive.mount('/content/drive', force_remount=True)
!python -m atari_py.import_roms 'drive/MyDrive/rom'
import gym
import random
env = gym.make('SpaceInvaders-v0')
height, width, channels = env.observation_space.shape
actions = env.action_space.n
#env.unwrapped.get_action_meanings()
#episodes = 5
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Convolution2D
from tensorflow.keras.optimizers import Adam
tf.compat.v1.disable_eager_execution()
def build_model(height, width, channels, actions):
model = Sequential()
model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape=(3,height, width, channels)))
model.add(Convolution2D(64, (4,4), strides=(2,2), activation='relu'))
model.add(Convolution2D(64, (3,3), activation='relu'))
model.add(Flatten())
model.add(Dense(512, activation='relu'))
model.add(Dense(256, activation='relu'))
model.add(Dense(actions, activation='linear'))
return model
#del model
model = build_model(height, width, channels, actions)
model.summary()
from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy
def build_agent(model, actions):
policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
memory = SequentialMemory(limit=1000, window_length=3)
dqn = DQNAgent(model=model, memory=memory, policy=policy, enable_dueling_network=True, dueling_type='avg', nb_actions=actions, nb_steps_warmup=1000)
return dqn
dqn = build_agent(model, actions)
dqn.compile(Adam(lr=1e-4))
dqn.fit(env, nb_steps=10000, visualize=False, verbose=2)
dqn.save_weights('drive/MyDrive/Saved/dqn_weights.h5f')
#dqn.load_weights('drive/MyDrive/Saved/dqn_weights.h5f')
#scores = dqn.test(env, nb_episodes=10, visualize=True)
#print(np.mean(scores.history['episode_reward']))
del model, dqn
'AI 학습 > 강화학습' 카테고리의 다른 글
CartPole-v0 텐서플로우로 구현 (0) | 2021.07.30 |
---|---|
OpenAI GYM 아타리게임 학습과 테스트 (0) | 2021.07.29 |
OpenAI GYM 아타리게임 테스트 (0) | 2021.07.28 |
OpenAI GYM 아타리게임 설치 (0) | 2021.07.28 |
OpenAI GYM 아타리게임 설치 기본 (0) | 2021.07.28 |
댓글