зеркало из https://github.com/microsoft/logrl.git
project init
This commit is contained in:
Родитель
3db228bfb6
Коммит
d9b3746b49
|
@ -1,330 +0,0 @@
|
|||
## Ignore Visual Studio temporary files, build results, and
|
||||
## files generated by popular Visual Studio add-ons.
|
||||
##
|
||||
## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
|
||||
|
||||
# User-specific files
|
||||
*.suo
|
||||
*.user
|
||||
*.userosscache
|
||||
*.sln.docstates
|
||||
|
||||
# User-specific files (MonoDevelop/Xamarin Studio)
|
||||
*.userprefs
|
||||
|
||||
# Build results
|
||||
[Dd]ebug/
|
||||
[Dd]ebugPublic/
|
||||
[Rr]elease/
|
||||
[Rr]eleases/
|
||||
x64/
|
||||
x86/
|
||||
bld/
|
||||
[Bb]in/
|
||||
[Oo]bj/
|
||||
[Ll]og/
|
||||
|
||||
# Visual Studio 2015/2017 cache/options directory
|
||||
.vs/
|
||||
# Uncomment if you have tasks that create the project's static files in wwwroot
|
||||
#wwwroot/
|
||||
|
||||
# Visual Studio 2017 auto generated files
|
||||
Generated\ Files/
|
||||
|
||||
# MSTest test Results
|
||||
[Tt]est[Rr]esult*/
|
||||
[Bb]uild[Ll]og.*
|
||||
|
||||
# NUNIT
|
||||
*.VisualState.xml
|
||||
TestResult.xml
|
||||
|
||||
# Build Results of an ATL Project
|
||||
[Dd]ebugPS/
|
||||
[Rr]eleasePS/
|
||||
dlldata.c
|
||||
|
||||
# Benchmark Results
|
||||
BenchmarkDotNet.Artifacts/
|
||||
|
||||
# .NET Core
|
||||
project.lock.json
|
||||
project.fragment.lock.json
|
||||
artifacts/
|
||||
**/Properties/launchSettings.json
|
||||
|
||||
# StyleCop
|
||||
StyleCopReport.xml
|
||||
|
||||
# Files built by Visual Studio
|
||||
*_i.c
|
||||
*_p.c
|
||||
*_i.h
|
||||
*.ilk
|
||||
*.meta
|
||||
*.obj
|
||||
*.iobj
|
||||
*.pch
|
||||
*.pdb
|
||||
*.ipdb
|
||||
*.pgc
|
||||
*.pgd
|
||||
*.rsp
|
||||
*.sbr
|
||||
*.tlb
|
||||
*.tli
|
||||
*.tlh
|
||||
*.tmp
|
||||
*.tmp_proj
|
||||
*.log
|
||||
*.vspscc
|
||||
*.vssscc
|
||||
.builds
|
||||
*.pidb
|
||||
*.svclog
|
||||
*.scc
|
||||
|
||||
# Chutzpah Test files
|
||||
_Chutzpah*
|
||||
|
||||
# Visual C++ cache files
|
||||
ipch/
|
||||
*.aps
|
||||
*.ncb
|
||||
*.opendb
|
||||
*.opensdf
|
||||
*.sdf
|
||||
*.cachefile
|
||||
*.VC.db
|
||||
*.VC.VC.opendb
|
||||
|
||||
# Visual Studio profiler
|
||||
*.psess
|
||||
*.vsp
|
||||
*.vspx
|
||||
*.sap
|
||||
|
||||
# Visual Studio Trace Files
|
||||
*.e2e
|
||||
|
||||
# TFS 2012 Local Workspace
|
||||
$tf/
|
||||
|
||||
# Guidance Automation Toolkit
|
||||
*.gpState
|
||||
|
||||
# ReSharper is a .NET coding add-in
|
||||
_ReSharper*/
|
||||
*.[Rr]e[Ss]harper
|
||||
*.DotSettings.user
|
||||
|
||||
# JustCode is a .NET coding add-in
|
||||
.JustCode
|
||||
|
||||
# TeamCity is a build add-in
|
||||
_TeamCity*
|
||||
|
||||
# DotCover is a Code Coverage Tool
|
||||
*.dotCover
|
||||
|
||||
# AxoCover is a Code Coverage Tool
|
||||
.axoCover/*
|
||||
!.axoCover/settings.json
|
||||
|
||||
# Visual Studio code coverage results
|
||||
*.coverage
|
||||
*.coveragexml
|
||||
|
||||
# NCrunch
|
||||
_NCrunch_*
|
||||
.*crunch*.local.xml
|
||||
nCrunchTemp_*
|
||||
|
||||
# MightyMoose
|
||||
*.mm.*
|
||||
AutoTest.Net/
|
||||
|
||||
# Web workbench (sass)
|
||||
.sass-cache/
|
||||
|
||||
# Installshield output folder
|
||||
[Ee]xpress/
|
||||
|
||||
# DocProject is a documentation generator add-in
|
||||
DocProject/buildhelp/
|
||||
DocProject/Help/*.HxT
|
||||
DocProject/Help/*.HxC
|
||||
DocProject/Help/*.hhc
|
||||
DocProject/Help/*.hhk
|
||||
DocProject/Help/*.hhp
|
||||
DocProject/Help/Html2
|
||||
DocProject/Help/html
|
||||
|
||||
# Click-Once directory
|
||||
publish/
|
||||
|
||||
# Publish Web Output
|
||||
*.[Pp]ublish.xml
|
||||
*.azurePubxml
|
||||
# Note: Comment the next line if you want to checkin your web deploy settings,
|
||||
# but database connection strings (with potential passwords) will be unencrypted
|
||||
*.pubxml
|
||||
*.publishproj
|
||||
|
||||
# Microsoft Azure Web App publish settings. Comment the next line if you want to
|
||||
# checkin your Azure Web App publish settings, but sensitive information contained
|
||||
# in these scripts will be unencrypted
|
||||
PublishScripts/
|
||||
|
||||
# NuGet Packages
|
||||
*.nupkg
|
||||
# The packages folder can be ignored because of Package Restore
|
||||
**/[Pp]ackages/*
|
||||
# except build/, which is used as an MSBuild target.
|
||||
!**/[Pp]ackages/build/
|
||||
# Uncomment if necessary however generally it will be regenerated when needed
|
||||
#!**/[Pp]ackages/repositories.config
|
||||
# NuGet v3's project.json files produces more ignorable files
|
||||
*.nuget.props
|
||||
*.nuget.targets
|
||||
|
||||
# Microsoft Azure Build Output
|
||||
csx/
|
||||
*.build.csdef
|
||||
|
||||
# Microsoft Azure Emulator
|
||||
ecf/
|
||||
rcf/
|
||||
|
||||
# Windows Store app package directories and files
|
||||
AppPackages/
|
||||
BundleArtifacts/
|
||||
Package.StoreAssociation.xml
|
||||
_pkginfo.txt
|
||||
*.appx
|
||||
|
||||
# Visual Studio cache files
|
||||
# files ending in .cache can be ignored
|
||||
*.[Cc]ache
|
||||
# but keep track of directories ending in .cache
|
||||
!*.[Cc]ache/
|
||||
|
||||
# Others
|
||||
ClientBin/
|
||||
~$*
|
||||
*~
|
||||
*.dbmdl
|
||||
*.dbproj.schemaview
|
||||
*.jfm
|
||||
*.pfx
|
||||
*.publishsettings
|
||||
orleans.codegen.cs
|
||||
|
||||
# Including strong name files can present a security risk
|
||||
# (https://github.com/github/gitignore/pull/2483#issue-259490424)
|
||||
#*.snk
|
||||
|
||||
# Since there are multiple workflows, uncomment next line to ignore bower_components
|
||||
# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
|
||||
#bower_components/
|
||||
|
||||
# RIA/Silverlight projects
|
||||
Generated_Code/
|
||||
|
||||
# Backup & report files from converting an old project file
|
||||
# to a newer Visual Studio version. Backup files are not needed,
|
||||
# because we have git ;-)
|
||||
_UpgradeReport_Files/
|
||||
Backup*/
|
||||
UpgradeLog*.XML
|
||||
UpgradeLog*.htm
|
||||
ServiceFabricBackup/
|
||||
*.rptproj.bak
|
||||
|
||||
# SQL Server files
|
||||
*.mdf
|
||||
*.ldf
|
||||
*.ndf
|
||||
|
||||
# Business Intelligence projects
|
||||
*.rdl.data
|
||||
*.bim.layout
|
||||
*.bim_*.settings
|
||||
*.rptproj.rsuser
|
||||
|
||||
# Microsoft Fakes
|
||||
FakesAssemblies/
|
||||
|
||||
# GhostDoc plugin setting file
|
||||
*.GhostDoc.xml
|
||||
|
||||
# Node.js Tools for Visual Studio
|
||||
.ntvs_analysis.dat
|
||||
node_modules/
|
||||
|
||||
# Visual Studio 6 build log
|
||||
*.plg
|
||||
|
||||
# Visual Studio 6 workspace options file
|
||||
*.opt
|
||||
|
||||
# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
|
||||
*.vbw
|
||||
|
||||
# Visual Studio LightSwitch build output
|
||||
**/*.HTMLClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/GeneratedArtifacts
|
||||
**/*.DesktopClient/ModelManifest.xml
|
||||
**/*.Server/GeneratedArtifacts
|
||||
**/*.Server/ModelManifest.xml
|
||||
_Pvt_Extensions
|
||||
|
||||
# Paket dependency manager
|
||||
.paket/paket.exe
|
||||
paket-files/
|
||||
|
||||
# FAKE - F# Make
|
||||
.fake/
|
||||
|
||||
# JetBrains Rider
|
||||
.idea/
|
||||
*.sln.iml
|
||||
|
||||
# CodeRush
|
||||
.cr/
|
||||
|
||||
# Python Tools for Visual Studio (PTVS)
|
||||
__pycache__/
|
||||
*.pyc
|
||||
|
||||
# Cake - Uncomment if you are using it
|
||||
# tools/**
|
||||
# !tools/packages.config
|
||||
|
||||
# Tabs Studio
|
||||
*.tss
|
||||
|
||||
# Telerik's JustMock configuration file
|
||||
*.jmconfig
|
||||
|
||||
# BizTalk build output
|
||||
*.btp.cs
|
||||
*.btm.cs
|
||||
*.odx.cs
|
||||
*.xsd.cs
|
||||
|
||||
# OpenCover UI analysis results
|
||||
OpenCover/
|
||||
|
||||
# Azure Stream Analytics local run output
|
||||
ASALocalRun/
|
||||
|
||||
# MSBuild Binary and Structured Log
|
||||
*.binlog
|
||||
|
||||
# NVidia Nsight GPU debugger configuration file
|
||||
*.nvuser
|
||||
|
||||
# MFractors (Xamarin productivity tool) working folder
|
||||
.mfractor/
|
127
README.md
127
README.md
|
@ -1,64 +1,103 @@
|
|||
---
|
||||
page_type: sample
|
||||
languages:
|
||||
- csharp
|
||||
products:
|
||||
- dotnet
|
||||
description: "Add 150 character max description"
|
||||
urlFragment: "update-this-to-unique-url-stub"
|
||||
---
|
||||
# Logarithmic Reinforcement Learning
|
||||
|
||||
# Official Microsoft Sample
|
||||
This repository hosts sample code for the NeurIPS 2019 paper: [van Seijen, Fatemi, Tavakoli (2019)][log_rl].
|
||||
|
||||
<!--
|
||||
Guidelines on README format: https://review.docs.microsoft.com/help/onboard/admin/samples/concepts/readme-template?branch=master
|
||||
We provide code for the linear experiments of the paper as well as the deep RL Atari examples (LogDQN).
|
||||
|
||||
Guidance on onboarding samples to docs.microsoft.com/samples: https://review.docs.microsoft.com/help/onboard/admin/samples/process/onboarding?branch=master
|
||||
## For the license please see LICENSE.
|
||||
|
||||
Taxonomies for products and languages: https://review.docs.microsoft.com/new-hope/information-architecture/metadata/taxonomies?branch=master
|
||||
-->
|
||||
The code for LogDQN has been developed by [Arash Tavakoli] and the code for the linear experiments has been developed by [Harm van Seijen].
|
||||
|
||||
Give a short description for your sample here. What does it do and why is it important?
|
||||
## Citing
|
||||
|
||||
## Contents
|
||||
If you use this research in your work, please cite the accompanying [paper][log_rl]:
|
||||
|
||||
Outline the file contents of the repository. It helps users navigate the codebase, build configuration and any related assets.
|
||||
```
|
||||
@inproceedings{vanseijen2019logrl,
|
||||
title={Using a Logarithmic Mapping to Enable Lower Discount Factors in Reinforcement Learning},
|
||||
author={van Seijen, Harm and
|
||||
Fatemi, Mehdi and
|
||||
Tavakoli, Arash},
|
||||
booktitle={Advances in Neural Information Processing Systems},
|
||||
year={2019}
|
||||
}
|
||||
```
|
||||
|
||||
| File/folder | Description |
|
||||
|-------------------|--------------------------------------------|
|
||||
| `src` | Sample source code. |
|
||||
| `.gitignore` | Define what to ignore at commit time. |
|
||||
| `CHANGELOG.md` | List of changes to the sample. |
|
||||
| `CONTRIBUTING.md` | Guidelines for contributing to the sample. |
|
||||
| `README.md` | This README file. |
|
||||
| `LICENSE` | The license for the sample. |
|
||||
## Linear Experiments
|
||||
|
||||
## Prerequisites
|
||||
First navigate to `linear_experiments` folder.
|
||||
|
||||
Outline the required components and tools that a user might need to have on their machine in order to run the sample. This can be anything from frameworks, SDKs, OS versions or IDE releases.
|
||||
To create result-files:
|
||||
```
|
||||
python main
|
||||
```
|
||||
To visualize result-files:
|
||||
```
|
||||
python show_results
|
||||
```
|
||||
|
||||
## Setup
|
||||
With the default settings (i.e., keeping `main.py` unchanged), a scan over different gamma values is performed for a tile-width of 2 for a version of Q-learning without a logarithmic mapping.
|
||||
|
||||
Explain how to prepare the sample once the user clones or downloads the repository. The section should outline every step necessary to install dependencies and set up any settings (for example, API keys and output folders).
|
||||
All experimental settings can be found at the top of the main.py file.
|
||||
To run the logarithmic-mapping version of Q-learning, set:
|
||||
```
|
||||
agent_settings['log_mapping'] = True
|
||||
```
|
||||
|
||||
## Runnning the sample
|
||||
Results of the full scans are provided. To visualize these results for regular Q-learning or logarithmic Q-leearning, set filename in `show_results.py` to `full_scan_reg` or `full_scan_log`, respectively.
|
||||
|
||||
Outline step-by-step instructions to execute the sample and see its output. Include steps for executing the sample from the IDE, starting specific services in the Azure portal or anything related to the overall launch of the code.
|
||||
|
||||
## Key concepts
|
||||
## Logarithmic Deep Q-Network (LogDQN)
|
||||
|
||||
Provide users with more context on the tools and services used in the sample. Explain some of the code that is being used and how services interact with each other.
|
||||
This part presents an implementation of LogDQN from [van Seijen, Fatemi, Tavakoli (2019)][log_rl].
|
||||
|
||||
## Contributing
|
||||
### Instructions
|
||||
|
||||
This project welcomes contributions and suggestions. Most contributions require you to agree to a
|
||||
Contributor License Agreement (CLA) declaring that you have the right to, and actually do, grant us
|
||||
the rights to use your contribution. For details, visit https://cla.opensource.microsoft.com.
|
||||
Our implementation of LogDQN builds on Dopamine ([Castro et al., 2018][dopamine_paper]), a Tensorflow-based research framework for fast prototyping of reinforcement learning algorithms.
|
||||
|
||||
When you submit a pull request, a CLA bot will automatically determine whether you need to provide
|
||||
a CLA and decorate the PR appropriately (e.g., status check, comment). Simply follow the instructions
|
||||
provided by the bot. You will only need to do this once across all repos using our CLA.
|
||||
Follow the instructions below to install the LogDQN package along with a compatible version of Dopamine and their dependencies inside a conda environment.
|
||||
|
||||
This project has adopted the [Microsoft Open Source Code of Conduct](https://opensource.microsoft.com/codeofconduct/).
|
||||
For more information see the [Code of Conduct FAQ](https://opensource.microsoft.com/codeofconduct/faq/) or
|
||||
contact [opencode@microsoft.com](mailto:opencode@microsoft.com) with any additional questions or comments.
|
||||
First install [Anaconda](https://docs.anaconda.com/anaconda/install/), and then proceed below.
|
||||
|
||||
```
|
||||
conda create --name log-env python=3.6
|
||||
conda activate log-env
|
||||
```
|
||||
|
||||
#### Ubuntu
|
||||
|
||||
```
|
||||
sudo apt-get update && sudo apt-get install cmake zlib1g-dev
|
||||
pip install absl-py atari-py gin-config gym opencv-python tensorflow==1.15rc3
|
||||
pip install git+git://github.com/google/dopamine.git@a59d5d6c68b1a6e790d5808c550ae0f51d3e85ce
|
||||
```
|
||||
|
||||
Finally, install the LogDQN package from source.
|
||||
|
||||
```
|
||||
cd log_dqn_experiments/log_rl
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Training an agent
|
||||
|
||||
To run a LogDQN agent, navigate to `log_dqn_experiments` and run the following:
|
||||
|
||||
```
|
||||
python -um log_dqn.train_atari \
|
||||
--agent_name=log_dqn \
|
||||
--base_dir=/tmp/log_dqn \
|
||||
--gin_files='log_dqn/log_dqn.gin' \
|
||||
--gin_bindings="Runner.game_name = \"Asterix\"" \
|
||||
--gin_bindings="LogDQNAgent.tf_device=\"/gpu:0\""
|
||||
```
|
||||
|
||||
You can set `LogDQNAgent.tf_device` to `/cpu:*` for a non-GPU version.
|
||||
|
||||
|
||||
|
||||
[log_rl]: https://arxiv.org/abs/1906.00572
|
||||
[dopamine_paper]: https://arxiv.org/abs/1812.06110
|
||||
[dqn]: https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf
|
||||
[Harm van Seijen]: mailto://Harm.vanSeijen@microsoft.com
|
||||
[Arash Tavakoli]: mailto://a.tavakoli16@imperial.ac.uk
|
|
@ -0,0 +1,198 @@
|
|||
import numpy as np
|
||||
|
||||
class Agent(object):
|
||||
|
||||
def __init__(self, settings, domain):
|
||||
|
||||
self.init_beta_reg = settings['initial_beta_reg']
|
||||
self.init_beta_log = settings['initial_beta_log']
|
||||
self.init_alpha = settings['initial_alpha']
|
||||
self.final_beta_reg = settings['final_beta_reg']
|
||||
self.final_beta_log = settings['final_beta_log']
|
||||
self.final_alpha = settings['final_alpha']
|
||||
self.decay_period = settings['decay_period']
|
||||
|
||||
self.decay_beta_reg = (self.final_beta_reg/self.init_beta_reg)**(1./(self.decay_period-1))
|
||||
self.decay_beta_log = (self.final_beta_log/self.init_beta_log)**(1./(self.decay_period-1))
|
||||
self.decay_alpha = (self.final_alpha/self.init_alpha)**(1./(self.decay_period-1))
|
||||
|
||||
self.theta_init_reg = settings['theta_init_reg']
|
||||
self.log_mapping = settings['log_mapping']
|
||||
self.c = settings['c']
|
||||
self.h = settings['h']
|
||||
self.Q_init_log = settings['Q_init_log']
|
||||
self.domain = domain
|
||||
self.num_states = domain.get_num_states()
|
||||
self.max_return = settings['max_return']
|
||||
|
||||
|
||||
def perform_update_sweep(self):
|
||||
|
||||
num_samples = (self.num_states) * 2
|
||||
samples = [None] * num_samples
|
||||
|
||||
i = 0
|
||||
for s in range(0, self.num_states):
|
||||
for a in [0, 1]:
|
||||
s2, r = self.domain.take_action(s, a)
|
||||
samples[i] = (s, a, s2, r)
|
||||
i += 1
|
||||
|
||||
np.random.shuffle(samples)
|
||||
|
||||
for sample in samples:
|
||||
self._single_update(sample)
|
||||
|
||||
self.beta_log = max(self.beta_log*self.decay_beta_log, self.final_beta_log)
|
||||
self.beta_reg = max(self.beta_reg*self.decay_beta_reg, self.final_beta_reg)
|
||||
self.alpha = max(self.alpha*self.decay_alpha, self.final_alpha)
|
||||
return
|
||||
|
||||
|
||||
def initialize(self):
|
||||
self.num_features = self.domain.get_num_features()
|
||||
self.gamma = self.domain.get_gamma()
|
||||
self.qstar = self.domain.get_qstar()
|
||||
f = self.domain.get_features(0)
|
||||
|
||||
self.alpha = self.init_alpha
|
||||
self.beta_reg = self.init_beta_reg
|
||||
self.beta_log = self.init_beta_log
|
||||
if self.log_mapping:
|
||||
self.d = -self.c*np.log(self.Q_init_log + self.gamma**self.h)
|
||||
self.theta_min = np.zeros([self.num_features, 2])
|
||||
self.theta_plus = np.zeros([self.num_features, 2])
|
||||
f = self.domain.get_features(0)
|
||||
v0_log = np.dot(self.theta_plus[:, 0], f) - np.dot(self.theta_min[:, 0], f)
|
||||
v0 = self._f_inverse(v0_log)
|
||||
else:
|
||||
self.theta = np.ones([self.num_features, 2]) * self.theta_init_reg
|
||||
f = self.domain.get_features(0)
|
||||
v0 = np.dot(self.theta[:, 0], f)
|
||||
print("reg. v(0): {:1.2f}".format(v0))
|
||||
|
||||
def _single_update(self, sample):
|
||||
s, a, s2, r = sample
|
||||
f = self.domain.get_features(s)
|
||||
f2 = self.domain.get_features(s2)
|
||||
|
||||
|
||||
if self.log_mapping:
|
||||
if r >= 0:
|
||||
r_plus = r
|
||||
r_min = 0
|
||||
else:
|
||||
r_plus = 0
|
||||
r_min = -r
|
||||
|
||||
#compute_optimal action
|
||||
q_next_0 = self._f_inverse(np.dot(self.theta_plus[:, 0], f2)) \
|
||||
- self._f_inverse(np.dot(self.theta_min[:, 0], f2))
|
||||
q_next_1 = self._f_inverse(np.dot(self.theta_plus[:, 1], f2)) \
|
||||
- self._f_inverse(np.dot(self.theta_min[:, 1], f2))
|
||||
if q_next_0 > q_next_1:
|
||||
a_star = 0
|
||||
else:
|
||||
a_star = 1
|
||||
|
||||
|
||||
# plus-network update
|
||||
if s2 == -1: # terinal state
|
||||
v_next_log_plus = self._f(0.0)
|
||||
else:
|
||||
if a_star == 0:
|
||||
v_next_log_plus = np.dot(self.theta_plus[:, 0], f2)
|
||||
else:
|
||||
v_next_log_plus = np.dot(self.theta_plus[:, 1], f2)
|
||||
|
||||
q_sa_log_plus = np.dot(self.theta_plus[:, a], f)
|
||||
q_sa_plus = self._f_inverse(q_sa_log_plus)
|
||||
v_next_plus = self._f_inverse(v_next_log_plus)
|
||||
update_target_plus = min(r_plus + self.gamma * v_next_plus, self.max_return)
|
||||
update_target_new_plus = q_sa_plus + self.beta_reg * (update_target_plus - q_sa_plus)
|
||||
TD_error_log_plus = self._f(update_target_new_plus) - q_sa_log_plus
|
||||
self.theta_plus[:, a] += self.beta_log * TD_error_log_plus * f
|
||||
|
||||
# min-network update
|
||||
if s2 == -1: # terinal state
|
||||
v_next_log_min = self._f(0.0)
|
||||
else:
|
||||
if a_star == 0:
|
||||
v_next_log_min = np.dot(self.theta_min[:, 0], f2)
|
||||
else:
|
||||
v_next_log_min = np.dot(self.theta_min[:, 1], f2)
|
||||
q_sa_log_min = np.dot(self.theta_min[:, a], f)
|
||||
q_sa_min = self._f_inverse(q_sa_log_min)
|
||||
v_next_min = self._f_inverse(v_next_log_min)
|
||||
update_target_min = min(r_min + self.gamma * v_next_min, self.max_return)
|
||||
update_target_new_min = q_sa_min + self.beta_reg * (update_target_min - q_sa_min)
|
||||
TD_error_log_min = self._f(update_target_new_min) - q_sa_log_min
|
||||
self.theta_min[:, a] += self.beta_log * TD_error_log_min * f
|
||||
|
||||
if (self.theta_min > 100000).any():
|
||||
print('LARGE VALUE detected!')
|
||||
if np.isinf(self.theta_min).any():
|
||||
print('INF dectected!')
|
||||
elif np.isnan(self.theta_min).any():
|
||||
print('NAN dectected!')
|
||||
if (self.theta_plus > 100000).any():
|
||||
print('LARGE VALUE detected!')
|
||||
if np.isinf(self.theta_plus).any():
|
||||
print('INF dectected!')
|
||||
elif np.isnan(self.theta_plus).any():
|
||||
print('NAN dectected!')
|
||||
|
||||
else:
|
||||
# compute update target
|
||||
if s2 == -1: # terinal state
|
||||
v_next = 0.0
|
||||
else:
|
||||
q0_next = np.dot(self.theta[:, 0], f2)
|
||||
q1_next = np.dot(self.theta[:, 1], f2)
|
||||
v_next = max(q0_next, q1_next)
|
||||
q_sa = np.dot(self.theta[:, a], f)
|
||||
update_target = min(r + self.gamma * v_next, self.max_return)
|
||||
TD_error = update_target - q_sa
|
||||
self.theta[:, a] += self.alpha * TD_error * f
|
||||
|
||||
if (self.theta > 100000).any():
|
||||
print('LARGE VALUE detected!')
|
||||
if np.isinf(self.theta).any():
|
||||
print('INF dectected!')
|
||||
elif np.isnan(self.theta).any():
|
||||
print('NAN dectected!')
|
||||
|
||||
def evaluate(self):
|
||||
|
||||
q = np.zeros([self.num_states,2])
|
||||
|
||||
if self.log_mapping:
|
||||
for s in range(self.num_states):
|
||||
f = self.domain.get_features(s)
|
||||
q[s, 0] = self._f_inverse(np.dot(self.theta_plus[:, 0], f)) - self._f_inverse(np.dot(self.theta_min[:, 0], f))
|
||||
q[s, 1] = self._f_inverse(np.dot(self.theta_plus[:, 1], f)) - self._f_inverse(np.dot(self.theta_min[:, 1], f))
|
||||
else:
|
||||
for s in range(self.num_states):
|
||||
f = self.domain.get_features(s)
|
||||
q[s,0] = np.dot(self.theta[:, 0], f)
|
||||
q[s,1] = np.dot(self.theta[:, 1], f)
|
||||
|
||||
i = np.argmax(q,axis=1)
|
||||
|
||||
k = np.argmax(self.qstar)
|
||||
if (i == k).all():
|
||||
success = 1
|
||||
else:
|
||||
success = 0
|
||||
|
||||
return success
|
||||
|
||||
def _f(self, x):
|
||||
return self.c * np.log(x + self.gamma ** self.h) + self.d
|
||||
#return self.c * np.log(x)
|
||||
|
||||
|
||||
def _f_inverse(self,x):
|
||||
return np.exp((x - self.d )/ self.c) - self.gamma ** self.h
|
||||
#return np.exp(x/self.c)
|
||||
|
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
{"num_datapoints": 1100, "log_mapping": true, "widths": [1, 2, 3, 5], "gammas": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.94, 0.96, 0.98, 0.99], "window_size": 10000, "num_sweeps": 110000}
|
Двоичный файл не отображается.
|
@ -0,0 +1 @@
|
|||
{"log_mapping": false, "widths": [1, 2, 3, 5], "num_sweeps": 110000, "gammas": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.94, 0.96, 0.98, 0.99], "num_datapoints": 1100, "window_size": 10000}
|
|
@ -0,0 +1,128 @@
|
|||
import numpy as np
|
||||
import math
|
||||
|
||||
class Domain:
|
||||
# States [-1] <- 0 <-> 1 <-> 2 <-> 3 <-> ... <-> n -1 -> [-1] (-1 indicates terminal state)
|
||||
# Action 0: towards state 0
|
||||
# Action 1: toward state n-1
|
||||
# Left terminal state: max reward
|
||||
# Right terminal state: min reward
|
||||
|
||||
|
||||
def __init__(self, settings):
|
||||
self.gamma = None
|
||||
self.qstar = None
|
||||
self.num_states = settings['num_states']
|
||||
self.min_reward = settings['min_reward']
|
||||
self.max_reward = settings['max_reward']
|
||||
self.stoch = settings['stochasticity']
|
||||
return
|
||||
|
||||
def _compute_num_features(self):
|
||||
self.num_tiles_per_tiling = math.ceil(self.num_states / self.tile_width) + 1
|
||||
num_features = self.num_tilings * self.num_tiles_per_tiling
|
||||
return num_features
|
||||
|
||||
def get_num_features(self):
|
||||
return self.num_features
|
||||
|
||||
def get_qstar(self):
|
||||
return self.qstar
|
||||
|
||||
def set_tile_width(self, width):
|
||||
self.tile_width = width
|
||||
|
||||
def get_num_states(self):
|
||||
return self.num_states
|
||||
|
||||
def get_gamma(self):
|
||||
return self.gamma
|
||||
|
||||
def set_gamma(self,gamma):
|
||||
self.gamma = gamma
|
||||
self.qstar = self._compute_qstar()
|
||||
|
||||
def init_representation(self):
|
||||
self.num_tilings = self.tile_width
|
||||
self.feature_value = math.sqrt(1/self.num_tilings) # value of active feature tilings-representation
|
||||
self.num_features = self._compute_num_features()
|
||||
return
|
||||
|
||||
def take_action(self, state, action):
|
||||
assert(state >= 0)
|
||||
|
||||
if np.random.random() < self.stoch:
|
||||
if action == 0:
|
||||
action = 1
|
||||
else:
|
||||
action = 0
|
||||
|
||||
if (state == self.num_states-1) and (action == 1):
|
||||
next_state = -1
|
||||
reward = self.min_reward
|
||||
elif (state == 0) and (action == 0):
|
||||
next_state = -1
|
||||
reward = self.max_reward
|
||||
else:
|
||||
if action == 0:
|
||||
next_state = state-1
|
||||
else:
|
||||
next_state = state +1
|
||||
reward = 0
|
||||
return next_state, reward
|
||||
|
||||
def get_features(self, state):
|
||||
|
||||
if state < -1 or state >= self.num_states:
|
||||
print('state out-of-bounds!')
|
||||
assert(False)
|
||||
|
||||
features = np.zeros(self.num_features)
|
||||
if state != -1:
|
||||
features = np.zeros(self.num_features)
|
||||
for t in range(self.num_tilings):
|
||||
tilde_id = math.floor((state + t)/self.tile_width)
|
||||
features[tilde_id + t * self.num_tiles_per_tiling] = self.feature_value
|
||||
return features
|
||||
|
||||
def get_qstar_plus_min(self):
|
||||
q = np.zeros([self.num_states,2])
|
||||
v = np.zeros(self.num_states+2)
|
||||
v[0] = 0
|
||||
v[-1] = self.min_reward / self.gamma
|
||||
|
||||
for i in range(10000):
|
||||
v[1:-1] = q[:,0]
|
||||
q[:, 0] = (1-self.stoch)*self.gamma*v[:-2] + self.stoch * self.gamma*v[2:]
|
||||
q[:, 1] = (1-self.stoch)*self.gamma*v[2:] + self.stoch * self.gamma*v[:-2]
|
||||
|
||||
qstar_min = -1*q
|
||||
|
||||
q = np.zeros([self.num_states,2])
|
||||
v = np.zeros(self.num_states+2)
|
||||
v[0] = self.max_reward / self.gamma
|
||||
v[-1] = 0
|
||||
|
||||
for i in range(10000):
|
||||
v[1:-1] = q[:,0]
|
||||
q[:, 0] = (1-self.stoch)*self.gamma*v[:-2] + self.stoch * self.gamma*v[2:]
|
||||
q[:, 1] = (1-self.stoch)*self.gamma*v[2:] + self.stoch * self.gamma*v[:-2]
|
||||
|
||||
qstar_plus = q
|
||||
|
||||
return qstar_plus, qstar_min
|
||||
|
||||
|
||||
def _compute_qstar(self):
|
||||
q = np.zeros([self.num_states,2])
|
||||
v = np.zeros(self.num_states+2)
|
||||
v[0] = self.max_reward / self.gamma
|
||||
v[-1] = self.min_reward / self.gamma
|
||||
|
||||
for i in range(10000):
|
||||
v[1:-1] = np.max(q, 1)
|
||||
q[:, 0] = (1-self.stoch)*self.gamma*v[:-2] + self.stoch * self.gamma*v[2:]
|
||||
q[:, 1] = (1-self.stoch)*self.gamma*v[2:] + self.stoch * self.gamma*v[:-2]
|
||||
|
||||
return q
|
||||
|
|
@ -0,0 +1,92 @@
|
|||
import numpy as np
|
||||
import json
|
||||
import time
|
||||
from domain import Domain
|
||||
from agent import Agent
|
||||
|
||||
# Settings ##########################################################################
|
||||
|
||||
# Experiment settings
|
||||
filename = 'default_reg'
|
||||
num_sweeps = 110000
|
||||
window_size = 10000 # number of sweeps that will be averaged over to get initial and final performance
|
||||
num_datapoints = 1100
|
||||
num_runs = 1
|
||||
gammas = [0.1, 0.8, 0.85, 0.9, 0.94, 0.96, 0.98, 0.99] ########
|
||||
#gammas = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.85, 0.9, 0.94, 0.96, 0.98, 0.99] ########
|
||||
widths = [2]
|
||||
#widths = [1, 2, 3, 5] #############
|
||||
|
||||
# Domain settings
|
||||
domain_settings = {}
|
||||
domain_settings['stochasticity'] = 0.25
|
||||
domain_settings['num_states'] = 50
|
||||
domain_settings['min_reward'] = -1.0
|
||||
domain_settings['max_reward'] = 1.0
|
||||
|
||||
# Agent settings
|
||||
agent_settings = {}
|
||||
agent_settings['log_mapping'] = False
|
||||
agent_settings['initial_beta_reg'] = 1 # used if log_mapping = True
|
||||
agent_settings['initial_beta_log'] = 1 # used if log_mapping = True
|
||||
agent_settings['initial_alpha'] = 1 # used if log_mapping = False
|
||||
agent_settings['final_beta_reg'] = 0.1
|
||||
agent_settings['final_beta_log'] = 0.01
|
||||
agent_settings['final_alpha'] = 0.001
|
||||
agent_settings['decay_period'] = 10000 # number of sweeps over which step_size is annealed (from initial to final)
|
||||
agent_settings['c'] = 1 # used if log_mapping = True
|
||||
agent_settings['h'] = 200 # used if log_mapping = True
|
||||
agent_settings['Q_init_log'] = 0 # used if log_mapping = True
|
||||
agent_settings['theta_init_reg'] = 0 # used if log_mapping = False
|
||||
agent_settings['max_return'] = domain_settings['max_reward']
|
||||
# max_return is used to bound the update-target (to improve stability)
|
||||
|
||||
#############################################################################################
|
||||
if num_datapoints > num_sweeps:
|
||||
num_datapoints = num_sweeps
|
||||
num_gammas = len(gammas)
|
||||
num_widths = len(widths)
|
||||
eval_interval = num_sweeps // num_datapoints
|
||||
my_domain = Domain(domain_settings)
|
||||
my_agent = Agent(agent_settings, my_domain)
|
||||
|
||||
|
||||
start = time.time()
|
||||
avg_performance = np.zeros([num_datapoints,num_gammas, num_widths])
|
||||
for run in range(num_runs):
|
||||
for width_index in range(num_widths):
|
||||
width = widths[width_index]
|
||||
my_domain.set_tile_width(width)
|
||||
my_domain.init_representation()
|
||||
for gamma_index in range(num_gammas):
|
||||
gamma = gammas[gamma_index]
|
||||
print('***** run {}, width: {}, gamma: {} *****'.format(run +1, width, gamma))
|
||||
my_domain.set_gamma(gamma)
|
||||
my_agent.initialize()
|
||||
performance = np.zeros(num_datapoints)
|
||||
eval_no = 0
|
||||
for sweep in range(num_sweeps):
|
||||
my_agent.perform_update_sweep()
|
||||
if (sweep % eval_interval == 0) & (eval_no < num_datapoints):
|
||||
performance[eval_no] = my_agent.evaluate()
|
||||
eval_no += 1
|
||||
mean = np.mean(performance)
|
||||
print('mean = {}'.format(mean))
|
||||
alpha = 1/(run+1)
|
||||
avg_performance[:,gamma_index, width_index] = (1-alpha)*avg_performance[:,gamma_index, width_index] + alpha*performance
|
||||
end = time.time()
|
||||
print("time: {}s".format(end-start))
|
||||
|
||||
# Store results + some essential settings
|
||||
settings = {}
|
||||
settings['log_mapping'] = agent_settings['log_mapping']
|
||||
settings['gammas'] = gammas
|
||||
settings['widths'] = widths
|
||||
settings['num_sweeps'] = num_sweeps
|
||||
settings['window_size'] = window_size
|
||||
settings['num_datapoints'] = num_datapoints
|
||||
with open('data/' + filename + '_settings.txt', 'w') as json_file:
|
||||
json.dump(settings, json_file)
|
||||
np.save('data/' + filename + '_results.npy', avg_performance)
|
||||
|
||||
print('Done.')
|
|
@ -0,0 +1,71 @@
|
|||
import numpy as np
|
||||
import math
|
||||
import json
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
#filename = 'full_scan_log'
|
||||
filename = 'full_scan_reg'
|
||||
|
||||
results = np.load('data/' + filename + '_results.npy')
|
||||
with open('data/' + filename + '_settings.txt') as f:
|
||||
settings = json.load(f)
|
||||
gammas = settings['gammas']
|
||||
widths = settings['widths']
|
||||
num_sweeps = settings['num_sweeps']
|
||||
window_size_sweeps = settings['window_size']
|
||||
num_datapoints = settings['num_datapoints']
|
||||
window_size_datapoints = int(math.floor(window_size_sweeps/num_sweeps*num_datapoints))
|
||||
|
||||
print(results.shape)
|
||||
num_datapoints = results.shape[0]
|
||||
num_gammas = len(gammas)
|
||||
num_widths = len(widths)
|
||||
|
||||
|
||||
font_size = 20
|
||||
font_size_legend = 20
|
||||
font_size_title = 20
|
||||
|
||||
plt.rc('font', size=font_size) # controls default text sizes
|
||||
plt.rc('axes', titlesize=font_size_title) # fontsize of the axes title
|
||||
plt.rc('axes', labelsize=font_size) # fontsize of the x and y labels
|
||||
plt.rc('xtick', labelsize=font_size) # fontsize of the tick labels
|
||||
plt.rc('ytick', labelsize=font_size) # fontsize of the tick labels
|
||||
plt.rc('legend', fontsize=font_size_legend) # legend fontsize
|
||||
plt.rc('figure', titlesize=font_size) # fontsize of the figure title
|
||||
|
||||
plt.figure()
|
||||
for width_id in range(num_widths):
|
||||
mean_performance = np.mean(results[:window_size_datapoints, :, width_id], axis=0)
|
||||
plt.plot(gammas, mean_performance, linewidth=2.0, label='w: {}'.format(widths[width_id]))
|
||||
plt.ylabel('performance')
|
||||
plt.xlabel('$\gamma$')
|
||||
plt.title('early performance')
|
||||
plt.legend(loc='lower left')
|
||||
plt.axis([0.1, 1.0, -0.1, 1.1])
|
||||
|
||||
plt.figure()
|
||||
for width_id in range(num_widths):
|
||||
mean_performance = np.mean(results[num_datapoints - window_size_datapoints :, :, width_id], axis=0)
|
||||
plt.plot(gammas, mean_performance, linewidth=2.0, label='w: {}'.format(widths[width_id]))
|
||||
plt.ylabel('performance')
|
||||
plt.xlabel('$\gamma$')
|
||||
plt.title('late performance')
|
||||
plt.legend(loc='lower left')
|
||||
plt.axis([0.1, 1.0, -0.1, 1.1])
|
||||
|
||||
|
||||
# ### Individual plot ##########
|
||||
# plt.figure()
|
||||
# width_id = 0
|
||||
# gamma_id = 0
|
||||
# eval_interval = num_sweeps // num_datapoints
|
||||
# sweeps = [i*eval_interval for i in range(1,num_datapoints+1)]
|
||||
# plt.plot(sweeps, results[:,gamma_id,width_id], linewidth=2.0, label='gamma: {}, w: {}'.format(gammas[gamma_id], widths[width_id]))
|
||||
# plt.ylabel('performance')
|
||||
# plt.xlabel('#sweeps')
|
||||
# plt.legend(loc='lower right')
|
||||
|
||||
plt.show()
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
name = "log_dqn"
|
|
@ -0,0 +1,118 @@
|
|||
"""The standard DQN replay memory modified to support float64 rewards.
|
||||
|
||||
This script modifies the Dopamine's implementation of an out-of-graph
|
||||
replay memory + in-graph wrapper to support float64 formatted rewards.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import gzip
|
||||
import math
|
||||
import os
|
||||
import pickle
|
||||
|
||||
from dopamine.replay_memory import circular_replay_buffer
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import gin.tf
|
||||
|
||||
|
||||
class OutOfGraphReplayBuffer64(circular_replay_buffer.OutOfGraphReplayBuffer):
|
||||
|
||||
def __init__(self,
|
||||
observation_shape,
|
||||
stack_size,
|
||||
replay_capacity,
|
||||
batch_size,
|
||||
update_horizon=1,
|
||||
gamma=0.99,
|
||||
max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
|
||||
extra_storage_types=None,
|
||||
observation_dtype=np.uint8):
|
||||
super(OutOfGraphReplayBuffer64, self).__init__(
|
||||
observation_shape=observation_shape,
|
||||
stack_size=stack_size,
|
||||
replay_capacity=replay_capacity,
|
||||
batch_size=batch_size,
|
||||
update_horizon=update_horizon,
|
||||
gamma=gamma,
|
||||
max_sample_attempts=max_sample_attempts,
|
||||
extra_storage_types=extra_storage_types,
|
||||
observation_dtype=observation_dtype)
|
||||
self._cumulative_discount_vector = np.array(
|
||||
[self._gamma**np.float64(n) for n in range(update_horizon)],
|
||||
dtype=np.float64)
|
||||
|
||||
def get_storage_signature(self):
|
||||
storage_elements = [
|
||||
circular_replay_buffer.ReplayElement('observation',
|
||||
self._observation_shape, self._observation_dtype),
|
||||
circular_replay_buffer.ReplayElement('action', (), np.int32),
|
||||
circular_replay_buffer.ReplayElement('reward', (), np.float64),
|
||||
circular_replay_buffer.ReplayElement('terminal', (), np.uint8)
|
||||
]
|
||||
|
||||
for extra_replay_element in self._extra_storage_types:
|
||||
storage_elements.append(extra_replay_element)
|
||||
return storage_elements
|
||||
|
||||
def get_transition_elements(self, batch_size=None):
|
||||
batch_size = self._batch_size if batch_size is None else batch_size
|
||||
|
||||
transition_elements = [
|
||||
circular_replay_buffer.ReplayElement('state',
|
||||
(batch_size,) + self._state_shape, self._observation_dtype),
|
||||
circular_replay_buffer.ReplayElement('action', (batch_size,), np.int32),
|
||||
circular_replay_buffer.ReplayElement('reward', (batch_size,), np.float64),
|
||||
circular_replay_buffer.ReplayElement('next_state',
|
||||
(batch_size,) + self._state_shape, self._observation_dtype),
|
||||
circular_replay_buffer.ReplayElement('terminal', (batch_size,), np.uint8),
|
||||
circular_replay_buffer.ReplayElement('indices', (batch_size,), np.int32)
|
||||
]
|
||||
for element in self._extra_storage_types:
|
||||
transition_elements.append(
|
||||
circular_replay_buffer.ReplayElement(element.name,
|
||||
(batch_size,) + tuple(element.shape), element.type))
|
||||
return transition_elements
|
||||
|
||||
|
||||
@gin.configurable(blacklist=['observation_shape', 'stack_size',
|
||||
'update_horizon', 'gamma'])
|
||||
class WrappedReplayBuffer64(circular_replay_buffer.WrappedReplayBuffer):
|
||||
|
||||
def __init__(self,
|
||||
observation_shape,
|
||||
stack_size,
|
||||
use_staging=True,
|
||||
replay_capacity=1000000,
|
||||
batch_size=32,
|
||||
update_horizon=1,
|
||||
gamma=0.99,
|
||||
wrapped_memory=None,
|
||||
max_sample_attempts=circular_replay_buffer.MAX_SAMPLE_ATTEMPTS,
|
||||
extra_storage_types=None,
|
||||
observation_dtype=np.uint8):
|
||||
if replay_capacity < update_horizon + 1:
|
||||
raise ValueError(
|
||||
'Update horizon ({}) should be significantly smaller '
|
||||
'than replay capacity ({}).'.format(update_horizon, replay_capacity))
|
||||
if not update_horizon >= 1:
|
||||
raise ValueError('Update horizon must be positive.')
|
||||
if not 0.0 <= gamma <= 1.0:
|
||||
raise ValueError('Discount factor (gamma) must be in [0, 1].')
|
||||
|
||||
self.batch_size = batch_size
|
||||
|
||||
if wrapped_memory is not None:
|
||||
self.memory = wrapped_memory
|
||||
else:
|
||||
self.memory = OutOfGraphReplayBuffer64(
|
||||
observation_shape, stack_size, replay_capacity, batch_size,
|
||||
update_horizon, gamma, max_sample_attempts,
|
||||
observation_dtype=observation_dtype,
|
||||
extra_storage_types=extra_storage_types)
|
||||
|
||||
self.create_sampling_ops(use_staging)
|
|
@ -0,0 +1,43 @@
|
|||
# Hyperparameters follow van Seijen, Fatemi, Tavakoli (2019).
|
||||
import dopamine.atari.run_experiment
|
||||
import log_dqn.log_dqn_agent
|
||||
import log_dqn.circular_replay_buffer_64
|
||||
import gin.tf.external_configurables
|
||||
|
||||
# LogDQN specific hyperparameters.
|
||||
LogDQNAgent.c = 0.5
|
||||
LogDQNAgent.k = 100
|
||||
LogDQNAgent.pos_q_init = 1.0
|
||||
LogDQNAgent.neg_q_init = 0.0
|
||||
LogDQNAgent.net_init_method = 'asym'
|
||||
LogDQNAgent.alpha = 0.00025 # alpha = beta_reg * beta_log
|
||||
|
||||
LogDQNAgent.gamma = 0.96
|
||||
LogDQNAgent.update_horizon = 1
|
||||
LogDQNAgent.min_replay_history = 20000 # agent steps
|
||||
LogDQNAgent.update_period = 4
|
||||
LogDQNAgent.target_update_period = 8000 # agent steps
|
||||
LogDQNAgent.clip_qt_max = True
|
||||
LogDQNAgent.epsilon_train = 0.01
|
||||
LogDQNAgent.epsilon_eval = 0.001
|
||||
LogDQNAgent.epsilon_decay_period = 250000 # agent steps
|
||||
LogDQNAgent.tf_device = '/gpu:0' # use '/cpu:*' for non-GPU version
|
||||
LogDQNAgent.loss_type = 'Huber'
|
||||
LogDQNAgent.optimizer = @tf.train.RMSPropOptimizer()
|
||||
|
||||
tf.train.RMSPropOptimizer.learning_rate = 0.0025 # beta_log
|
||||
tf.train.RMSPropOptimizer.decay = 0.95
|
||||
tf.train.RMSPropOptimizer.momentum = 0.0
|
||||
tf.train.RMSPropOptimizer.epsilon = 0.00001
|
||||
tf.train.RMSPropOptimizer.centered = True
|
||||
|
||||
Runner.game_name = 'Pong'
|
||||
# Sticky actions with probability 0.25, as suggested by (Machado et al., 2017).
|
||||
Runner.sticky_actions = True
|
||||
Runner.num_iterations = 200
|
||||
Runner.training_steps = 250000 # agent steps
|
||||
Runner.evaluation_steps = 125000 # agent steps
|
||||
Runner.max_steps_per_episode = 27000 # agent steps
|
||||
|
||||
WrappedReplayBuffer64.replay_capacity = 1000000
|
||||
WrappedReplayBuffer64.batch_size = 32
|
|
@ -0,0 +1,421 @@
|
|||
"""Compact implementation of a LogDQN agent.
|
||||
|
||||
Details in "Using a Logarithmic Mapping to Enable Lower Discount Factors
|
||||
in Reinforcement Learning" by van Seijen, Fatemi, Tavakoli (2019).
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
|
||||
|
||||
from dopamine.agents.dqn import dqn_agent
|
||||
from log_dqn import circular_replay_buffer_64
|
||||
import numpy as np
|
||||
import tensorflow as tf
|
||||
|
||||
import gin.tf
|
||||
|
||||
slim = tf.contrib.slim
|
||||
|
||||
|
||||
@gin.configurable
|
||||
class LogDQNAgent(dqn_agent.DQNAgent):
|
||||
"""An implementation of the LogDQN agent."""
|
||||
|
||||
def __init__(self,
|
||||
sess,
|
||||
num_actions,
|
||||
gamma=0.96,
|
||||
c=0.5,
|
||||
k=100,
|
||||
pos_q_init=1.0,
|
||||
neg_q_init=0.0,
|
||||
net_init_method='asym',
|
||||
alpha=0.00025,
|
||||
clip_qt_max=True,
|
||||
update_horizon=1,
|
||||
min_replay_history=20000,
|
||||
update_period=4,
|
||||
target_update_period=8000,
|
||||
epsilon_fn=dqn_agent.linearly_decaying_epsilon,
|
||||
epsilon_train=0.01,
|
||||
epsilon_eval=0.001,
|
||||
epsilon_decay_period=250000,
|
||||
tf_device='/cpu:*',
|
||||
use_staging=True,
|
||||
max_tf_checkpoints_to_keep=3,
|
||||
loss_type='Huber',
|
||||
optimizer=tf.train.RMSPropOptimizer(
|
||||
learning_rate=0.0025,
|
||||
decay=0.95,
|
||||
momentum=0.0,
|
||||
epsilon=0.00001,
|
||||
centered=True),
|
||||
summary_writer=None,
|
||||
summary_writing_frequency=500):
|
||||
"""Initializes the agent and constructs the components of its graph.
|
||||
|
||||
Args:
|
||||
sess: `tf.Session`, for executing ops.
|
||||
num_actions: int, number of actions the agent can take at any state.
|
||||
gamma: float, discount factor with the usual RL meaning.
|
||||
c: float, a hyperparameter of the logarithmic mapping approach.
|
||||
k: int, a hyperparameter of the logarithmic mapping approach.
|
||||
pos_q_init: float, used to evaluate 'd' in logarithmic mapping.
|
||||
neg_q_init: float, used to evaluate 'd' in logarithmic mapping.
|
||||
net_init_method: str, determines how to initialize the weights of the
|
||||
LogDQN network heads.
|
||||
alpha: float, effective step-size: alpha = beta_reg * beta_log.
|
||||
clip_qt_max: bool, when True clip the maximum target value.
|
||||
update_horizon: int, horizon at which updates are performed, the 'n' in
|
||||
n-step update.
|
||||
min_replay_history: int, number of transitions that should be experienced
|
||||
before the agent begins training its value function.
|
||||
update_period: int, period between LogDQN updates.
|
||||
target_update_period: int, update period for the target network.
|
||||
epsilon_fn: function expecting 4 parameters:
|
||||
(decay_period, step, warmup_steps, epsilon). This function should return
|
||||
the epsilon value used for exploration during training.
|
||||
epsilon_train: float, the value to which the agent's epsilon is eventually
|
||||
decayed during training.
|
||||
epsilon_eval: float, epsilon used when evaluating the agent.
|
||||
epsilon_decay_period: int, length of the epsilon decay schedule.
|
||||
tf_device: str, Tensorflow device on which the agent's graph is executed.
|
||||
use_staging: bool, when True use a staging area to prefetch the next
|
||||
training batch, speeding training up by about 30%.
|
||||
max_tf_checkpoints_to_keep: int, the number of TensorFlow checkpoints to
|
||||
keep.
|
||||
optimizer: `tf.train.Optimizer`, for training the value function.
|
||||
summary_writer: SummaryWriter object for outputting training statistics.
|
||||
Summary writing disabled if set to None.
|
||||
summary_writing_frequency: int, frequency with which summaries will be
|
||||
written. Lower values will result in slower training.
|
||||
"""
|
||||
|
||||
tf.logging.info('Creating %s agent with the following parameters:',
|
||||
self.__class__.__name__)
|
||||
tf.logging.info('\t gamma: %f', gamma)
|
||||
tf.logging.info('\t c: %f', c)
|
||||
tf.logging.info('\t k: %d', k)
|
||||
tf.logging.info('\t pos_q_init: %s', np.amax([gamma**k, pos_q_init]))
|
||||
tf.logging.info('\t neg_q_init: %s', np.amax([gamma**k, neg_q_init]))
|
||||
tf.logging.info('\t pos_Delta: %s', -c * np.log(np.amax([gamma**k, pos_q_init])))
|
||||
tf.logging.info('\t neg_Delta: %s', -c * np.log(np.amax([gamma**k, neg_q_init])))
|
||||
tf.logging.info('\t net_init_method: %s', net_init_method)
|
||||
tf.logging.info('\t clip_qt_max: %s', clip_qt_max)
|
||||
tf.logging.info('\t update_horizon: %d', update_horizon)
|
||||
tf.logging.info('\t min_replay_history: %d', min_replay_history)
|
||||
tf.logging.info('\t update_period: %d', update_period)
|
||||
tf.logging.info('\t target_update_period: %d', target_update_period)
|
||||
tf.logging.info('\t epsilon_train: %f', epsilon_train)
|
||||
tf.logging.info('\t epsilon_eval: %f', epsilon_eval)
|
||||
tf.logging.info('\t epsilon_decay_period: %d', epsilon_decay_period)
|
||||
tf.logging.info('\t tf_device: %s', tf_device)
|
||||
tf.logging.info('\t use_staging: %s', use_staging)
|
||||
tf.logging.info('\t loss_type: %s', loss_type)
|
||||
tf.logging.info('\t optimizer: %s', optimizer)
|
||||
tf.logging.info('\t beta_log: %f', optimizer._learning_rate)
|
||||
tf.logging.info('\t beta_reg: %f', alpha / optimizer._learning_rate)
|
||||
tf.logging.info('\t alpha: %f', alpha)
|
||||
|
||||
self.tf_float = tf.float64
|
||||
self.np_float = np.float64
|
||||
|
||||
self.num_actions = num_actions
|
||||
self.gamma = self.np_float(gamma)
|
||||
self.c = self.np_float(c)
|
||||
self.k = self.np_float(k)
|
||||
self.pos_q_init = np.amax([self.gamma**self.k, self.np_float(pos_q_init)])
|
||||
self.neg_q_init = np.amax([self.gamma**self.k, self.np_float(neg_q_init)])
|
||||
self.pos_Delta = -self.c * np.log(self.pos_q_init)
|
||||
self.neg_Delta = -self.c * np.log(self.neg_q_init)
|
||||
self.clip_qt_max = clip_qt_max
|
||||
self.net_init_method = net_init_method
|
||||
self.alpha = alpha
|
||||
self.beta_reg = alpha / optimizer._learning_rate
|
||||
self.update_horizon = update_horizon
|
||||
self.cumulative_gamma = self.gamma**self.np_float(update_horizon)
|
||||
self.min_replay_history = min_replay_history
|
||||
self.target_update_period = target_update_period
|
||||
self.epsilon_fn = epsilon_fn
|
||||
self.epsilon_train = epsilon_train
|
||||
self.epsilon_eval = epsilon_eval
|
||||
self.epsilon_decay_period = epsilon_decay_period
|
||||
self.update_period = update_period
|
||||
self.eval_mode = False
|
||||
self.training_steps = 0
|
||||
self.optimizer = optimizer
|
||||
self.loss_type = loss_type
|
||||
self.summary_writer = summary_writer
|
||||
self.summary_writing_frequency = summary_writing_frequency
|
||||
|
||||
with tf.device(tf_device):
|
||||
# Create a placeholder for the state input to the LogDQN network.
|
||||
# The last axis indicates the number of consecutive frames stacked.
|
||||
state_shape = [1,
|
||||
dqn_agent.OBSERVATION_SHAPE,
|
||||
dqn_agent.OBSERVATION_SHAPE,
|
||||
dqn_agent.STACK_SIZE]
|
||||
self.state = np.zeros(state_shape)
|
||||
self.state_ph = tf.placeholder(tf.uint8, state_shape, name='state_ph')
|
||||
self._replay = self._build_replay_buffer(use_staging)
|
||||
|
||||
self._build_networks()
|
||||
|
||||
self._train_op = self._build_train_op()
|
||||
self._sync_qt_ops = self._build_sync_op()
|
||||
|
||||
if self.summary_writer is not None:
|
||||
# All tf.summaries should have been defined prior to running this.
|
||||
self._merged_summaries = tf.summary.merge_all()
|
||||
self._sess = sess
|
||||
self._saver = tf.train.Saver(max_to_keep=max_tf_checkpoints_to_keep)
|
||||
|
||||
# Variables to be initialized by the agent once it interacts with the
|
||||
# environment.
|
||||
self._observation = None
|
||||
self._last_observation = None
|
||||
|
||||
def _get_network_type(self):
|
||||
"""Returns the type of the outputs of a Q-value network.
|
||||
|
||||
Returns:
|
||||
net_type: _network_type object defining the outputs of the network.
|
||||
"""
|
||||
return collections.namedtuple('LogDQN_network', ['q_values',
|
||||
'pos_q_tilde_values', 'neg_q_tilde_values',
|
||||
'pos_q_values', 'neg_q_values'])
|
||||
|
||||
def _network_template(self, state):
|
||||
"""Builds the convolutional network used to compute the agent's Q-values.
|
||||
|
||||
Args:
|
||||
state: `tf.Tensor`, contains the agent's current state.
|
||||
|
||||
Returns:
|
||||
net: _network_type object containing the tensors output by the network.
|
||||
"""
|
||||
net = tf.cast(state, tf.float32)
|
||||
net = tf.div(net, 255.)
|
||||
net = slim.conv2d(net, 32, [8, 8], stride=4)
|
||||
net = slim.conv2d(net, 64, [4, 4], stride=2)
|
||||
net = slim.conv2d(net, 64, [3, 3], stride=1)
|
||||
net = slim.flatten(net)
|
||||
net = slim.fully_connected(net, 512)
|
||||
net = tf.cast(net, self.tf_float)
|
||||
|
||||
# Create two network heads with the specified initialization scheme.
|
||||
pos_q_tilde_values = slim.fully_connected(net, self.num_actions,
|
||||
activation_fn=None)
|
||||
if self.net_init_method=='standard':
|
||||
neg_q_tilde_values = slim.fully_connected(net, self.num_actions,
|
||||
activation_fn=None)
|
||||
elif self.net_init_method=='asym':
|
||||
neg_q_tilde_values = slim.fully_connected(net, self.num_actions,
|
||||
activation_fn=None,
|
||||
weights_initializer=tf.zeros_initializer())
|
||||
|
||||
# Inverse mapping of Q-tilde values.
|
||||
pos_q_values = tf.exp((pos_q_tilde_values - self.pos_Delta) / self.c)
|
||||
neg_q_values = tf.exp((neg_q_tilde_values - self.neg_Delta) / self.c)
|
||||
|
||||
# Aggregate positive and negative heads' Q-values.
|
||||
q_values = pos_q_values - neg_q_values
|
||||
|
||||
return self._get_network_type()(q_values, pos_q_tilde_values,
|
||||
neg_q_tilde_values, pos_q_values, neg_q_values)
|
||||
|
||||
def _build_networks(self):
|
||||
"""Builds the Q-value network computations needed for acting and training.
|
||||
|
||||
These are:
|
||||
self.online_convnet: For computing the current state's Q-values.
|
||||
self.target_convnet: For computing the next state's target Q-values.
|
||||
self._net_outputs: The actual Q-values.
|
||||
self._q_argmax: The action maximizing the current state's Q-values.
|
||||
self._replay_net_outputs: The replayed states' Q-values.
|
||||
self._replay_next_target_net_outputs: The replayed next states' target
|
||||
Q-values (see Mnih et al., 2015 for details).
|
||||
"""
|
||||
# Calling online_convnet will generate a new graph as defined in
|
||||
# self._get_network_template using whatever input is passed, but will always
|
||||
# share the same weights.
|
||||
self.online_convnet = tf.make_template('Online', self._network_template)
|
||||
self.target_convnet = tf.make_template('Target', self._network_template)
|
||||
self._net_outputs = self.online_convnet(self.state_ph)
|
||||
self._q_argmax = tf.argmax(self._net_outputs.q_values, axis=1)[0]
|
||||
|
||||
self._replay_net_outputs = self.online_convnet(self._replay.states)
|
||||
self._replay_next_target_net_outputs = self.target_convnet(
|
||||
self._replay.next_states)
|
||||
|
||||
# Gets greedy actions over the aggregated target-network's Q-values for the
|
||||
# replay's next states, used for retrieving the target Q-values for both heads.
|
||||
self._replay_next_target_net_q_argmax = tf.argmax(
|
||||
self._replay_next_target_net_outputs.q_values, axis=1)
|
||||
|
||||
def _build_replay_buffer(self, use_staging):
|
||||
"""Creates a float64-compatible replay buffer used by the agent.
|
||||
|
||||
Args:
|
||||
use_staging: bool, if True, uses a staging area to prefetch data for
|
||||
faster training.
|
||||
|
||||
Returns:
|
||||
A WrapperReplayBuffer64 object.
|
||||
"""
|
||||
return circular_replay_buffer_64.WrappedReplayBuffer64(
|
||||
observation_shape=dqn_agent.OBSERVATION_SHAPE,
|
||||
stack_size=dqn_agent.STACK_SIZE,
|
||||
use_staging=use_staging,
|
||||
update_horizon=self.update_horizon,
|
||||
gamma=self.gamma)
|
||||
|
||||
def _build_target_q_op(self):
|
||||
"""Build an op used as a target for the logarithmic Q-value.
|
||||
|
||||
Returns:
|
||||
target_q_op: An op calculating the logarithmic Q-value.
|
||||
"""
|
||||
one = tf.constant(1, dtype=self.tf_float)
|
||||
zero = tf.constant(0, dtype=self.tf_float)
|
||||
# One-hot encode the greedy actions over the target-network's aggregated
|
||||
# Q-values for the replay's next states.
|
||||
replay_next_target_net_q_argmax_one_hot = tf.one_hot(
|
||||
self._replay_next_target_net_q_argmax, self.num_actions, one, zero,
|
||||
name='replay_next_target_net_q_argmax_one_hot')
|
||||
# Calculate each head's target Q-value (in standard space) with the
|
||||
# action that maximizes the target-network's aggregated Q-values for
|
||||
# the replay's next states.
|
||||
pos_replay_next_qt_max_unclipped = tf.reduce_sum(
|
||||
self._replay_next_target_net_outputs.pos_q_values * \
|
||||
replay_next_target_net_q_argmax_one_hot,
|
||||
reduction_indices=1,
|
||||
name='pos_replay_next_qt_max_unclipped')
|
||||
neg_replay_next_qt_max_unclipped = tf.reduce_sum(
|
||||
self._replay_next_target_net_outputs.neg_q_values * \
|
||||
replay_next_target_net_q_argmax_one_hot,
|
||||
reduction_indices=1,
|
||||
name='neg_replay_next_qt_max_unclipped')
|
||||
|
||||
# Clips the maximum target-network's positive and negative Q-values
|
||||
# for the replay's next states.
|
||||
if self.clip_qt_max:
|
||||
min_return = zero
|
||||
max_return = one / (one - self.cumulative_gamma)
|
||||
|
||||
pos_replay_next_qt_max_clipped_min = tf.maximum(min_return,
|
||||
pos_replay_next_qt_max_unclipped)
|
||||
pos_replay_next_qt_max = tf.minimum(max_return,
|
||||
pos_replay_next_qt_max_clipped_min)
|
||||
|
||||
neg_replay_next_qt_max_clipped_min = tf.maximum(min_return,
|
||||
neg_replay_next_qt_max_unclipped)
|
||||
neg_replay_next_qt_max = tf.minimum(max_return,
|
||||
neg_replay_next_qt_max_clipped_min)
|
||||
else:
|
||||
pos_replay_next_qt_max = pos_replay_next_qt_max_unclipped
|
||||
neg_replay_next_qt_max = neg_replay_next_qt_max_unclipped
|
||||
|
||||
# Terminal state masking.
|
||||
pos_replay_next_qt_max_masked = pos_replay_next_qt_max * \
|
||||
(1. - tf.cast(self._replay.terminals, self.tf_float))
|
||||
neg_replay_next_qt_max_masked = neg_replay_next_qt_max * \
|
||||
(1. - tf.cast(self._replay.terminals, self.tf_float))
|
||||
|
||||
# Creates the positive and negative head's separate reward signals
|
||||
# and bootstraps from the appropriate target for each head.
|
||||
# Positive head's reward signal is r if r > 0 and 0 otherwise.
|
||||
pos_standard_td_target_unclipped = self._replay.rewards * \
|
||||
tf.cast(tf.greater(self._replay.rewards, zero), self.tf_float) + \
|
||||
self.cumulative_gamma * pos_replay_next_qt_max_masked
|
||||
# Negative head's reward signal is -r if r < 0 and 0 otherwise.
|
||||
neg_standard_td_target_unclipped = -1 * self._replay.rewards * \
|
||||
tf.cast(tf.less(self._replay.rewards, zero), self.tf_float) + \
|
||||
self.cumulative_gamma * neg_replay_next_qt_max_masked
|
||||
|
||||
# Clips the minimum TD-targets in the standard space for both positive
|
||||
# and negative heads so as to avoid log(x <= 0).
|
||||
pos_standard_td_target = tf.maximum(self.cumulative_gamma**self.k,
|
||||
pos_standard_td_target_unclipped)
|
||||
neg_standard_td_target = tf.maximum(self.cumulative_gamma**self.k,
|
||||
neg_standard_td_target_unclipped)
|
||||
|
||||
# Gets the current-network's positive and negative Q-values (in standard
|
||||
# space) for the replay's chosen actions.
|
||||
replay_action_one_hot = tf.one_hot(
|
||||
self._replay.actions, self.num_actions, one, zero,
|
||||
name='replay_action_one_hot')
|
||||
pos_replay_chosen_q = tf.reduce_sum(
|
||||
self._replay_net_outputs.pos_q_values * replay_action_one_hot,
|
||||
reduction_indices=1, name='pos_replay_chosen_q')
|
||||
neg_replay_chosen_q = tf.reduce_sum(
|
||||
self._replay_net_outputs.neg_q_values * replay_action_one_hot,
|
||||
reduction_indices=1, name='neg_replay_chosen_q')
|
||||
|
||||
# Averaging samples in the standard space.
|
||||
pos_UT_new = pos_replay_chosen_q + \
|
||||
self.beta_reg * (pos_standard_td_target - pos_replay_chosen_q)
|
||||
neg_UT_new = neg_replay_chosen_q + \
|
||||
self.beta_reg * (neg_standard_td_target - neg_replay_chosen_q)
|
||||
|
||||
# Forward mapping.
|
||||
pos_log_td_target = self.c * tf.log(pos_UT_new) + self.pos_Delta
|
||||
neg_log_td_target = self.c * tf.log(neg_UT_new) + self.neg_Delta
|
||||
|
||||
pos_log_td_target = tf.cast(pos_log_td_target, tf.float32)
|
||||
neg_log_td_target = tf.cast(neg_log_td_target, tf.float32)
|
||||
return pos_log_td_target, neg_log_td_target
|
||||
|
||||
def _build_train_op(self):
|
||||
"""Builds a training op.
|
||||
|
||||
Returns:
|
||||
train_op: An op performing one step of training from replay data.
|
||||
"""
|
||||
one = tf.constant(1, dtype=self.tf_float)
|
||||
zero = tf.constant(0, dtype=self.tf_float)
|
||||
replay_action_one_hot = tf.one_hot(
|
||||
self._replay.actions, self.num_actions, one, zero, name='action_one_hot')
|
||||
# For the replay's chosen actions, these are the current-network's positive
|
||||
# and negative Q-tilde values, which will be updated for each head separately.
|
||||
pos_replay_chosen_q_tilde = tf.reduce_sum(
|
||||
self._replay_net_outputs.pos_q_tilde_values * replay_action_one_hot,
|
||||
reduction_indices=1,
|
||||
name='pos_replay_chosen_q_tilde')
|
||||
neg_replay_chosen_q_tilde = tf.reduce_sum(
|
||||
self._replay_net_outputs.neg_q_tilde_values * replay_action_one_hot,
|
||||
reduction_indices=1,
|
||||
name='neg_replay_chosen_q_tilde')
|
||||
|
||||
pos_replay_chosen_q_tilde = tf.cast(pos_replay_chosen_q_tilde, tf.float32)
|
||||
neg_replay_chosen_q_tilde = tf.cast(neg_replay_chosen_q_tilde, tf.float32)
|
||||
|
||||
# Gets the target for both positive and negative heads.
|
||||
pos_log_td_target, neg_log_td_target = self._build_target_q_op()
|
||||
pos_log_target = tf.stop_gradient(pos_log_td_target)
|
||||
neg_log_target = tf.stop_gradient(neg_log_td_target)
|
||||
|
||||
if self.loss_type == 'Huber':
|
||||
pos_loss = tf.losses.huber_loss(pos_log_target,
|
||||
pos_replay_chosen_q_tilde, reduction=tf.losses.Reduction.NONE)
|
||||
neg_loss = tf.losses.huber_loss(neg_log_target,
|
||||
neg_replay_chosen_q_tilde, reduction=tf.losses.Reduction.NONE)
|
||||
elif self.loss_type == 'MSE':
|
||||
pos_loss = tf.losses.mean_squared_error(pos_log_target,
|
||||
pos_replay_chosen_q_tilde, reduction=tf.losses.Reduction.NONE)
|
||||
neg_loss = tf.losses.mean_squared_error(neg_log_target,
|
||||
neg_replay_chosen_q_tilde, reduction=tf.losses.Reduction.NONE)
|
||||
|
||||
loss = pos_loss + neg_loss
|
||||
if self.summary_writer is not None:
|
||||
with tf.variable_scope('Losses'):
|
||||
tf.summary.scalar(self.loss_type+'Loss', tf.reduce_mean(loss))
|
||||
return self.optimizer.minimize(tf.reduce_mean(loss))
|
|
@ -0,0 +1,137 @@
|
|||
r"""The entry point for running an agent on an Atari 2600 domain.
|
||||
|
||||
This script modifies Dopamine's `train.py` to support LogDQN.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
|
||||
|
||||
from absl import app
|
||||
from absl import flags
|
||||
from dopamine.agents.dqn import dqn_agent
|
||||
from log_dqn import log_dqn_agent
|
||||
from dopamine.agents.implicit_quantile import implicit_quantile_agent
|
||||
from dopamine.agents.rainbow import rainbow_agent
|
||||
from dopamine.atari import run_experiment
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
|
||||
flags.DEFINE_bool('debug_mode', False,
|
||||
'If set to true, the agent will output in-episode statistics '
|
||||
'to Tensorboard. Disabled by default as this results in '
|
||||
'slower training.')
|
||||
flags.DEFINE_string('agent_name', None,
|
||||
'Name of the agent. Must be one of '
|
||||
'(dqn, log_dqn, rainbow, implicit_quantile)')
|
||||
flags.DEFINE_string('base_dir', None,
|
||||
'Base directory to host all required sub-directories.')
|
||||
flags.DEFINE_multi_string(
|
||||
'gin_files', [], 'List of paths to gin configuration files (e.g.'
|
||||
'"dopamine/agents/dqn/dqn.gin").')
|
||||
flags.DEFINE_multi_string(
|
||||
'gin_bindings', [],
|
||||
'Gin bindings to override the values set in the config files '
|
||||
'(e.g. "DQNAgent.epsilon_train=0.1",'
|
||||
' "create_environment.game_name="Pong"").')
|
||||
flags.DEFINE_string(
|
||||
'schedule', 'continuous_train_and_eval',
|
||||
'The schedule with which to run the experiment and choose an appropriate '
|
||||
'Runner. Supported choices are '
|
||||
'{continuous_train, continuous_train_and_eval}.')
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
|
||||
|
||||
|
||||
def create_agent(sess, environment, summary_writer=None):
|
||||
"""Creates a DQN agent.
|
||||
|
||||
Args:
|
||||
sess: A `tf.Session` object for running associated ops.
|
||||
environment: An Atari 2600 Gym environment.
|
||||
summary_writer: A Tensorflow summary writer to pass to the agent
|
||||
for in-agent training statistics in Tensorboard.
|
||||
|
||||
Returns:
|
||||
agent: An RL agent.
|
||||
|
||||
Raises:
|
||||
ValueError: If `agent_name` is not in supported list.
|
||||
"""
|
||||
if not FLAGS.debug_mode:
|
||||
summary_writer = None
|
||||
if FLAGS.agent_name == 'dqn':
|
||||
return dqn_agent.DQNAgent(sess, num_actions=environment.action_space.n,
|
||||
summary_writer=summary_writer)
|
||||
elif FLAGS.agent_name == 'log_dqn':
|
||||
return log_dqn_agent.LogDQNAgent(sess, num_actions=environment.action_space.n,
|
||||
summary_writer=summary_writer)
|
||||
elif FLAGS.agent_name == 'rainbow':
|
||||
return rainbow_agent.RainbowAgent(
|
||||
sess, num_actions=environment.action_space.n,
|
||||
summary_writer=summary_writer)
|
||||
elif FLAGS.agent_name == 'implicit_quantile':
|
||||
return implicit_quantile_agent.ImplicitQuantileAgent(
|
||||
sess, num_actions=environment.action_space.n,
|
||||
summary_writer=summary_writer)
|
||||
else:
|
||||
raise ValueError('Unknown agent: {}'.format(FLAGS.agent_name))
|
||||
|
||||
|
||||
def create_runner(base_dir, create_agent_fn):
|
||||
"""Creates an experiment Runner.
|
||||
|
||||
Args:
|
||||
base_dir: str, base directory for hosting all subdirectories.
|
||||
create_agent_fn: A function that takes as args a Tensorflow session and an
|
||||
Atari 2600 Gym environment, and returns an agent.
|
||||
|
||||
Returns:
|
||||
runner: A `run_experiment.Runner` like object.
|
||||
|
||||
Raises:
|
||||
ValueError: When an unknown schedule is encountered.
|
||||
"""
|
||||
assert base_dir is not None
|
||||
# Continuously runs training and evaluation until max num_iterations is hit.
|
||||
if FLAGS.schedule == 'continuous_train_and_eval':
|
||||
return run_experiment.Runner(base_dir, create_agent_fn)
|
||||
# Continuously runs training until max num_iterations is hit.
|
||||
elif FLAGS.schedule == 'continuous_train':
|
||||
return run_experiment.TrainRunner(base_dir, create_agent_fn)
|
||||
else:
|
||||
raise ValueError('Unknown schedule: {}'.format(FLAGS.schedule))
|
||||
|
||||
|
||||
def launch_experiment(create_runner_fn, create_agent_fn):
|
||||
"""Launches the experiment.
|
||||
|
||||
Args:
|
||||
create_runner_fn: A function that takes as args a base directory and a
|
||||
function for creating an agent and returns a `Runner`-like object.
|
||||
create_agent_fn: A function that takes as args a Tensorflow session and an
|
||||
Atari 2600 Gym environment, and returns an agent.
|
||||
"""
|
||||
run_experiment.load_gin_configs(FLAGS.gin_files, FLAGS.gin_bindings)
|
||||
runner = create_runner_fn(FLAGS.base_dir, create_agent_fn)
|
||||
runner.run_experiment()
|
||||
|
||||
|
||||
def main(unused_argv):
|
||||
"""Main method.
|
||||
|
||||
Args:
|
||||
unused_argv: Arguments (unused).
|
||||
"""
|
||||
tf.logging.set_verbosity(tf.logging.INFO)
|
||||
launch_experiment(create_runner, create_agent)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
flags.mark_flag_as_required('agent_name')
|
||||
flags.mark_flag_as_required('base_dir')
|
||||
app.run(main)
|
|
@ -0,0 +1,28 @@
|
|||
import codecs
|
||||
from os import path
|
||||
from setuptools import find_packages
|
||||
from setuptools import setup
|
||||
|
||||
|
||||
here = path.abspath(path.dirname(__file__))
|
||||
|
||||
with codecs.open(path.join(here, 'README.md'), encoding='utf-8') as f:
|
||||
long_description = f.read()
|
||||
|
||||
install_requires = ['gin-config >= 0.1.1', 'absl-py >= 0.2.2',
|
||||
'tensorflow==1.15rc3', 'opencv-python >= 3.4.1.15',
|
||||
'gym >= 0.10.5', 'dopamine-rl==1.0.2']
|
||||
|
||||
log_dqn_description = (
|
||||
'LogDQN agent from van Seijen, Fatemi, Tavakoli (2019)')
|
||||
|
||||
setup(
|
||||
name='log_dqn',
|
||||
version='0.0.1',
|
||||
packages=find_packages(),
|
||||
author_email='a.tavakoli@imperial.ac.uk',
|
||||
install_requires=install_requires,
|
||||
description=log_dqn_description,
|
||||
long_description=long_description,
|
||||
long_description_content_type='text/markdown'
|
||||
)
|
Загрузка…
Ссылка в новой задаче