138 строки
5.3 KiB
Python
138 строки
5.3 KiB
Python
import os, sys, queue, logging
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
from decision_modules import Examiner, Interactor, Navigator, Hoarder, YesNo, YouHaveTo, Darkness, Idler
|
|
from event import *
|
|
from knowledge_graph import *
|
|
from gv import kg, event_stream, dbg, rng
|
|
from util import clean, action_recognized
|
|
from valid_detectors.learned_valid_detector import LearnedValidDetector
|
|
|
|
|
|
class NailAgent():
|
|
"""
|
|
NAIL Agent: Navigate, Acquire, Interact, Learn
|
|
|
|
NAIL has a set of decision modules which compete for control over low-level
|
|
actions. Changes in world-state and knowledge_graph stream events to the
|
|
decision modules. The modules then update how eager they are to take control.
|
|
|
|
"""
|
|
def __init__(self, seed, env, rom_name, output_subdir='.'):
|
|
self.setup_logging(rom_name, output_subdir)
|
|
rng.seed(seed)
|
|
dbg("RandomSeed: {}".format(seed))
|
|
self.knowledge_graph = gv.kg
|
|
self.knowledge_graph.__init__() # Re-initialize KnowledgeGraph
|
|
gv.event_stream.clear()
|
|
self.modules = [Examiner(True), Hoarder(True), Navigator(True), Interactor(True),
|
|
Idler(True), YesNo(True), YouHaveTo(True), Darkness(True)]
|
|
self.active_module = None
|
|
self.action_generator = None
|
|
self.first_step = True
|
|
self._valid_detector = LearnedValidDetector()
|
|
if env and rom_name:
|
|
self.env = env
|
|
self.step_num = 0
|
|
|
|
|
|
def setup_logging(self, rom_name, output_subdir):
|
|
""" Configure the logging facilities. """
|
|
for handler in logging.root.handlers[:]:
|
|
handler.close()
|
|
logging.root.removeHandler(handler)
|
|
self.logpath = os.path.join(output_subdir, 'nail_logs')
|
|
if not os.path.exists(self.logpath):
|
|
os.mkdir(self.logpath)
|
|
self.kgs_dir_path = os.path.join(output_subdir, 'kgs')
|
|
if not os.path.exists(self.kgs_dir_path):
|
|
os.mkdir(self.kgs_dir_path)
|
|
self.logpath = os.path.join(self.logpath, rom_name)
|
|
logging.basicConfig(format='%(message)s', filename=self.logpath+'.log',
|
|
level=logging.DEBUG, filemode='w')
|
|
|
|
|
|
def elect_new_active_module(self):
|
|
""" Selects the most eager module to take control. """
|
|
most_eager = 0.
|
|
for module in self.modules:
|
|
eagerness = module.get_eagerness()
|
|
if eagerness >= most_eager:
|
|
self.active_module = module
|
|
most_eager = eagerness
|
|
dbg("[NAIL](elect): {} Eagerness: {}"\
|
|
.format(type(self.active_module).__name__, most_eager))
|
|
self.action_generator = self.active_module.take_control()
|
|
self.action_generator.send(None)
|
|
|
|
|
|
def generate_next_action(self, observation):
|
|
"""Returns the action selected by the current active module and
|
|
selects a new active module if the current one is finished.
|
|
|
|
"""
|
|
next_action = None
|
|
while not next_action:
|
|
try:
|
|
next_action = self.action_generator.send(observation)
|
|
except StopIteration:
|
|
self.consume_event_stream()
|
|
self.elect_new_active_module()
|
|
return next_action.text()
|
|
|
|
|
|
def consume_event_stream(self):
|
|
""" Each module processes stored events then the stream is cleared. """
|
|
for module in self.modules:
|
|
module.process_event_stream()
|
|
event_stream.clear()
|
|
|
|
|
|
def take_action(self, observation):
|
|
if self.env:
|
|
# Add true locations to the .log file.
|
|
loc = self.env.get_player_location()
|
|
if loc and hasattr(loc, 'num') and hasattr(loc, 'name') and loc.num and loc.name:
|
|
dbg("[TRUE_LOC] {} \"{}\"".format(loc.num, loc.name))
|
|
|
|
# Output a snapshot of the kg.
|
|
# with open(os.path.join(self.kgs_dir_path, str(self.step_num) + '.kng'), 'w') as f:
|
|
# f.write(str(self.knowledge_graph)+'\n\n')
|
|
# self.step_num += 1
|
|
|
|
observation = observation.strip()
|
|
if self.first_step:
|
|
dbg("[NAIL] {}".format(observation))
|
|
self.first_step = False
|
|
return 'look' # Do a look to get rid of intro text
|
|
|
|
if not kg.player_location:
|
|
loc = Location(observation)
|
|
kg.add_location(loc)
|
|
kg.player_location = loc
|
|
kg._init_loc = loc
|
|
|
|
self.consume_event_stream()
|
|
|
|
if not self.active_module:
|
|
self.elect_new_active_module()
|
|
|
|
next_action = self.generate_next_action(observation)
|
|
return next_action
|
|
|
|
|
|
def observe(self, obs, action, score, new_obs, terminal):
|
|
""" Observe will be used for learning from rewards. """
|
|
p_valid = self._valid_detector.action_valid(action, new_obs)
|
|
dbg("[VALID] p={:.3f} {}".format(p_valid, clean(new_obs)))
|
|
if kg.player_location:
|
|
dbg("[EAGERNESS] {}".format(' '.join([str(module.get_eagerness()) for module in self.modules[:5]])))
|
|
event_stream.push(NewTransitionEvent(obs, action, score, new_obs, terminal))
|
|
action_recognized(action, new_obs) # Update the unrecognized words
|
|
if terminal:
|
|
kg.reset()
|
|
|
|
|
|
def finalize(self):
|
|
with open(self.logpath+'.kng', 'w') as f:
|
|
f.write(str(self.knowledge_graph)+'\n\n')
|