From 98b4256aa7a558e19739ac1d39ba10179277e1f4 Mon Sep 17 00:00:00 2001 From: Maxime Chevalier-Boisvert Date: Thu, 3 Aug 2023 16:09:18 -0400 Subject: [PATCH] YJIT: handle expandarray_rhs_too_small case (#8161) * YJIT: handle expandarray_rhs_too_small case YJIT: fix csel bug in x86 backend, add test Remove commented out lines Refactor expandarray to use chain guards Propagate Type::Nil when known Update yjit/src/codegen.rs Co-authored-by: Takashi Kokubun * Add missing counter, use get_array_ptr() in expandarray * Make change suggested by Kokubun to reuse loop --------- Co-authored-by: Takashi Kokubun --- bootstraptest/test_yjit.rb | 11 ++++ yjit/src/backend/x86_64/mod.rs | 66 +++++++++++++++++---- yjit/src/codegen.rs | 101 +++++++++++++++++++++++++-------- yjit/src/stats.rs | 3 +- 4 files changed, 143 insertions(+), 38 deletions(-) diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index 80d0ef00d8..6cc509f212 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -2276,6 +2276,17 @@ assert_equal '[1, 2, nil]', %q{ expandarray_rhs_too_small } +assert_equal '[nil, 2, nil]', %q{ + def foo(arr) + a, b, c = arr + end + + a, b, c1 = foo([0, 1]) + a, b, c2 = foo([0, 1, 2]) + a, b, c3 = foo([0, 1]) + [c1, c2, c3] +} + assert_equal '[1, [2]]', %q{ def expandarray_splat a, *b = [1, 2] diff --git a/yjit/src/backend/x86_64/mod.rs b/yjit/src/backend/x86_64/mod.rs index 0cc276fca1..1ae5ee7477 100644 --- a/yjit/src/backend/x86_64/mod.rs +++ b/yjit/src/backend/x86_64/mod.rs @@ -428,11 +428,31 @@ impl Assembler } } - fn emit_csel(cb: &mut CodeBlock, truthy: Opnd, falsy: Opnd, out: Opnd, cmov_fn: fn(&mut CodeBlock, X86Opnd, X86Opnd)) { - if out != truthy { - mov(cb, out.into(), truthy.into()); + fn emit_csel( + cb: &mut CodeBlock, + truthy: Opnd, + falsy: Opnd, + out: Opnd, + cmov_fn: fn(&mut CodeBlock, X86Opnd, X86Opnd), + cmov_neg: fn(&mut CodeBlock, X86Opnd, X86Opnd)){ + + // Assert that output is a register + out.unwrap_reg(); + + // If the truthy value is a memory operand + if let Opnd::Mem(_) = truthy { + if out != falsy { + mov(cb, out.into(), falsy.into()); + } + + cmov_fn(cb, out.into(), truthy.into()); + } else { + if out != truthy { + mov(cb, out.into(), truthy.into()); + } + + cmov_neg(cb, out.into(), falsy.into()); } - cmov_fn(cb, out.into(), falsy.into()); } //dbg!(&self.insns); @@ -724,28 +744,28 @@ impl Assembler Insn::Breakpoint => int3(cb), Insn::CSelZ { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovnz); + emit_csel(cb, *truthy, *falsy, *out, cmovz, cmovnz); }, Insn::CSelNZ { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovz); + emit_csel(cb, *truthy, *falsy, *out, cmovnz, cmovz); }, Insn::CSelE { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovne); + emit_csel(cb, *truthy, *falsy, *out, cmove, cmovne); }, Insn::CSelNE { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmove); + emit_csel(cb, *truthy, *falsy, *out, cmovne, cmove); }, Insn::CSelL { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovge); + emit_csel(cb, *truthy, *falsy, *out, cmovl, cmovge); }, Insn::CSelLE { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovg); + emit_csel(cb, *truthy, *falsy, *out, cmovle, cmovg); }, Insn::CSelG { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovle); + emit_csel(cb, *truthy, *falsy, *out, cmovg, cmovle); }, Insn::CSelGE { truthy, falsy, out } => { - emit_csel(cb, *truthy, *falsy, *out, cmovl); + emit_csel(cb, *truthy, *falsy, *out, cmovge, cmovl); } Insn::LiveReg { .. } => (), // just a reg alloc signal, no code Insn::PadInvalPatch => { @@ -1177,4 +1197,26 @@ mod tests { 0x23: call rax "}); } + + #[test] + fn test_cmov_mem() { + let (mut asm, mut cb) = setup_asm(); + + let top = Opnd::mem(64, SP, 0); + let ary_opnd = SP; + let array_len_opnd = Opnd::mem(64, SP, 16); + + asm.cmp(array_len_opnd, 1.into()); + let elem_opnd = asm.csel_g(Opnd::mem(64, ary_opnd, 0), Qnil.into()); + asm.mov(top, elem_opnd); + + asm.compile_with_num_regs(&mut cb, 1); + + assert_disasm!(cb, "48837b1001b804000000480f4f03488903", {" + 0x0: cmp qword ptr [rbx + 0x10], 1 + 0x5: mov eax, 4 + 0xa: cmovg rax, qword ptr [rbx] + 0xe: mov qword ptr [rbx], rax + "}); + } } diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 312bf3db16..ee299574a7 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -1467,10 +1467,10 @@ fn guard_object_is_not_ruby2_keyword_hash( fn gen_expandarray( jit: &mut JITState, asm: &mut Assembler, - _ocb: &mut OutlinedCb, + ocb: &mut OutlinedCb, ) -> Option { // Both arguments are rb_num_t which is unsigned - let num = jit.get_arg(0).as_usize(); + let num = jit.get_arg(0).as_u32(); let flag = jit.get_arg(1).as_usize(); // If this instruction has the splat flag, then bail out. @@ -1500,6 +1500,23 @@ fn gen_expandarray( return Some(KeepCompiling); } + // Defer compilation so we can specialize on a runtime `self` + if !jit.at_current_insn() { + defer_compilation(jit, asm, ocb); + return Some(EndBlock); + } + + let comptime_recv = jit.peek_at_stack(&asm.ctx, 0); + + // If the comptime receiver is not an array, bail + if comptime_recv.class_of() != unsafe { rb_cArray } { + gen_counter_incr(asm, Counter::expandarray_comptime_not_array); + return None; + } + + // Get the compile-time array length + let comptime_len = unsafe { rb_yjit_array_len(comptime_recv) as u32 }; + // Move the array from the stack and check that it's an array. guard_object_is_array( asm, @@ -1507,42 +1524,75 @@ fn gen_expandarray( array_opnd.into(), Counter::expandarray_not_array, ); - let array_opnd = asm.stack_pop(1); // pop after using the type info // If we don't actually want any values, then just return. if num == 0 { + asm.stack_pop(1); // pop the array return Some(KeepCompiling); } + let array_opnd = asm.stack_opnd(0); let array_reg = asm.load(array_opnd); let array_len_opnd = get_array_len(asm, array_reg); - // Only handle the case where the number of values in the array is greater - // than or equal to the number of values requested. - asm.cmp(array_len_opnd, num.into()); - asm.jl(Target::side_exit(Counter::expandarray_rhs_too_small)); + /* + // FIXME: JCC_JB not implemented + // Guard on the comptime/expected array length + if comptime_len >= num { + asm.comment(&format!("guard array length >= {}", num)); + asm.cmp(array_len_opnd, num.into()); + jit_chain_guard( + JCC_JB, + jit, + asm, + ocb, + OPT_AREF_MAX_CHAIN_DEPTH, + Counter::expandarray_chain_max_depth, + ); + } else { + asm.comment(&format!("guard array length == {}", comptime_len)); + asm.cmp(array_len_opnd, comptime_len.into()); + jit_chain_guard( + JCC_JNE, + jit, + asm, + ocb, + OPT_AREF_MAX_CHAIN_DEPTH, + Counter::expandarray_chain_max_depth, + ); + } + */ - // Load the address of the embedded array into REG1. - // (struct RArray *)(obj)->as.ary - let array_reg = asm.load(array_opnd); - let ary_opnd = asm.lea(Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RARRAY_AS_ARY)); - - // Conditionally load the address of the heap array into REG1. - // (struct RArray *)(obj)->as.heap.ptr - let flags_opnd = Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RBASIC_FLAGS); - asm.test(flags_opnd, Opnd::UImm(RARRAY_EMBED_FLAG as u64)); - let heap_ptr_opnd = Opnd::mem( - usize::BITS as u8, - asm.load(array_opnd), - RUBY_OFFSET_RARRAY_AS_HEAP_PTR, + asm.comment(&format!("guard array length == {}", comptime_len)); + asm.cmp(array_len_opnd, comptime_len.into()); + jit_chain_guard( + JCC_JNE, + jit, + asm, + ocb, + OPT_AREF_MAX_CHAIN_DEPTH, + Counter::expandarray_chain_max_depth, ); - let ary_opnd = asm.csel_nz(ary_opnd, heap_ptr_opnd); + + let array_opnd = asm.stack_pop(1); // pop after using the type info + + // Load the pointer to the embedded or heap array + let ary_opnd = if comptime_len > 0 { + let array_reg = asm.load(array_opnd); + Some(get_array_ptr(asm, array_reg)) + } else { + None + }; // Loop backward through the array and push each element onto the stack. for i in (0..num).rev() { - let top = asm.stack_push(Type::Unknown); - let offset = i32::try_from(i * SIZEOF_VALUE).unwrap(); - asm.mov(top, Opnd::mem(64, ary_opnd, offset)); + let top = asm.stack_push(if i < comptime_len { Type::Unknown } else { Type::Nil }); + let offset = i32::try_from(i * (SIZEOF_VALUE as u32)).unwrap(); + + // Missing elements are Qnil + asm.comment(&format!("load array[{}]", i)); + let elem_opnd = if i < comptime_len { Opnd::mem(64, ary_opnd.unwrap(), offset) } else { Qnil.into() }; + asm.mov(top, elem_opnd); } Some(KeepCompiling) @@ -5324,6 +5374,7 @@ fn get_array_ptr(asm: &mut Assembler, array_reg: Opnd) -> Opnd { array_reg, RUBY_OFFSET_RARRAY_AS_HEAP_PTR, ); + // Load the address of the embedded array // (struct RArray *)(obj)->as.ary let ary_opnd = asm.lea(Opnd::mem(VALUE_BITS, array_reg, RUBY_OFFSET_RARRAY_AS_ARY)); @@ -7394,7 +7445,7 @@ fn gen_leave( ocb: &mut OutlinedCb, ) -> Option { // Only the return value should be on the stack - assert_eq!(1, asm.ctx.get_stack_size()); + assert_eq!(1, asm.ctx.get_stack_size(), "leave instruction expects stack size 1, but was: {}", asm.ctx.get_stack_size()); let ocb_asm = Assembler::new(); diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index 686516cb9b..c9cf95daca 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -358,7 +358,8 @@ make_counters! { expandarray_splat, expandarray_postarg, expandarray_not_array, - expandarray_rhs_too_small, + expandarray_comptime_not_array, + expandarray_chain_max_depth, // getblockparam gbp_wb_required,