diff --git a/test/fiber/test_mutex.rb b/test/fiber/test_mutex.rb index a70c6992ab..258f5358a6 100644 --- a/test/fiber/test_mutex.rb +++ b/test/fiber/test_mutex.rb @@ -70,6 +70,38 @@ class TestFiberMutex < Test::Unit::TestCase thread.join end + def test_mutex_fiber_raise + mutex = Mutex.new + ran = false + + main = Thread.new do + mutex.lock + + thread = Thread.new do + scheduler = Scheduler.new + Thread.current.scheduler = scheduler + + f = Fiber.schedule do + assert_raise_with_message(RuntimeError, "bye") do + assert_same scheduler, Thread.scheduler + mutex.lock + end + ran = true + end + + Fiber.schedule do + f.raise "bye" + end + end + + thread.join + end + + main.join # causes mutex to be released + assert_equal false, mutex.locked? + assert_equal true, ran + end + def test_condition_variable mutex = Mutex.new condition = ConditionVariable.new diff --git a/thread_sync.c b/thread_sync.c index 94e4d35395..148e6091e6 100644 --- a/thread_sync.c +++ b/thread_sync.c @@ -214,18 +214,17 @@ VALUE rb_mutex_trylock(VALUE self) { rb_mutex_t *mutex = mutex_ptr(self); - VALUE locked = Qfalse; if (mutex->fiber == 0) { rb_fiber_t *fiber = GET_EC()->fiber_ptr; rb_thread_t *th = GET_THREAD(); mutex->fiber = fiber; - locked = Qtrue; mutex_locked(th, self); + return Qtrue; } - return locked; + return Qfalse; } /* @@ -246,6 +245,16 @@ mutex_owned_p(rb_fiber_t *fiber, rb_mutex_t *mutex) } } +static VALUE call_rb_scheduler_block(VALUE mutex) { + return rb_scheduler_block(rb_thread_current_scheduler(), mutex); +} + +static VALUE remove_from_mutex_lock_waiters(VALUE arg) { + struct list_node *node = (struct list_node*)arg; + list_del(node); + return Qnil; +} + static VALUE do_mutex_lock(VALUE self, int interruptible_p) { @@ -276,9 +285,7 @@ do_mutex_lock(VALUE self, int interruptible_p) if (scheduler != Qnil) { list_add_tail(&mutex->waitq, &w.node); - rb_scheduler_block(scheduler, self); - - list_del(&w.node); + rb_ensure(call_rb_scheduler_block, self, remove_from_mutex_lock_waiters, (VALUE)&w.node); if (!mutex->fiber) { mutex->fiber = fiber;