Fix Array#flatten for recursive array when given positive depth [Bug #17092]

This commit is contained in:
Marc-Andre Lafortune 2020-07-29 16:59:06 -04:00
Родитель 2bd1f827f1
Коммит 1b1ea7b3bc
2 изменённых файлов: 37 добавлений и 21 удалений

12
array.c
Просмотреть файл

@ -6943,8 +6943,6 @@ flatten(VALUE ary, int level)
} }
if (i == RARRAY_LEN(ary)) { if (i == RARRAY_LEN(ary)) {
return ary; return ary;
} else if (tmp == ary) {
rb_raise(rb_eArgError, "tried to flatten recursive array");
} }
result = ary_new(0, RARRAY_LEN(ary)); result = ary_new(0, RARRAY_LEN(ary));
@ -6955,12 +6953,14 @@ flatten(VALUE ary, int level)
rb_ary_push(stack, ary); rb_ary_push(stack, ary);
rb_ary_push(stack, LONG2NUM(i + 1)); rb_ary_push(stack, LONG2NUM(i + 1));
if (level < 0) {
vmemo = rb_hash_new(); vmemo = rb_hash_new();
RBASIC_CLEAR_CLASS(vmemo); RBASIC_CLEAR_CLASS(vmemo);
memo = st_init_numtable(); memo = st_init_numtable();
rb_hash_st_table_set(vmemo, memo); rb_hash_st_table_set(vmemo, memo);
st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue); st_insert(memo, (st_data_t)ary, (st_data_t)Qtrue);
st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue); st_insert(memo, (st_data_t)tmp, (st_data_t)Qtrue);
}
ary = tmp; ary = tmp;
i = 0; i = 0;
@ -6974,20 +6974,24 @@ flatten(VALUE ary, int level)
} }
tmp = rb_check_array_type(elt); tmp = rb_check_array_type(elt);
if (RBASIC(result)->klass) { if (RBASIC(result)->klass) {
if (level < 0) {
RB_GC_GUARD(vmemo); RB_GC_GUARD(vmemo);
st_clear(memo); st_clear(memo);
}
rb_raise(rb_eRuntimeError, "flatten reentered"); rb_raise(rb_eRuntimeError, "flatten reentered");
} }
if (NIL_P(tmp)) { if (NIL_P(tmp)) {
rb_ary_push(result, elt); rb_ary_push(result, elt);
} }
else { else {
if (level < 0) {
id = (st_data_t)tmp; id = (st_data_t)tmp;
if (st_is_member(memo, id)) { if (st_is_member(memo, id)) {
st_clear(memo); st_clear(memo);
rb_raise(rb_eArgError, "tried to flatten recursive array"); rb_raise(rb_eArgError, "tried to flatten recursive array");
} }
st_insert(memo, id, (st_data_t)Qtrue); st_insert(memo, id, (st_data_t)Qtrue);
}
rb_ary_push(stack, ary); rb_ary_push(stack, ary);
rb_ary_push(stack, LONG2NUM(i)); rb_ary_push(stack, LONG2NUM(i));
ary = tmp; ary = tmp;
@ -6997,14 +7001,18 @@ flatten(VALUE ary, int level)
if (RARRAY_LEN(stack) == 0) { if (RARRAY_LEN(stack) == 0) {
break; break;
} }
if (level < 0) {
id = (st_data_t)ary; id = (st_data_t)ary;
st_delete(memo, &id, 0); st_delete(memo, &id, 0);
}
tmp = rb_ary_pop(stack); tmp = rb_ary_pop(stack);
i = NUM2LONG(tmp); i = NUM2LONG(tmp);
ary = rb_ary_pop(stack); ary = rb_ary_pop(stack);
} }
if (level < 0) {
st_clear(memo); st_clear(memo);
}
RBASIC_SET_CLASS(result, rb_obj_class(ary)); RBASIC_SET_CLASS(result, rb_obj_class(ary));
return result; return result;

Просмотреть файл

@ -886,6 +886,17 @@ class TestArray < Test::Unit::TestCase
assert_raise(NoMethodError, bug12738) { a.flatten.m } assert_raise(NoMethodError, bug12738) { a.flatten.m }
end end
def test_flatten_recursive
a = []
a << a
assert_raise(ArgumentError) { a.flatten }
b = [1]; c = [2, b]; b << c
assert_raise(ArgumentError) { b.flatten }
assert_equal([1, 2, b], b.flatten(1))
assert_equal([1, 2, 1, 2, 1, c], b.flatten(4))
end
def test_flatten! def test_flatten!
a1 = @cls[ 1, 2, 3] a1 = @cls[ 1, 2, 3]
a2 = @cls[ 5, 6 ] a2 = @cls[ 5, 6 ]
@ -2649,9 +2660,6 @@ class TestArray < Test::Unit::TestCase
def test_flatten_error def test_flatten_error
a = [] a = []
a << a
assert_raise(ArgumentError) { a.flatten }
f = [].freeze f = [].freeze
assert_raise(ArgumentError) { a.flatten!(1, 2) } assert_raise(ArgumentError) { a.flatten!(1, 2) }
assert_raise(TypeError) { a.flatten!(:foo) } assert_raise(TypeError) { a.flatten!(:foo) }