본문 바로가기
AI 학습/강화학습

OpenAI GYM 아타리게임 코렙에서 돌리기

by 오징어땅콩2 2021. 8. 11.
반응형

 

1. rom 파일은 구글드라이버를 통해서 수동으로 파일을 올렸다. 

2. 나머지는 경로만 잘 지정하면 되는것 같다.

3. 구글 드라이버를 통해서 파일저장 관리까지 아주편하다.

 

 

!pip install tensorflow==2.3.1 gym keras-rl2 gym[atari]
!pip install gym[atari]
!pip install atari-py

from google.colab import drive
drive.mount('/content/drive', force_remount=True)

!python -m atari_py.import_roms 'drive/MyDrive/rom'

 

 


import gym
import random


env = gym.make('SpaceInvaders-v0')
height, width, channels = env.observation_space.shape
actions = env.action_space.n


#env.unwrapped.get_action_meanings()
#episodes = 5

import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Convolution2D
from tensorflow.keras.optimizers import Adam
tf.compat.v1.disable_eager_execution()


def build_model(height, width, channels, actions):
    model = Sequential()
    model.add(Convolution2D(32, (8,8), strides=(4,4), activation='relu', input_shape=(3,height, width, channels)))
    model.add(Convolution2D(64, (4,4), strides=(2,2), activation='relu'))
    model.add(Convolution2D(64, (3,3), activation='relu'))
    model.add(Flatten())
    model.add(Dense(512, activation='relu'))
    model.add(Dense(256, activation='relu'))
    model.add(Dense(actions, activation='linear'))
    return model


#del model
model = build_model(height, width, channels, actions)
model.summary()

from rl.agents import DQNAgent
from rl.memory import SequentialMemory
from rl.policy import LinearAnnealedPolicy, EpsGreedyQPolicy

def build_agent(model, actions):
    policy = LinearAnnealedPolicy(EpsGreedyQPolicy(), attr='eps', value_max=1., value_min=.1, value_test=.2, nb_steps=10000)
    memory = SequentialMemory(limit=1000, window_length=3)
    dqn = DQNAgent(model=model, memory=memory, policy=policy, enable_dueling_network=True, dueling_type='avg', nb_actions=actions, nb_steps_warmup=1000)
    return dqn

dqn = build_agent(model, actions)
dqn.compile(Adam(lr=1e-4))

dqn.fit(env, nb_steps=10000, visualize=False, verbose=2)
dqn.save_weights('drive/MyDrive/Saved/dqn_weights.h5f')

#dqn.load_weights('drive/MyDrive/Saved/dqn_weights.h5f')
#scores = dqn.test(env, nb_episodes=10, visualize=True)
#print(np.mean(scores.history['episode_reward']))


del model, dqn

댓글