From e1d16a9e560a615e122e457325bcfb7c47228ed6 Mon Sep 17 00:00:00 2001 From: Jeremy Evans Date: Fri, 5 Mar 2021 12:25:51 -0800 Subject: [PATCH] Make Enumerator#{+,chain} create lazy chain if any included enumerator is lazy Implements [Feature #17347] --- enumerator.c | 21 ++++++++++++++++----- test/ruby/test_enumerator.rb | 12 ++++++++++++ 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/enumerator.c b/enumerator.c index 1c1ece0cfe..45620f352a 100644 --- a/enumerator.c +++ b/enumerator.c @@ -3137,6 +3137,20 @@ enum_chain_initialize(VALUE obj, VALUE enums) return obj; } +static VALUE +new_enum_chain(VALUE enums) { + long i; + VALUE obj = enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); + + for (i = 0; i < RARRAY_LEN(enums); i++) { + if (RTEST(rb_obj_is_kind_of(RARRAY_AREF(enums, i), rb_cLazy))) { + return enumerable_lazy(obj); + } + } + + return obj; +} + /* :nodoc: */ static VALUE enum_chain_init_copy(VALUE obj, VALUE orig) @@ -3306,8 +3320,7 @@ enum_chain(int argc, VALUE *argv, VALUE obj) { VALUE enums = rb_ary_new_from_values(1, &obj); rb_ary_cat(enums, argv, argc); - - return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); + return new_enum_chain(enums); } /* @@ -3323,9 +3336,7 @@ enum_chain(int argc, VALUE *argv, VALUE obj) static VALUE enumerator_plus(VALUE obj, VALUE eobj) { - VALUE enums = rb_ary_new_from_args(2, obj, eobj); - - return enum_chain_initialize(enum_chain_allocate(rb_cEnumChain), enums); + return new_enum_chain(rb_ary_new_from_args(2, obj, eobj)); } /* diff --git a/test/ruby/test_enumerator.rb b/test/ruby/test_enumerator.rb index 9b615ff9db..4e698fc478 100644 --- a/test/ruby/test_enumerator.rb +++ b/test/ruby/test_enumerator.rb @@ -820,6 +820,18 @@ class TestEnumerator < Test::Unit::TestCase assert_equal([[3, 0], [4, 1]], [3].chain([4]).with_index.to_a) end + def test_lazy_chain + ea = (10..).lazy.select(&:even?).take(10) + ed = (20..).lazy.select(&:odd?) + chain = (ea + ed).select{|x| x % 3 == 0} + assert_equal(12, chain.next) + assert_equal(18, chain.next) + assert_equal(24, chain.next) + assert_equal(21, chain.next) + assert_equal(27, chain.next) + assert_equal(33, chain.next) + end + def test_produce assert_raise(ArgumentError) { Enumerator.produce }