diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index 31bb626690..1eb6e393e7 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -4859,3 +4859,38 @@ assert_equal '["raised", "Module", "Object"]', %q{ ret += [foo(Class), foo(Class.new)] } + +# test TrueClass#=== before and after redefining TrueClass#== +assert_equal '[[true, false, false], [true, true, false], [true, :error, :error]]', %q{ + def true_eqq(x) + true === x + rescue NoMethodError + :error + end + + def test + [ + # first one is always true because rb_equal does object comparison before calling #== + true_eqq(true), + # these will use TrueClass#== + true_eqq(false), + true_eqq(:truthy), + ] + end + + results = [test] + + class TrueClass + def ==(x) + !x + end + end + + results << test + + class TrueClass + undef_method :== + end + + results << test +} unless rjit_enabled? # Not yet working on RJIT diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index 953ab0ac42..a7473c1bf6 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -380,6 +380,7 @@ fn main() { // From internal/object.h .allowlist_function("rb_class_allocate_instance") + .allowlist_function("rb_obj_equal") // From gc.h and internal/gc.h .allowlist_function("rb_obj_info") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 000f9fb516..62d6c085fb 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -250,6 +250,33 @@ impl JITState { } } + pub fn assume_expected_cfunc( + &mut self, + asm: &mut Assembler, + ocb: &mut OutlinedCb, + class: VALUE, + method: ID, + cfunc: *mut c_void, + ) -> bool { + let cme = unsafe { rb_callable_method_entry(class, method) }; + + if cme.is_null() { + return false; + } + + let def_type = unsafe { get_cme_def_type(cme) }; + if def_type != VM_METHOD_TYPE_CFUNC { + return false; + } + if unsafe { get_mct_func(get_cme_def_body_cfunc(cme)) } != cfunc { + return false; + } + + self.assume_method_lookup_stable(asm, ocb, cme); + + true + } + pub fn assume_method_lookup_stable(&mut self, asm: &mut Assembler, ocb: &mut OutlinedCb, cme: CmePtr) -> Option<()> { jit_ensure_block_entry_exit(self, asm, ocb)?; self.method_lookup_assumptions.push(cme); @@ -6164,6 +6191,34 @@ fn jit_rb_class_superclass( true } +// Codegen for rb_trueclass_case_equal() +fn jit_rb_trueclass_case_equal( + jit: &mut JITState, + asm: &mut Assembler, + ocb: &mut OutlinedCb, + _ci: *const rb_callinfo, + _cme: *const rb_callable_method_entry_t, + _block: Option, + _argc: i32, + _known_recv_class: Option, +) -> bool { + if !jit.assume_expected_cfunc( asm, ocb, unsafe { rb_cTrueClass }, ID!(eq), rb_obj_equal as _) { + return false; + } + + // Compare the arguments + asm_comment!(asm, "TrueClass#==="); + let arg1 = asm.stack_pop(1); + let arg0 = asm.stack_pop(1); + asm.cmp(arg0, arg1); + let ret_opnd = asm.csel_e(Qtrue.into(), Qfalse.into()); + + let stack_ret = asm.stack_push(Type::UnknownImm); + asm.mov(stack_ret, ret_opnd); + + true +} + fn jit_thread_s_current( _jit: &mut JITState, asm: &mut Assembler, @@ -10166,6 +10221,8 @@ pub fn yjit_reg_method_codegen_fns() { yjit_reg_method(rb_cString, "<<", jit_rb_str_concat); yjit_reg_method(rb_cString, "+@", jit_rb_str_uplus); + yjit_reg_method(rb_cTrueClass, "===", jit_rb_trueclass_case_equal); + yjit_reg_method(rb_cArray, "empty?", jit_rb_ary_empty_p); yjit_reg_method(rb_cArray, "length", jit_rb_ary_length); yjit_reg_method(rb_cArray, "size", jit_rb_ary_length); diff --git a/yjit/src/cruby.rs b/yjit/src/cruby.rs index d07262ad4f..9c44ed681f 100644 --- a/yjit/src/cruby.rs +++ b/yjit/src/cruby.rs @@ -805,6 +805,7 @@ pub(crate) mod ids { name: hash content: b"hash" name: respond_to_missing content: b"respond_to_missing?" name: to_ary content: b"to_ary" + name: eq content: b"==" } } diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 70578ec7e9..a03c2d0f00 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1024,6 +1024,7 @@ extern "C" { pub fn rb_attr_get(obj: VALUE, name: ID) -> VALUE; pub fn rb_obj_info_dump(obj: VALUE); pub fn rb_class_allocate_instance(klass: VALUE) -> VALUE; + pub fn rb_obj_equal(obj1: VALUE, obj2: VALUE) -> VALUE; pub fn rb_reg_new_ary(ary: VALUE, options: ::std::os::raw::c_int) -> VALUE; pub fn rb_ary_tmp_new_from_values( arg1: VALUE,