diff --git a/leisure/__init__.py b/leisure/__init__.py index 015410c..e809c84 100644 --- a/leisure/__init__.py +++ b/leisure/__init__.py @@ -35,7 +35,11 @@ import tempfile def main(): script = sys.argv[1] - run_script(script, tempfile.mkdtemp()) + if len(sys.argv) == 3: + data_root = sys.argv[2] + else: + data_root = tempfile.mkdtemp() + run_script(script, data_root) if __name__ == "__main__": diff --git a/leisure/disco.py b/leisure/disco.py index eab0823..e9a9ad5 100644 --- a/leisure/disco.py +++ b/leisure/disco.py @@ -17,18 +17,25 @@ from disco.core import Disco from .io import puts from . import event_loop from . import job_control +from . import server disco_url_regex = re.compile(".*?://.*?/disco/(.*)") +preferred_host_re = re.compile("^[a-zA-Z0-9]+://([^/:]*)") def run_script(script, data_root): loop = start_event_loop() job_control.set_event_loop(loop) try: patch_disco() + host,port = server.start(loop) os.environ['DISCO_HOME'] = disco.__path__[0] os.environ['DISCO_DATA'] = data_root + os.environ['DISCO_PORT'] = str(port) + os.environ['DDFS_PUT_PORT'] = str(port) + + globals_ = { "__name__" : "__main__", "__file__" : script, @@ -123,6 +130,12 @@ def hex_hash(path): return hashlib.md5(path).hexdigest()[:2] +def preferred_host(url): + + m = preferred_host_re.search(url) + if m: + return m.group(1) + def timestamp(dt=None): """ diff --git a/leisure/event_emmiter.py b/leisure/event_emmiter.py new file mode 100644 index 0000000..616e4d9 --- /dev/null +++ b/leisure/event_emmiter.py @@ -0,0 +1,15 @@ +from collections import defaultdict +class EventEmmiter(object): + + def on(self, event, callback, *args): + if not hasattr(self, "_callbacks"): + self._callbacks = defaultdict(list) + + self._callbacks[event].append((callback, args)) + return self + + def fire(self, event, sender): + for callback, args in self._callbacks[event]: + callback(sender, *args) + + \ No newline at end of file diff --git a/leisure/event_loop.py b/leisure/event_loop.py index b683565..a77bca4 100644 --- a/leisure/event_loop.py +++ b/leisure/event_loop.py @@ -32,6 +32,9 @@ def add_reader(fd, callback, *args): def remove_reader(fd): current_event_loop().remove_reader(fd) +def run(): + current_event_loop().run() + def fileno(fd): if isinstance(fd, int): diff --git a/leisure/job.py b/leisure/job.py index 2f67eec..0baa818 100644 --- a/leisure/job.py +++ b/leisure/job.py @@ -57,11 +57,16 @@ class Job(object): def jobfile_path(self): return os.path.join(self.job_dir, "jobfile") + @property + def nr_reduces(self): + return self.jobpack.jobdict['nr_reduces'] + @property def has_map_phase(self): """Return true if the job has a map phase""" return self.jobpack.jobdict['map?'] + @property def has_reduce_phase(self): """Return true if the job has a map phase""" diff --git a/leisure/job_control.py b/leisure/job_control.py index eb5fe53..aa5f729 100644 --- a/leisure/job_control.py +++ b/leisure/job_control.py @@ -55,6 +55,7 @@ def map_reduce(job): return reduce(inputs, job, _finished) def _finished(results): + job.results.extend(results) job.status = "ready" map(job.inputs, job, _reduce) @@ -68,18 +69,47 @@ def map(inputs, job, cb): return run_phase(map_inputs(inputs), "map", job, cb) def map_inputs(inputs): + # preferred_host = leisure.disco.preferred_host + # def case(input): + # if isinstance(input, list): + # return [ (i, preferred_host(input)) for i in input ] + # else: + # return [(input, preferred_host(input))] + + # return list(enumerate([ case(input) for input in inputs ])) + + if not hasattr(inputs, '__iter__'): inputs = [inputs] return inputs def reduce(inputs, job, cb): - + if not job.has_reduce_phase: return event_loop.call_soon(cb, inputs) else: - return run_phase(map_inputs(inputs), "reduce", job, cb) + return run_phase(reduce_inputs(inputs, job.nr_reduces), "reduce", job, cb) +def reduce_inputs(inputs, n_red): + return inputs + hosts = usort([ + leisure.disco.preferred_host(input) + for input in inputs + ]) + + num_hosts = len(hosts) + if num_hosts == 0: + return [] + else: + hosts_d = dict(enumerate(hosts)) + return [ + (task_id, [(inputs, hosts_d[task_id % n_red])]) + for task_id in range(num_hosts) + ] + +def usort(inputs): + return sorted(set(inputs)) def results(job, mode, local_results, global_results, **state): @@ -106,12 +136,11 @@ def run_phase(inputs, mode, job, cb): outstanding = len(inputs), local_results = [], global_results = [] - #task_results = {} ) for id, input in enumerate(inputs): task = Task(id, job, input, mode) - task.on('done', on_task_done, task, state) + task.on('done', on_task_done, state) worker.start(task) def on_task_done(task, state): diff --git a/leisure/send_file.py b/leisure/send_file.py new file mode 100644 index 0000000..75d09ec --- /dev/null +++ b/leisure/send_file.py @@ -0,0 +1,49 @@ +import mimetypes +import os +import re + +from flask import request, send_file, Response + + + +def send_file_partial(path): + """ + Simple wrapper around send_file which handles HTTP 206 Partial Content + (byte ranges) + TODO: handle all send_file args, mirror send_file's error handling + (if it has any) + """ + + range_header = request.headers.get('Range', None) + if not range_header: return send_file(path) + + size = os.path.getsize(path) + byte1, byte2 = 0, None + + m = re.search('(\d+)-(\d*)', range_header) + g = m.groups() + + if g[0]: byte1 = int(g[0]) + if g[1]: byte2 = int(g[1]) + + byte2 = min(size, byte2) + + #if byte1 == byte2: + # import pdb; pdb.set_trace() + + length = size - byte1 + if byte2 is not None: + length = byte2 - byte1 + + data = None + with open(path, 'rb') as f: + f.seek(byte1) + data = f.read(length) + + rv = Response(data, + 206, + mimetype="application/octet-stream",#mimetypes.guess_type(path)[0], + direct_passthrough=True) + rv.headers.add('Content-Range', 'bytes {0}-{1}/{2}'.format(byte1, byte1 + length - 1, size)) + + return rv diff --git a/leisure/server.py b/leisure/server.py new file mode 100644 index 0000000..78855bf --- /dev/null +++ b/leisure/server.py @@ -0,0 +1,105 @@ +from http_parser.pyparser import HttpParser + +from .io import puts, indent +from .transports import Socket + + +import os +import leisure + +def start(event_loop): + socket = Socket(("localhost", 0)) + socket.on("accept", new_connection) + addr = socket.listen(5, event_loop) + puts("{}".format(addr)) + return addr + +def new_connection(client): + parser =HttpParser(kind=0) + parser.environ = True + + client.on( + "data", on_read, parser , client + ).on( + "error", on_error, client + ) + +def on_read(data, parser, client): + parser.execute(data.tobytes(), len(data)) + if parser.is_headers_complete(): + env = parser.get_wsgi_environ() + dispatch(env, client) + +def on_error(exce, client): + print exce + +def dispatch(env, client): + sock = client._socket + + out = bytearray() + + def start_response(status, response_headers, exc_info=None): + out.extend("HTTP/1.1 ") + out.extend(status) + out.extend("\r\n") + + for header, value in response_headers: + out.extend("{}: {}\r\n".format(header, value)) + + out.extend("\r\n") + + return sock.send + + + headers_sent = False + sent = 0 + app_iter = app.wsgi_app(env, start_response) + + + try: + for data in app_iter: + if not headers_sent: + #puts(out), + sock.send(out) + headers_sent = True + + sock.send(data) + sent += len(data) + finally: + if hasattr(app_iter, 'close'): + app_iter.close() + + client.close() + + +from flask import Flask, request, Response, abort +from .send_file import send_file_partial + +app = Flask(__name__) +app.debug = True + +@app.after_request +def after_request(response): + response.headers.add('Accept-Ranges', 'bytes') + return response + + +@app.route('/') +def hello_world(): + return "hello" + + +@app.route('/disco/') +def disco(path): + if '..' in path or path.startswith('/'): + abort(404) + + puts(request.url) + real_path = os.path.join(os.environ['DISCO_DATA'], path) + return send_file_partial(real_path) + + + + + + diff --git a/leisure/shuffle.py b/leisure/shuffle.py index f2c2fa1..ce321d1 100644 --- a/leisure/shuffle.py +++ b/leisure/shuffle.py @@ -1,4 +1,5 @@ import os +from gzip import GzipFile from itertools import groupby, chain from . import disco @@ -76,8 +77,8 @@ def write_index(filename, lines): tmp_path = "{}-{}".format(filename, disco.timestamp()) - - output = open(tmp_path, 'w') + output = GzipFile(tmp_path, 'w') + #output = open(tmp_path, 'w') output.writelines(lines) output.close() os.rename(tmp_path, filename) diff --git a/leisure/task.py b/leisure/task.py index 49fe9ee..c14b7ec 100644 --- a/leisure/task.py +++ b/leisure/task.py @@ -1,12 +1,12 @@ import os import time -from collections import defaultdict from .path import ensure_dir +from .event_emmiter import EventEmmiter import leisure -class Task(object): +class Task(EventEmmiter): def __init__(self, id, job, input, mode): self.id = id self.job = job @@ -16,34 +16,30 @@ class Task(object): self.output_file_name = None self.output_file = None self.host ="localhost" - self.callbacks = defaultdict(list) - def on(self, event, callback, *args): - self.callbacks[event].append((callback, args)) + self.disco_port = int(os.environ['DISCO_PORT']) + self.put_port = int(os.environ['DDFS_PUT_PORT']) - def fire(self, event): - for callback, args in self.callbacks[event]: - callback(*args) def done(self): if self.output_file: self.output_file.close() - self.fire('done') + self.fire('done', self) def info(self): path = self.job.job_dir return dict( host = self.host, - disco_data = os.path.join(path, "data"), - ddfs_data = os.path.join(path, "ddfs"), - master = "http://localhost:8989", + disco_data = os.environ['DISCO_DATA'], + ddfs_data = os.path.join(os.environ['DISCO_DATA'], "ddfs"), + master = "http://localhost:{}".format(self.disco_port), taskid = self.id, jobfile = self.job.jobfile_path, mode = self.mode, jobname = self.job.name, - disco_port = 8989, - put_port = 8990 + disco_port = self.disco_port, + put_port = self.put_port ) @property diff --git a/leisure/transports.py b/leisure/transports.py index 3a3a2a5..1cbba1b 100644 --- a/leisure/transports.py +++ b/leisure/transports.py @@ -1,5 +1,112 @@ import fcntl import os +import socket +import errno + +from .event_emmiter import EventEmmiter +class Socket(EventEmmiter): + def __init__(self,address, delegate=None): + self.address = address + self.delegate = delegate + self.event_loop = None + + def listen(self, backlog, event_loop = None): + """Listen for incoming connections on this port. + + backlog - the maximum number of queued connectinos + + runLoop - the runLoop that will monitor this port for + incomming connections. Defaults to the + currentRunLoop if none is specified. + """ + + if type(self.address) == tuple: + serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM ) + socket_path = None + else: + socket_path = self.address + serversocket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM ) + + if os.path.exists(socket_path): + # possible stale socket let's see if any one is listning + err = serversocket.connect_ex(socket_path) + if err == errno.ECONNREFUSED: + os.unlink(socket_path) + else: + serversocket._reset() + raise RuntimeError("Socket path %s is in use" % socket_path ) + + + serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + serversocket.bind(self.address) + + if socket_path: # ensure the world can read/write this socket + os.chmod(socket_path, 666) + + serversocket.listen(backlog) + serversocket.setblocking(0) + + self._socket = serversocket + self.listening = True + self.connected = True + + if event_loop is None: + event_loop = current_event_loop() + + event_loop.add_reader(self._socket, self.new_connection, self._socket) + self.event_loop = event_loop + return self._socket.getsockname() + + def new_connection(self, srv_socket): + client, addr = srv_socket.accept() + new_socket = Socket(addr, self.delegate) + new_socket.connection_accepted(client, self.event_loop) + self.fire("accept", new_socket) + + def connection_accepted(self, socket, event_loop): + self._socket = socket + self.event_loop = event_loop + self.connected = True + self.event_loop.add_reader(socket, self.can_read, socket) + + def close(self): + if self._socket: + self.event_loop.remove_reader(self._socket) + self._socket = None + self.fire('closed', self) + + + def can_read(self, client): + + while True: + try: + buf = bytearray(4096) + mem = memoryview(buf) + bytes = client.recv_into(buf) + if bytes > 0: + self.fire('data', mem[:bytes]) + else: + self.close() + + + except socket.error,e: + if e[0] in (errno.EWOULDBLOCK, errno.EAGAIN): + # other end of the socket is full, so + # ask the runLoop when we can send more + # data + + break + else: + import pdb; pdb.set_trace() + # if we receive any other socket + # error we close the connection + # and raise and notify our delegate + + #self._reset() + #self.delegate.onError(self, e) + self.fire('error', e) + self.event_loop.remove_reader(client) + class Stream(object): def __init__(self, fd, delegate=None): diff --git a/leisure/worker.py b/leisure/worker.py index 6f14dc6..8aaab24 100644 --- a/leisure/worker.py +++ b/leisure/worker.py @@ -89,7 +89,7 @@ def response(proc, task, packet): elif type == 'MSG': puts(payload) return msg("OK","") - elif type == ('ERROR','FATAL'): + elif type in ('ERROR','FATAL'): # todo: fail, the task task.job.status = "dead" done(proc) @@ -99,7 +99,7 @@ def response(proc, task, packet): return msg('TASK',task.info()) elif type == "INPUT": - return msg('INPUT', [ + return msg('INPUT', [ u'done', [ [0, u'ok', [[0, task.input]]] ] @@ -113,4 +113,4 @@ def response(proc, task, packet): task.done() return done(proc) else: - pass + raise RuntimeError("Uknown message type '' received".format(type)) diff --git a/requirements.txt b/requirements.txt index a35219c..eb7f682 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1 +1,3 @@ git+https://github.com/trivio/disco.git#egg=disco +http-parser +Flask diff --git a/sample_jobs/word_count.py b/sample_jobs/word_count.py new file mode 100644 index 0000000..15556d0 --- /dev/null +++ b/sample_jobs/word_count.py @@ -0,0 +1,19 @@ +from disco.core import Job, result_iterator + +def map(line, params): + for word in line.split(): + yield word, 1 + +def reduce(iter, params): + from disco.util import kvgroup + for word, counts in kvgroup(sorted(iter)): + yield word, sum(counts) + +if __name__ == '__main__': + print "runnning job" + job = Job().run(input=["http://discoproject.org/media/text/chekhov.txt"], + map=map, + reduce=reduce) + for word, count in result_iterator(job.wait(show=True)): + #print(word, count) + pass \ No newline at end of file diff --git a/setup.py b/setup.py index ff56b7e..f750bec 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ except IOError: setup( name='leisure', py_modules = ['leisure'], - version='0.0.1', + version='0.0.2', description='local job runner for disco', long_description=README, author='Scott Robertson', diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..b643244 --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,64 @@ +import sys +import os +from unittest import TestCase +from nose.tools import eq_ +import shutil +import tempfile +import threading + +import requests + +from leisure import server, event_loop +import logging +logging.basicConfig(stream=sys.stderr) + +from StringIO import StringIO +import gzip + + +class TestServer(object): + def setUp(self): + self.data_root = tempfile.mkdtemp() + os.environ['DISCO_DATA'] = self.data_root + self.event_loop = event_loop + self.event_loop.call_later(2, lambda: event_loop.stop()) + + def tearDown(self): + shutil.rmtree(self.data_root) + + def request(self, method, path=''): + loop = event_loop.current_event_loop() + context = [] + def fetch_data(addr): + def _(): # requests is blocking so needs it's own thread + + url = "http://{1}:{2}/{0}".format(path, *addr) + #import pdb; pdb.set_trace() + context.append( requests.get(url, timeout=3)) + loop.stop() + t = threading.Thread(target=_) + t.daemon = True + t.start() + + addr = server.start(event_loop) + self.event_loop.call_soon(fetch_data, addr) + self.event_loop.run() + return context.pop() + + def get(self, path=''): + return self.request('GET', path) + + + def test_get_compressed(self): + content = "line 1\nline 2\n" * 1024**2 + index = gzip.GzipFile(os.path.join(self.data_root, 'index.gz'), 'w') + index.write(content) + index.close() + + resp = self.get("disco/index.gz") + + data = gzip.GzipFile(fileobj=StringIO(resp.content)).read() + + eq_(data, content) + + diff --git a/tests/test_shuffle.py b/tests/test_shuffle.py index fdf1e78..27e6148 100644 --- a/tests/test_shuffle.py +++ b/tests/test_shuffle.py @@ -3,6 +3,7 @@ from unittest import TestCase from nose.tools import eq_ import tempfile import shutil +import gzip from leisure import shuffle, disco from leisure.path import makedirs @@ -32,7 +33,7 @@ class TestShuffle(TestCase): shutil.rmtree(self.data_root) def make_part_info(self, job_home): - part_dir = "partitions-{}".format(shuffle.timestamp()) + part_dir = "partitions-{}".format(disco.timestamp()) part_path = os.path.join( job_home, part_dir @@ -108,7 +109,7 @@ class TestShuffle(TestCase): filename = os.path.join(self.data_root, "blah") shuffle.write_index(filename, index) - read_lines = open(filename).readlines() + read_lines = gzip.GzipFile(filename).readlines() self.assertSequenceEqual(index, read_lines) def test_process_url_non_local(self): @@ -226,7 +227,5 @@ class TestShuffle(TestCase): mode="map", task_results=task_results )) - import pdb; pdb.set_trace() - pass