Added new example that uses the CNTK Keras backend to train a bird to fly through a cactus maze using reinforcement learning.
This commit is contained in:
Родитель
f217bacbcb
Коммит
2f4bf7475e
|
@ -0,0 +1,258 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
from collections import deque
|
||||
import json
|
||||
import numpy as np
|
||||
import os
|
||||
import random
|
||||
import requests
|
||||
import skimage as skimage
|
||||
from skimage import transform, color, exposure
|
||||
from skimage.transform import rotate
|
||||
from skimage.viewer import ImageViewer
|
||||
import sys
|
||||
|
||||
# Load the right urlretrieve based on python version
|
||||
try:
|
||||
from urllib.request import urlretrieve
|
||||
except ImportError:
|
||||
from urllib import urlretrieve
|
||||
|
||||
import game.wrapped_flappy_bird as game
|
||||
|
||||
from keras.models import model_from_json
|
||||
from keras.models import Sequential
|
||||
from keras.layers.core import Dense, Dropout, Activation, Flatten
|
||||
from keras.layers.convolutional import Convolution2D, MaxPooling2D
|
||||
from keras.optimizers import SGD , Adam
|
||||
|
||||
GAME = 'bird' # the name of the game being played for log files
|
||||
CONFIG = 'nothreshold'
|
||||
ACTIONS = 2 # number of valid actions
|
||||
GAMMA = 0.99 # decay rate of past observations
|
||||
OBSERVATION = 320. # timesteps to observe before training
|
||||
EXPLORE = 3000000. # frames over which to anneal epsilon
|
||||
FINAL_EPSILON = 0.0001 # final value of epsilon
|
||||
INITIAL_EPSILON = 0.1 # starting value of epsilon
|
||||
REPLAY_MEMORY = 50000 # number of previous transitions to remember
|
||||
BATCH = 32 # size of minibatch
|
||||
FRAME_PER_ACTION = 1
|
||||
LEARNING_RATE = 1e-4
|
||||
NUMRUNS = 400
|
||||
PRETRAINED_MODEL_URL_DEFAULT = 'https://cntk.ai/Models/FlappingBird_keras/FlappingBird_model.h5.model'
|
||||
PRETRAINED_MODEL_FNAME = 'FlappingBird_model.h5'
|
||||
|
||||
img_rows , img_cols = 80, 80
|
||||
#Convert image into Black and white
|
||||
img_channels = 4 #We stack 4 frames
|
||||
|
||||
def pretrained_model_download(url, filename):
|
||||
'''Download the file unless it already exists, with retry. Throws if all retries fail.'''
|
||||
if os.path.exists(filename):
|
||||
print('Reusing locally cached: ', filename)
|
||||
else:
|
||||
print('Starting download of {} to {}'.format(url, filename))
|
||||
retry_cnt = 0
|
||||
while True:
|
||||
try:
|
||||
urlretrieve(url, filename)
|
||||
print('Download completed.')
|
||||
return
|
||||
except:
|
||||
retry_cnt += 1
|
||||
if retry_cnt == max_retries:
|
||||
raise Exception('Exceeded maximum retry count, aborting.')
|
||||
print('Failed to download, retrying.')
|
||||
time.sleep(np.random.randint(1,10))
|
||||
|
||||
def buildmodel():
|
||||
print("Now we build the model")
|
||||
model = Sequential()
|
||||
model.add(Convolution2D(32, 8, 8, subsample=(4, 4), border_mode='same',input_shape=(img_rows,img_cols,img_channels))) #80*80*4
|
||||
model.add(Activation('relu'))
|
||||
model.add(Convolution2D(64, 4, 4, subsample=(2, 2), border_mode='same'))
|
||||
model.add(Activation('relu'))
|
||||
model.add(Convolution2D(64, 3, 3, subsample=(1, 1), border_mode='same'))
|
||||
model.add(Activation('relu'))
|
||||
model.add(Flatten())
|
||||
model.add(Dense(512))
|
||||
model.add(Activation('relu'))
|
||||
model.add(Dense(2))
|
||||
|
||||
adam = Adam(lr=LEARNING_RATE)
|
||||
model.compile(loss='mse',optimizer=adam)
|
||||
print("We finish building the model")
|
||||
return model
|
||||
|
||||
def trainNetwork(model, args, pretrained_model_url=None, internal_testing=False):
|
||||
print(args)
|
||||
|
||||
# Check if pretrained model url is passed in else ignore
|
||||
if not pretrained_model_url:
|
||||
pretrained_model_url = PRETRAINED_MODEL_URL_DEFAULT
|
||||
else:
|
||||
pretrained_model_url = pretrained_model_url
|
||||
|
||||
pretrained_model_fname = PRETRAINED_MODEL_FNAME
|
||||
|
||||
# open up a game state to communicate with emulator
|
||||
game_state = game.GameState()
|
||||
|
||||
# store the previous observations in replay memory
|
||||
D = deque()
|
||||
|
||||
# get the first state by doing nothing and preprocess the image to 80x80x4
|
||||
do_nothing = np.zeros(ACTIONS)
|
||||
do_nothing[0] = 1
|
||||
x_t, r_0, terminal = game_state.frame_step(do_nothing)
|
||||
|
||||
x_t = skimage.color.rgb2gray(x_t)
|
||||
x_t = skimage.transform.resize(x_t,(80,80))
|
||||
x_t = skimage.exposure.rescale_intensity(x_t,out_range=(0,255))
|
||||
|
||||
if internal_testing:
|
||||
x_t = np.random.rand(x_t.shape[0], x_t.shape[1])
|
||||
|
||||
s_t = np.stack((x_t, x_t, x_t, x_t), axis=2).astype(np.float32)
|
||||
|
||||
#In Keras, need to reshape
|
||||
s_t = s_t.reshape(1, s_t.shape[0], s_t.shape[1], s_t.shape[2]) #1*80*80*4
|
||||
|
||||
if args['mode'] == 'Run':
|
||||
OBSERVE = 999999999 #We keep observe, never train
|
||||
epsilon = FINAL_EPSILON
|
||||
print ("Now we load weight")
|
||||
pretrained_model_download(pretrained_model_url, pretrained_model_fname)
|
||||
model.load_weights(pretrained_model_fname)
|
||||
adam = Adam(lr=LEARNING_RATE)
|
||||
model.compile(loss='mse',optimizer=adam)
|
||||
print ("Weight load successfully")
|
||||
else: #We go to training mode
|
||||
OBSERVE = OBSERVATION
|
||||
epsilon = INITIAL_EPSILON
|
||||
|
||||
t = 0
|
||||
while (True):
|
||||
loss = 0
|
||||
Q_sa = 0
|
||||
action_index = 0
|
||||
r_t = 0
|
||||
a_t = np.zeros([ACTIONS])
|
||||
#choose an action epsilon greedy
|
||||
if t % FRAME_PER_ACTION == 0:
|
||||
if random.random() <= epsilon:
|
||||
print("----------Random Action----------")
|
||||
action_index = random.randrange(ACTIONS)
|
||||
a_t[action_index] = 1
|
||||
else:
|
||||
q = model.predict(s_t) #input a stack of 4 images, get the prediction
|
||||
max_Q = np.argmax(q)
|
||||
action_index = max_Q
|
||||
a_t[max_Q] = 1
|
||||
|
||||
if ((args['mode'] == 'Run') and (t > NUMRUNS)):
|
||||
break
|
||||
|
||||
#We reduced the epsilon gradually
|
||||
if epsilon > FINAL_EPSILON and t > OBSERVE:
|
||||
epsilon -= (INITIAL_EPSILON - FINAL_EPSILON) / EXPLORE
|
||||
|
||||
#run the selected action and observed next state and reward
|
||||
x_t1_colored, r_t, terminal = game_state.frame_step(a_t)
|
||||
|
||||
x_t1 = skimage.color.rgb2gray(x_t1_colored)
|
||||
x_t1 = skimage.transform.resize(x_t1,(80,80))
|
||||
x_t1 = skimage.exposure.rescale_intensity(x_t1, out_range=(0, 255))
|
||||
|
||||
x_t1 = x_t1.reshape(1, x_t1.shape[0], x_t1.shape[1], 1).astype(np.float32) #1x80x80x1
|
||||
s_t1 = np.append(x_t1, s_t[:, :, :, :3], axis=3)
|
||||
|
||||
# store the transition in D
|
||||
|
||||
D.append((s_t, action_index, r_t, s_t1, terminal))
|
||||
if len(D) > REPLAY_MEMORY:
|
||||
D.popleft()
|
||||
|
||||
#only train if done observing
|
||||
if t > OBSERVE:
|
||||
#sample a minibatch to train on
|
||||
minibatch = random.sample(D, BATCH)
|
||||
|
||||
inputs = np.zeros((BATCH, s_t.shape[1], s_t.shape[2], s_t.shape[3])).astype(np.float32) #32, 80, 80, 4
|
||||
print (inputs.shape)
|
||||
targets = np.zeros((inputs.shape[0], ACTIONS)).astype(np.float32) #32, 2
|
||||
|
||||
#Now we do the experience replay
|
||||
for i in range(0, len(minibatch)):
|
||||
state_t = minibatch[i][0]
|
||||
action_t = minibatch[i][1] #This is action index
|
||||
reward_t = minibatch[i][2]
|
||||
state_t1 = minibatch[i][3]
|
||||
terminal = minibatch[i][4]
|
||||
# if terminated, only equals reward
|
||||
|
||||
inputs[i:i + 1] = state_t #I saved down s_t
|
||||
|
||||
targets[i] = model.predict(state_t) # Hitting each buttom probability
|
||||
Q_sa = model.predict(state_t1)
|
||||
|
||||
if terminal:
|
||||
targets[i, action_t] = reward_t
|
||||
else:
|
||||
targets[i, action_t] = reward_t + GAMMA * np.max(Q_sa)
|
||||
|
||||
# targets2 = normalize(targets)
|
||||
loss += model.train_on_batch(inputs, targets)
|
||||
|
||||
s_t = s_t1
|
||||
t = t + 1
|
||||
|
||||
# save progress every 10000 iterations
|
||||
if t % 1000 == 0:
|
||||
|
||||
print("Now we save model")
|
||||
model.save_weights(pretrained_model_fname, overwrite=True)
|
||||
with open("model.json", "w") as outfile:
|
||||
json.dump(model.to_json(), outfile)
|
||||
|
||||
# print info
|
||||
state = ""
|
||||
if t <= OBSERVE:
|
||||
if internal_testing:
|
||||
return 0 #0 means success
|
||||
else:
|
||||
state = "observe"
|
||||
elif t > OBSERVE and t <= OBSERVE + EXPLORE:
|
||||
state = "explore"
|
||||
else:
|
||||
state = "train"
|
||||
|
||||
print("TIMESTEP", t, "/ STATE", state, \
|
||||
"/ EPSILON", epsilon, "/ ACTION", action_index, "/ REWARD", r_t, \
|
||||
"/ Q_MAX " , np.max(Q_sa), "/ Loss ", loss)
|
||||
|
||||
print("Episode finished!")
|
||||
print("************************")
|
||||
|
||||
def playGame(args):
|
||||
model = buildmodel()
|
||||
trainNetwork(model,args)
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Description of your program')
|
||||
parser.add_argument('-m','--mode', help='Train / Run', required=True)
|
||||
args = vars(parser.parse_args())
|
||||
playGame(args)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# CNTK auto detects the GPU and is able to optimally allocate resources
|
||||
# Hence, these lines below are commented out.
|
||||
# from keras import backend as K
|
||||
#if K.backend() == 'tensorflow':
|
||||
# config = tf.ConfigProto()
|
||||
# config.gpu_options.allow_growth = True
|
||||
# sess = tf.Session(config=config)
|
||||
# K.set_session(sess)
|
||||
main()
|
|
@ -0,0 +1,72 @@
|
|||
# Flapping Bird using Keras and Reinforcement Learning
|
||||
|
||||
In [CNTK 203 tutorial](https://github.com/Microsoft/CNTK/blob/master/Tutorials/CNTK_203_Reinforcement_Learning_Basics.ipynb),
|
||||
we have introduced the basic concepts of reinforcement
|
||||
learning. In this example, we show an easy way to train a popular game called
|
||||
FlappyBird using Deep Q Network (DQN). This tutorial draws heavily on the
|
||||
[original work](https://yanpanlau.github.io/2016/07/10/FlappyBird-Keras.html)
|
||||
by Ben Lau on training the FlappyBird game with Keras frontend. This tutorial
|
||||
uses the CNTK backend and with very little change (commenting out a few specific
|
||||
references to TensorFlow) in the original code.
|
||||
|
||||
Note: Since, we have replaced the game environment components with different components drawn
|
||||
from public data sources, we call the game Flapping Bird.
|
||||
|
||||
# Goals
|
||||
|
||||
The key objective behind this example is to show how to:
|
||||
- Use CNTK backend API with Keras frontend
|
||||
- Interchangeably use models trained between TensorFlow and CNTK via Keras
|
||||
- Train and test (evaluate) the flapping bird game using a simple DQN implementation.
|
||||
|
||||
# Pre-requisite
|
||||
|
||||
Assuming you have installed CNTK, installed Keras and configured Keras
|
||||
backend to be CNTK. The details are [here](https://docs.microsoft.com/en-us/cognitive-toolkit/using-cntk-with-keras).
|
||||
|
||||
This example takes additional dependency on the following Python packages:
|
||||
- pygame (`pip install pygame`)
|
||||
- scikit-learn (`conda install scikit-learn`)
|
||||
- scikit-image (`conda install scikit-image`)
|
||||
|
||||
These packages are needed to perform image manipulation operation and interactivity
|
||||
of the RL agent with the game environment.
|
||||
|
||||
# How to run?
|
||||
|
||||
From the example directory, run:
|
||||
|
||||
```
|
||||
python FlappingBird_with_keras_DQN.py -m Run
|
||||
```
|
||||
|
||||
Note: if you run the game first time in "Run" mode a pre-trained model is
|
||||
locally downloaded. Note, even though this model was trained with TensorFlow,
|
||||
Keras takes care of saving and loading in a portable format. This allows for
|
||||
model evaluation with CNTK.
|
||||
|
||||
If you want to train the network from beginning, delete the locally cached
|
||||
`FlappingBird_model.h5` (if you have a locally cached file) and run. After
|
||||
training, the trained model can be evaluated with CNTK or TensorFlow.
|
||||
|
||||
```
|
||||
python FlappingBird_with_keras_DQN.py -m Train
|
||||
```
|
||||
|
||||
# Brief recap
|
||||
|
||||
The code has 4 steps:
|
||||
|
||||
1. Receive the image of the game screen as pixel array
|
||||
2. Process the image
|
||||
3. Use a Deep Convolutional Neural Network (CNN) to predict the best action
|
||||
(flap up or down)
|
||||
4. Train the network (millions of times) to maximize flying time
|
||||
|
||||
Details can be found in Ben Lau's
|
||||
[original work](https://yanpanlau.github.io/2016/07/10/FlappyBird-Keras.html)
|
||||
|
||||
# Acknowledgements
|
||||
- Ben Lau: Developer of Keras RL example for contributing the [code](https://yanpanlau.github.io/2016/07/10/FlappyBird-Keras.html) with TensorFlow backend.
|
||||
- Bird sprite from [Open Game Art](https://opengameart.org/content/free-game-asset-grumpy-flappy-bird-sprite-sheets).
|
||||
- Shreyaan Pathak: Seventh Grade student at Northstar Middle School, Kirkland, WA for creating and processing new game sprites.
|
|
@ -0,0 +1,11 @@
|
|||
# ==============================================================================
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
'''
|
||||
CNTK Flappybird Game Env.
|
||||
'''
|
||||
# TODO: Remove * import
|
||||
from .game import *
|
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/background-black.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/background-black.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 3.9 KiB |
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/base.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/base.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 664 B |
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/cactus-green.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/cactus-green.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 16 KiB |
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-downflap.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-downflap.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 1.6 KiB |
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-midflap.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-midflap.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 1.5 KiB |
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-upflap.png
Normal file
Двоичные данные
Examples/ReinforcementLearning/FlappingBirdWithKeras/assets/sprites/newbird-upflap.png
Normal file
Двоичный файл не отображается.
После Ширина: | Высота: | Размер: 1.7 KiB |
|
@ -0,0 +1,9 @@
|
|||
# ==============================================================================
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
'''
|
||||
CNTK Flappybird Game Env.
|
||||
'''
|
|
@ -0,0 +1,89 @@
|
|||
import pygame
|
||||
import sys
|
||||
def load():
|
||||
# path of player with different states
|
||||
PLAYER_PATH = (
|
||||
'assets/sprites/newbird-upflap.png',
|
||||
'assets/sprites/newbird-midflap.png',
|
||||
'assets/sprites/newbird-downflap.png'
|
||||
)
|
||||
|
||||
# path of background
|
||||
BACKGROUND_PATH = 'assets/sprites/background-black.png'
|
||||
|
||||
# path of pipe
|
||||
PIPE_PATH = 'assets/sprites/cactus-green.png'
|
||||
|
||||
IMAGES, SOUNDS, HITMASKS = {}, {}, {}
|
||||
|
||||
# Commented out since we are not scoring in this example
|
||||
# numbers sprites for score display
|
||||
#IMAGES['numbers'] = (
|
||||
# pygame.image.load('assets/sprites/0.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/1.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/2.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/3.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/4.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/5.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/6.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/7.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/8.png').convert_alpha(),
|
||||
# pygame.image.load('assets/sprites/9.png').convert_alpha()
|
||||
#)
|
||||
|
||||
# base (ground) sprite
|
||||
IMAGES['base'] = pygame.image.load('assets/sprites/base.png').convert_alpha()
|
||||
|
||||
# sounds
|
||||
if 'win' in sys.platform:
|
||||
soundExt = '.wav'
|
||||
else:
|
||||
soundExt = '.ogg'
|
||||
|
||||
# Commented out since we are not using sounds in this example
|
||||
#SOUNDS['die'] = pygame.mixer.Sound('assets/audio/die' + soundExt)
|
||||
#SOUNDS['hit'] = pygame.mixer.Sound('assets/audio/hit' + soundExt)
|
||||
#SOUNDS['point'] = pygame.mixer.Sound('assets/audio/point' + soundExt)
|
||||
#SOUNDS['swoosh'] = pygame.mixer.Sound('assets/audio/swoosh' + soundExt)
|
||||
#SOUNDS['wing'] = pygame.mixer.Sound('assets/audio/wing' + soundExt)
|
||||
|
||||
# select random background sprites
|
||||
IMAGES['background'] = pygame.image.load(BACKGROUND_PATH).convert()
|
||||
|
||||
# select random player sprites
|
||||
IMAGES['player'] = (
|
||||
pygame.image.load(PLAYER_PATH[0]).convert_alpha(),
|
||||
pygame.image.load(PLAYER_PATH[1]).convert_alpha(),
|
||||
pygame.image.load(PLAYER_PATH[2]).convert_alpha(),
|
||||
)
|
||||
|
||||
# select random pipe sprites
|
||||
IMAGES['pipe'] = (
|
||||
pygame.transform.rotate(
|
||||
pygame.image.load(PIPE_PATH).convert_alpha(), 180),
|
||||
pygame.image.load(PIPE_PATH).convert_alpha(),
|
||||
)
|
||||
|
||||
# hismask for pipes
|
||||
HITMASKS['pipe'] = (
|
||||
getHitmask(IMAGES['pipe'][0]),
|
||||
getHitmask(IMAGES['pipe'][1]),
|
||||
)
|
||||
|
||||
# hitmask for player
|
||||
HITMASKS['player'] = (
|
||||
getHitmask(IMAGES['player'][0]),
|
||||
getHitmask(IMAGES['player'][1]),
|
||||
getHitmask(IMAGES['player'][2]),
|
||||
)
|
||||
|
||||
return IMAGES, SOUNDS, HITMASKS
|
||||
|
||||
def getHitmask(image):
|
||||
"""returns a hitmask using an image's alpha."""
|
||||
mask = []
|
||||
for x in range(image.get_width()):
|
||||
mask.append([])
|
||||
for y in range(image.get_height()):
|
||||
mask[x].append(bool(image.get_at((x,y))[3]))
|
||||
return mask
|
|
@ -0,0 +1,227 @@
|
|||
import numpy as np
|
||||
import sys
|
||||
import random
|
||||
|
||||
from game import flappy_bird_utils
|
||||
import pygame
|
||||
import pygame.surfarray as surfarray
|
||||
from pygame.locals import *
|
||||
from itertools import cycle
|
||||
|
||||
FPS = 30
|
||||
SCREENWIDTH = 288
|
||||
SCREENHEIGHT = 512
|
||||
|
||||
pygame.init()
|
||||
FPSCLOCK = pygame.time.Clock()
|
||||
SCREEN = pygame.display.set_mode((SCREENWIDTH, SCREENHEIGHT))
|
||||
pygame.display.set_caption('Flappy Bird')
|
||||
|
||||
IMAGES, SOUNDS, HITMASKS = flappy_bird_utils.load()
|
||||
PIPEGAPSIZE = 100 # gap between upper and lower part of pipe
|
||||
BASEY = SCREENHEIGHT * 0.79
|
||||
|
||||
PLAYER_WIDTH = IMAGES['player'][0].get_width()
|
||||
PLAYER_HEIGHT = IMAGES['player'][0].get_height()
|
||||
PIPE_WIDTH = IMAGES['pipe'][0].get_width()
|
||||
PIPE_HEIGHT = IMAGES['pipe'][0].get_height()
|
||||
BACKGROUND_WIDTH = IMAGES['background'].get_width()
|
||||
|
||||
PLAYER_INDEX_GEN = cycle([0, 1, 2, 1])
|
||||
|
||||
|
||||
class GameState:
|
||||
def __init__(self):
|
||||
self.score = self.playerIndex = self.loopIter = 0
|
||||
self.playerx = int(SCREENWIDTH * 0.2)
|
||||
self.playery = int((SCREENHEIGHT - PLAYER_HEIGHT) / 2)
|
||||
self.basex = 0
|
||||
self.baseShift = IMAGES['base'].get_width() - BACKGROUND_WIDTH
|
||||
|
||||
newPipe1 = getRandomPipe()
|
||||
newPipe2 = getRandomPipe()
|
||||
self.upperPipes = [
|
||||
{'x': SCREENWIDTH, 'y': newPipe1[0]['y']},
|
||||
{'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[0]['y']},
|
||||
]
|
||||
self.lowerPipes = [
|
||||
{'x': SCREENWIDTH, 'y': newPipe1[1]['y']},
|
||||
{'x': SCREENWIDTH + (SCREENWIDTH / 2), 'y': newPipe2[1]['y']},
|
||||
]
|
||||
|
||||
# player velocity, max velocity, downward accleration, accleration on flap
|
||||
self.pipeVelX = -4
|
||||
self.playerVelY = 0 # player's velocity along Y, default same as playerFlapped
|
||||
self.playerMaxVelY = 10 # max vel along Y, max descend speed
|
||||
self.playerMinVelY = -8 # min vel along Y, max ascend speed
|
||||
self.playerAccY = 1 # players downward accleration
|
||||
self.playerFlapAcc = -9 # players speed on flapping
|
||||
self.playerFlapped = False # True when player flaps
|
||||
|
||||
def frame_step(self, input_actions):
|
||||
pygame.event.pump()
|
||||
|
||||
reward = 0.1
|
||||
terminal = False
|
||||
|
||||
if sum(input_actions) != 1:
|
||||
raise ValueError('Multiple input actions!')
|
||||
|
||||
# input_actions[0] == 1: do nothing
|
||||
# input_actions[1] == 1: flap the bird
|
||||
if input_actions[1] == 1:
|
||||
if self.playery > -2 * PLAYER_HEIGHT:
|
||||
self.playerVelY = self.playerFlapAcc
|
||||
self.playerFlapped = True
|
||||
#SOUNDS['wing'].play()
|
||||
|
||||
# check for score
|
||||
playerMidPos = self.playerx + PLAYER_WIDTH / 2
|
||||
for pipe in self.upperPipes:
|
||||
pipeMidPos = pipe['x'] + PIPE_WIDTH / 2
|
||||
if pipeMidPos <= playerMidPos < pipeMidPos + 4:
|
||||
self.score += 1
|
||||
#SOUNDS['point'].play()
|
||||
reward = 1
|
||||
|
||||
# playerIndex basex change
|
||||
if (self.loopIter + 1) % 3 == 0:
|
||||
self.playerIndex = next(PLAYER_INDEX_GEN)
|
||||
self.loopIter = (self.loopIter + 1) % 30
|
||||
self.basex = -((-self.basex + 100) % self.baseShift)
|
||||
|
||||
# player's movement
|
||||
if self.playerVelY < self.playerMaxVelY and not self.playerFlapped:
|
||||
self.playerVelY += self.playerAccY
|
||||
if self.playerFlapped:
|
||||
self.playerFlapped = False
|
||||
self.playery += min(self.playerVelY, BASEY - self.playery - PLAYER_HEIGHT)
|
||||
if self.playery < 0:
|
||||
self.playery = 0
|
||||
|
||||
# move pipes to left
|
||||
for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
|
||||
uPipe['x'] += self.pipeVelX
|
||||
lPipe['x'] += self.pipeVelX
|
||||
|
||||
# add new pipe when first pipe is about to touch left of screen
|
||||
if 0 < self.upperPipes[0]['x'] < 5:
|
||||
newPipe = getRandomPipe()
|
||||
self.upperPipes.append(newPipe[0])
|
||||
self.lowerPipes.append(newPipe[1])
|
||||
|
||||
# remove first pipe if its out of the screen
|
||||
if self.upperPipes[0]['x'] < -PIPE_WIDTH:
|
||||
self.upperPipes.pop(0)
|
||||
self.lowerPipes.pop(0)
|
||||
|
||||
# check if crash here
|
||||
isCrash= checkCrash({'x': self.playerx, 'y': self.playery,
|
||||
'index': self.playerIndex},
|
||||
self.upperPipes, self.lowerPipes)
|
||||
if isCrash:
|
||||
#SOUNDS['hit'].play()
|
||||
#SOUNDS['die'].play()
|
||||
terminal = True
|
||||
self.__init__()
|
||||
reward = -1
|
||||
|
||||
# draw sprites
|
||||
SCREEN.blit(IMAGES['background'], (0,0))
|
||||
|
||||
for uPipe, lPipe in zip(self.upperPipes, self.lowerPipes):
|
||||
SCREEN.blit(IMAGES['pipe'][0], (uPipe['x'], uPipe['y']))
|
||||
SCREEN.blit(IMAGES['pipe'][1], (lPipe['x'], lPipe['y']))
|
||||
|
||||
SCREEN.blit(IMAGES['base'], (self.basex, BASEY))
|
||||
# print score so player overlaps the score
|
||||
# showScore(self.score)
|
||||
SCREEN.blit(IMAGES['player'][self.playerIndex],
|
||||
(self.playerx, self.playery))
|
||||
|
||||
image_data = pygame.surfarray.array3d(pygame.display.get_surface())
|
||||
pygame.display.update()
|
||||
#print ("FPS" , FPSCLOCK.get_fps())
|
||||
FPSCLOCK.tick(FPS)
|
||||
#print self.upperPipes[0]['y'] + PIPE_HEIGHT - int(BASEY * 0.2)
|
||||
return image_data, reward, terminal
|
||||
|
||||
def getRandomPipe():
|
||||
"""returns a randomly generated pipe"""
|
||||
# y of gap between upper and lower pipe
|
||||
gapYs = [20, 30, 40, 50, 60, 70, 80, 90]
|
||||
index = random.randint(0, len(gapYs)-1)
|
||||
gapY = gapYs[index]
|
||||
|
||||
gapY += int(BASEY * 0.2)
|
||||
pipeX = SCREENWIDTH + 10
|
||||
|
||||
return [
|
||||
{'x': pipeX, 'y': gapY - PIPE_HEIGHT}, # upper pipe
|
||||
{'x': pipeX, 'y': gapY + PIPEGAPSIZE}, # lower pipe
|
||||
]
|
||||
|
||||
|
||||
def showScore(score):
|
||||
"""displays score in center of screen"""
|
||||
scoreDigits = [int(x) for x in list(str(score))]
|
||||
totalWidth = 0 # total width of all numbers to be printed
|
||||
|
||||
for digit in scoreDigits:
|
||||
totalWidth += IMAGES['numbers'][digit].get_width()
|
||||
|
||||
Xoffset = (SCREENWIDTH - totalWidth) / 2
|
||||
|
||||
for digit in scoreDigits:
|
||||
SCREEN.blit(IMAGES['numbers'][digit], (Xoffset, SCREENHEIGHT * 0.1))
|
||||
Xoffset += IMAGES['numbers'][digit].get_width()
|
||||
|
||||
|
||||
def checkCrash(player, upperPipes, lowerPipes):
|
||||
"""returns True if player collders with base or pipes."""
|
||||
pi = player['index']
|
||||
player['w'] = IMAGES['player'][0].get_width()
|
||||
player['h'] = IMAGES['player'][0].get_height()
|
||||
|
||||
# if player crashes into ground
|
||||
if player['y'] + player['h'] >= BASEY - 1:
|
||||
return True
|
||||
else:
|
||||
|
||||
playerRect = pygame.Rect(player['x'], player['y'],
|
||||
player['w'], player['h'])
|
||||
|
||||
for uPipe, lPipe in zip(upperPipes, lowerPipes):
|
||||
# upper and lower pipe rects
|
||||
uPipeRect = pygame.Rect(uPipe['x'], uPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
|
||||
lPipeRect = pygame.Rect(lPipe['x'], lPipe['y'], PIPE_WIDTH, PIPE_HEIGHT)
|
||||
|
||||
# player and upper/lower pipe hitmasks
|
||||
pHitMask = HITMASKS['player'][pi]
|
||||
uHitmask = HITMASKS['pipe'][0]
|
||||
lHitmask = HITMASKS['pipe'][1]
|
||||
|
||||
# if bird collided with upipe or lpipe
|
||||
uCollide = pixelCollision(playerRect, uPipeRect, pHitMask, uHitmask)
|
||||
lCollide = pixelCollision(playerRect, lPipeRect, pHitMask, lHitmask)
|
||||
|
||||
if uCollide or lCollide:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def pixelCollision(rect1, rect2, hitmask1, hitmask2):
|
||||
"""Checks if two objects collide and not just their rects"""
|
||||
rect = rect1.clip(rect2)
|
||||
|
||||
if rect.width == 0 or rect.height == 0:
|
||||
return False
|
||||
|
||||
x1, y1 = rect.x - rect1.x, rect.y - rect1.y
|
||||
x2, y2 = rect.x - rect2.x, rect.y - rect2.y
|
||||
|
||||
for x in range(rect.width):
|
||||
for y in range(rect.height):
|
||||
if hitmask1[x1+x][y1+y] and hitmask2[x2+x][y2+y]:
|
||||
return True
|
||||
return False
|
|
@ -16,15 +16,19 @@ dependencies:
|
|||
- pip=8.1.2=py27_0
|
||||
- python=2.7.11=5
|
||||
- pyyaml=3.12=py27_0
|
||||
- scikit-image=0.12.3=np111py27_1
|
||||
- scikit-learn=0.18.1=np111py27_0
|
||||
- scipy=0.18.1=np111py27_0
|
||||
- seaborn=0.7.1=py27_0
|
||||
- setuptools=27.2.0=py27_0
|
||||
- six=1.10.0=py27_0
|
||||
- wheel=0.29.0=py27_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym[atari]==0.8.1
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -9,25 +9,28 @@ dependencies:
|
|||
- jupyter=1.0.0=py34_3
|
||||
- matplotlib=1.5.3=np111py34_0
|
||||
- numpy=1.11.2=py34_0
|
||||
- opencv=3.1.0=np111py34_1
|
||||
- pandas=0.19.1=np111py34_0
|
||||
- pandas-datareader=0.2.1=py34_0
|
||||
- pillow=3.4.2=py34_0
|
||||
- pip=8.1.2=py34_0
|
||||
- python=3.4.4=5
|
||||
- pyyaml=3.12=py34_0
|
||||
- scikit-image=0.12.3=np111py34_1
|
||||
- scikit-learn=0.18.1=np111py34_0
|
||||
- scipy=0.18.1=np111py34_0
|
||||
- seaborn=0.7.1=py34_0
|
||||
- setuptools=27.2.0=py34_0
|
||||
- six=1.10.0=py34_0
|
||||
- wheel=0.29.0=py34_0
|
||||
- opencv=3.1.0=np111py34_1
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym[atari]==0.8.1
|
||||
- pydot-ng==1.0.0
|
||||
- future==0.16.0
|
||||
- easydict==1.6.0
|
||||
- scikit-image==0.12.3
|
||||
- future==0.16.0
|
||||
- gym[atari]==0.8.1
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -15,15 +15,19 @@ dependencies:
|
|||
- pip=8.1.2=py35_0
|
||||
- python=3.5.2=0
|
||||
- pyyaml=3.12=py35_0
|
||||
- scikit-image=0.12.3=np111py35_1
|
||||
- scikit-learn=0.18.1=np111py35_0
|
||||
- scipy=0.18.1=np111py35_0
|
||||
- seaborn=0.7.1=py35_0
|
||||
- setuptools=27.2.0=py35_0
|
||||
- six=1.10.0=py35_0
|
||||
- wheel=0.29.0=py35_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym[atari]==0.8.1
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -15,15 +15,19 @@ dependencies:
|
|||
- pip=9.0.1=py36_1
|
||||
- python=3.6.0=0
|
||||
- pyyaml=3.12=py36_0
|
||||
- scikit-image=0.12.3=np111py36_1
|
||||
- scikit-learn=0.18.1=np111py36_0
|
||||
- scipy=0.18.1=np111py36_0
|
||||
- seaborn=0.7.1=py36_0
|
||||
- setuptools=27.2.0=py36_0
|
||||
- six=1.10.0=py36_0
|
||||
- wheel=0.29.0=py36_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym[atari]==0.8.1
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -16,15 +16,19 @@ dependencies:
|
|||
- pip=8.1.2=py27_0
|
||||
- python=2.7.11=5
|
||||
- pyyaml=3.12=py27_0
|
||||
- scikit-image=0.12.3=np111py27_1
|
||||
- scikit-learn=0.18.1=np111py27_0
|
||||
- scipy=0.18.1=np111py27_0
|
||||
- seaborn=0.7.1=py27_0
|
||||
- setuptools=27.2.0=py27_1
|
||||
- six=1.10.0=py27_0
|
||||
- wheel=0.29.0=py27_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym==0.5.2
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -15,15 +15,19 @@ dependencies:
|
|||
- pip=8.1.2=py34_0
|
||||
- python=3.4.4=5
|
||||
- pyyaml=3.12=py34_0
|
||||
- scikit-image=0.12.3=np111py34_1
|
||||
- scikit-learn=0.18.1=np111py34_0
|
||||
- scipy=0.18.1=np111py34_0
|
||||
- seaborn=0.7.1=py34_0
|
||||
- setuptools=27.2.0=py34_1
|
||||
- six=1.10.0=py34_0
|
||||
- wheel=0.29.0=py34_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym==0.5.2
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -15,15 +15,19 @@ dependencies:
|
|||
- pip=8.1.2=py35_0
|
||||
- python=3.5.2=0
|
||||
- pyyaml=3.12=py35_0
|
||||
- scikit-image=0.12.3=np111py35_1
|
||||
- scikit-learn=0.18.1=np111py35_0
|
||||
- scipy=0.18.1=np111py35_0
|
||||
- seaborn=0.7.1=py35_0
|
||||
- setuptools=27.2.0=py35_1
|
||||
- six=1.10.0=py35_0
|
||||
- wheel=0.29.0=py35_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym==0.5.2
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -15,15 +15,19 @@ dependencies:
|
|||
- pip=9.0.1=py36_1
|
||||
- python=3.6.0=0
|
||||
- pyyaml=3.12=py36_0
|
||||
- scikit-image=0.12.3=np111py36_1
|
||||
- scikit-learn=0.18.1=np111py36_0
|
||||
- scipy=0.18.1=np111py36_0
|
||||
- seaborn=0.7.1=py36_0
|
||||
- setuptools=27.2.0=py36_1
|
||||
- six=1.10.0=py36_0
|
||||
- wheel=0.29.0=py36_0
|
||||
- pip:
|
||||
- pytest==3.0.3
|
||||
- sphinx==1.5.4
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- twine==1.8.1
|
||||
- gym==0.5.2
|
||||
- keras==2.0.6
|
||||
- pydot-ng==1.0.0
|
||||
- pygame==1.9.3
|
||||
- pytest==3.0.3
|
||||
- sphinx-rtd-theme==0.2.4
|
||||
- sphinx==1.5.4
|
||||
- twine==1.8.1
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
# Licensed under the MIT license. See LICENSE.md file in the project root
|
||||
# for full license information.
|
||||
# ==============================================================================
|
||||
|
||||
import numpy as np
|
||||
import os
|
||||
import platform
|
||||
import sys
|
||||
import pytest
|
||||
|
||||
os.environ["SDL_VIDEODRIVER"] = "dummy"
|
||||
os.environ['KERAS_BACKEND'] = "cntk"
|
||||
|
||||
from cntk.device import try_set_default_device, gpu
|
||||
|
||||
abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
example_dir = os.path.join(abs_path, "..", "..", "..", "..",
|
||||
"Examples", "ReinforcementLearning",
|
||||
"FlappingBirdWithKeras")
|
||||
sys.path.append(example_dir)
|
||||
current_dir = os.getcwd()
|
||||
os.chdir(example_dir)
|
||||
|
||||
def test_FlappingBird_with_keras_DQN_noerror(device_id):
|
||||
if platform.system() != 'Windows':
|
||||
pytest.skip('Test only runs on Windows, pygame video device requirement constraint')
|
||||
from cntk.ops.tests.ops_test_utils import cntk_device
|
||||
try_set_default_device(cntk_device(device_id))
|
||||
|
||||
sys.path.append(example_dir)
|
||||
current_dir = os.getcwd()
|
||||
os.chdir(example_dir)
|
||||
|
||||
import FlappingBird_with_keras_DQN as fbgame
|
||||
|
||||
# TODO: Currently the model is downloaded from a cached site
|
||||
# Change the code to pick up the model from a locally
|
||||
# cached directory.
|
||||
model = fbgame.buildmodel()
|
||||
args = {'mode': 'Run'}
|
||||
res = fbgame.trainNetwork(model, args, internal_testing=True )
|
||||
|
||||
np.testing.assert_array_equal(res, 0, \
|
||||
err_msg='Error in running Flapping Bird example', verbose=True)
|
||||
|
||||
args = {'mode': 'Train'}
|
||||
res = fbgame.trainNetwork(model, args, internal_testing=True )
|
||||
|
||||
np.testing.assert_array_equal(res, 0, \
|
||||
err_msg='Error in testing Flapping Bird example', verbose=True)
|
||||
|
||||
#TODO: Add a test case to start with a CNTK trained cached model
|
||||
os.chdir(current_dir)
|
||||
print("Done")
|
||||
|
||||
|
||||
|
||||
|
|
@ -539,4 +539,13 @@
|
|||
"type": ["Recipe"],
|
||||
"dataadded": "05/05/2017"
|
||||
}
|
||||
{
|
||||
"category": ["Reinforcement Learning"],
|
||||
"name": "Flapping Bird with Keras",
|
||||
"url":"https://github.com/Microsoft/CNTK/tree/master/Examples/ReinforcementLearning/FlappingBirdWithKeras",
|
||||
"description": "Using CNTK Keras backend train an agent to navigate a bird through a cactus maze",
|
||||
"language": ["Python"],
|
||||
"type": ["Recipe"],
|
||||
"dataadded": "07/21/2017"
|
||||
}
|
||||
]
|
||||
|
|
Загрузка…
Ссылка в новой задаче