From 5b7ccc0629baa7cd2c7ab92802ee1bf62e3ec0f4 Mon Sep 17 00:00:00 2001 From: nobu Date: Thu, 26 Aug 2010 22:57:39 +0000 Subject: [PATCH] * array.c (rb_ary_shuffle_bang): bail out from modification during shuffle. * array.c (rb_ary_sample): ditto. git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@29108 b2dd03c8-39d4-4d8f-98ff-823fe69b080e --- ChangeLog | 7 +++++ array.c | 64 +++++++++++++++++++++++++++++------------ test/ruby/test_array.rb | 56 +++++++++++++++++++++++++++++++++++- 3 files changed, 108 insertions(+), 19 deletions(-) diff --git a/ChangeLog b/ChangeLog index c7853b63b8..d5474793aa 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,10 @@ +Fri Aug 27 07:57:34 2010 Nobuyoshi Nakada + + * array.c (rb_ary_shuffle_bang): bail out from modification during + shuffle. + + * array.c (rb_ary_sample): ditto. + Fri Aug 27 05:11:51 2010 Tanaka Akira * ext/pathname/pathname.c (path_sysopen): Pathname#sysopen translated diff --git a/array.c b/array.c index 0e51bb0b98..8a828c506c 100644 --- a/array.c +++ b/array.c @@ -20,6 +20,8 @@ #endif #include +#define numberof(array) (int)(sizeof(array) / sizeof((array)[0])) + VALUE rb_cArray; static ID id_cmp; @@ -3748,8 +3750,8 @@ static VALUE sym_random; static VALUE rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary) { - VALUE *ptr, opts, randgen = rb_cRandom; - long i = RARRAY_LEN(ary); + VALUE *ptr, opts, *snap_ptr, randgen = rb_cRandom; + long i, snap_len; if (OPTHASH_GIVEN_P(opts)) { randgen = rb_hash_lookup2(opts, sym_random, randgen); @@ -3758,10 +3760,17 @@ rb_ary_shuffle_bang(int argc, VALUE *argv, VALUE ary) rb_raise(rb_eArgError, "wrong number of arguments (%d for 0)", argc); } rb_ary_modify(ary); + i = RARRAY_LEN(ary); ptr = RARRAY_PTR(ary); + snap_len = i; + snap_ptr = ptr; while (i) { long j = RAND_UPTO(i); - VALUE tmp = ptr[--i]; + VALUE tmp; + if (snap_len != RARRAY_LEN(ary) || snap_ptr != RARRAY_PTR(ary)) { + rb_raise(rb_eRuntimeError, "modified during shuffle"); + } + tmp = ptr[--i]; ptr[i] = ptr[j]; ptr[j] = tmp; } @@ -3814,37 +3823,54 @@ static VALUE rb_ary_sample(int argc, VALUE *argv, VALUE ary) { VALUE nv, result, *ptr; - VALUE opts, randgen = rb_cRandom; + VALUE opts, snap, randgen = rb_cRandom; long n, len, i, j, k, idx[10]; + double rnds[numberof(idx)]; - len = RARRAY_LEN(ary); if (OPTHASH_GIVEN_P(opts)) { randgen = rb_hash_lookup2(opts, sym_random, randgen); } + ptr = RARRAY_PTR(ary); + len = RARRAY_LEN(ary); if (argc == 0) { if (len == 0) return Qnil; - i = len == 1 ? 0 : RAND_UPTO(len); + if (len == 1) { + i = 0; + } + else { + double x = rb_random_real(randgen); + if ((len = RARRAY_LEN(ary)) == 0) return Qnil; + i = (long)(x * len); + } return RARRAY_PTR(ary)[i]; } rb_scan_args(argc, argv, "1", &nv); n = NUM2LONG(nv); if (n < 0) rb_raise(rb_eArgError, "negative sample number"); - ptr = RARRAY_PTR(ary); + if (n > len) n = len; + if (n <= numberof(idx)) { + for (i = 0; i < n; ++i) { + rnds[i] = rb_random_real(randgen); + } + } len = RARRAY_LEN(ary); + ptr = RARRAY_PTR(ary); if (n > len) n = len; switch (n) { - case 0: return rb_ary_new2(0); + case 0: + return rb_ary_new2(0); case 1: - return rb_ary_new4(1, &ptr[RAND_UPTO(len)]); + i = (long)(rnds[0] * len); + return rb_ary_new4(1, &ptr[i]); case 2: - i = RAND_UPTO(len); - j = RAND_UPTO(len-1); + i = (long)(rnds[0] * len); + j = (long)(rnds[1] * (len-1)); if (j >= i) j++; return rb_ary_new3(2, ptr[i], ptr[j]); case 3: - i = RAND_UPTO(len); - j = RAND_UPTO(len-1); - k = RAND_UPTO(len-2); + i = (long)(rnds[0] * len); + j = (long)(rnds[1] * (len-1)); + k = (long)(rnds[2] * (len-2)); { long l = j, g = i; if (j >= i) l = i, g = ++j; @@ -3852,12 +3878,12 @@ rb_ary_sample(int argc, VALUE *argv, VALUE ary) } return rb_ary_new3(3, ptr[i], ptr[j], ptr[k]); } - if ((size_t)n < sizeof(idx)/sizeof(idx[0])) { + if (n <= numberof(idx)) { VALUE *ptr_result; - long sorted[sizeof(idx)/sizeof(idx[0])]; - sorted[0] = idx[0] = RAND_UPTO(len); + long sorted[numberof(idx)]; + sorted[0] = idx[0] = (long)(rnds[0] * len); for (i=1; iklass = 0; ptr_result = RARRAY_PTR(result); RB_GC_GUARD(ary); for (i=0; iklass = rb_cArray; } ARY_SET_LEN(result, n); diff --git a/test/ruby/test_array.rb b/test/ruby/test_array.rb index 4c3aba0589..44f71d3495 100644 --- a/test/ruby/test_array.rb +++ b/test/ruby/test_array.rb @@ -1901,7 +1901,6 @@ class TestArray < Test::Unit::TestCase end def test_shuffle_random - cc = nil gen = proc do 10000000 end @@ -1911,6 +1910,16 @@ class TestArray < Test::Unit::TestCase assert_raise(RangeError) { [*0..2].shuffle(random: gen) } + + ary = (0...10000).to_a + gen = proc do + ary.replace([]) + 0.5 + end + class << gen + alias rand call + end + assert_raise(RuntimeError) {ary.shuffle!(random: gen)} end def test_sample @@ -1951,6 +1960,51 @@ class TestArray < Test::Unit::TestCase end end + def test_sample_random + ary = (0...10000).to_a + assert_raise(ArgumentError) {ary.sample(1, 2, random: nil)} + gen0 = proc do + 0.5 + end + class << gen0 + alias rand call + end + gen1 = proc do + ary.replace([]) + 0.5 + end + class << gen1 + alias rand call + end + assert_equal(5000, ary.sample(random: gen0)) + assert_nil(ary.sample(random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000], ary.sample(1, random: gen0)) + assert_equal([], ary.sample(1, random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000, 4999], ary.sample(2, random: gen0)) + assert_equal([], ary.sample(2, random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000, 4999, 5001], ary.sample(3, random: gen0)) + assert_equal([], ary.sample(3, random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000, 4999, 5001, 4998], ary.sample(4, random: gen0)) + assert_equal([], ary.sample(4, random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000, 4999, 5001, 4998, 5002, 4997, 5003, 4996, 5004, 4995], ary.sample(10, random: gen0)) + assert_equal([], ary.sample(10, random: gen1)) + assert_equal([], ary) + ary = (0...10000).to_a + assert_equal([5000, 0, 5001, 2, 5002, 4, 5003, 6, 5004, 8, 5005], ary.sample(11, random: gen0)) + ary.sample(11, random: gen1) # implementation detail, may change in the future + assert_equal([], ary) + end + def test_cycle a = [] [0, 1, 2].cycle do |i|