diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index 25085f0e8d..349417595f 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -3436,3 +3436,23 @@ assert_equal '1', %q{ foo(1) } + +# case-when with redefined === +assert_equal 'ok', %q{ + class Symbol + def ===(a) + true + end + end + + def cw(arg) + case arg + when :b + :ok + when 4 + :ng + end + end + + cw(4) +} diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index ffcc148685..8098de87b0 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -100,6 +100,10 @@ fn main() { // From internal/hash.h .allowlist_function("rb_hash_new_with_size") .allowlist_function("rb_hash_resurrect") + .allowlist_function("rb_hash_stlike_foreach") + + // From include/ruby/st.h + .allowlist_type("st_retval") // From include/ruby/internal/intern/hash.h .allowlist_function("rb_hash_aset") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index 5b7eb8f67e..7c2c0dbe87 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -16,7 +16,7 @@ use std::cmp; use std::collections::HashMap; use std::ffi::CStr; use std::mem::{self, size_of}; -use std::os::raw::c_uint; +use std::os::raw::{c_int, c_uint}; use std::ptr; use std::slice; @@ -3154,7 +3154,27 @@ fn gen_opt_case_dispatch( // Supporting only Fixnum for now so that the implementation can be an equality check. let key_opnd = ctx.stack_pop(1); let comptime_key = jit_peek_at_stack(jit, ctx, 0); - if comptime_key.fixnum_p() && comptime_key.0 <= u32::MAX.as_usize() { + + // Check that all cases are fixnums to avoid having to register BOP assumptions on + // all the types that case hashes support. This spends compile time to save memory. + fn case_hash_all_fixnum_p(hash: VALUE) -> bool { + let mut all_fixnum = true; + unsafe { + unsafe extern "C" fn per_case(key: st_data_t, _value: st_data_t, data: st_data_t) -> c_int { + (if VALUE(key as usize).fixnum_p() { + ST_CONTINUE + } else { + (data as *mut bool).write(false); + ST_STOP + }) as c_int + } + rb_hash_stlike_foreach(hash, Some(per_case), (&mut all_fixnum) as *mut _ as st_data_t); + } + + all_fixnum + } + + if comptime_key.fixnum_p() && comptime_key.0 <= u32::MAX.as_usize() && case_hash_all_fixnum_p(case_hash) { if !assume_bop_not_redefined(jit, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_EQQ) { return CantCompile; } diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 671c0ad353..12617af86c 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -234,6 +234,19 @@ pub const RUBY_FL_SINGLETON: ruby_fl_type = 4096; pub type ruby_fl_type = i32; pub type st_data_t = ::std::os::raw::c_ulong; pub type st_index_t = st_data_t; +pub const ST_CONTINUE: st_retval = 0; +pub const ST_STOP: st_retval = 1; +pub const ST_DELETE: st_retval = 2; +pub const ST_CHECK: st_retval = 3; +pub const ST_REPLACE: st_retval = 4; +pub type st_retval = u32; +pub type st_foreach_callback_func = ::std::option::Option< + unsafe extern "C" fn( + arg1: st_data_t, + arg2: st_data_t, + arg3: st_data_t, + ) -> ::std::os::raw::c_int, +>; pub const RARRAY_EMBED_FLAG: ruby_rarray_flags = 8192; pub const RARRAY_EMBED_LEN_MASK: ruby_rarray_flags = 4161536; pub const RARRAY_TRANSIENT_FLAG: ruby_rarray_flags = 33554432; @@ -1023,6 +1036,13 @@ extern "C" { extern "C" { pub fn rb_ec_str_resurrect(ec: *mut rb_execution_context_struct, str_: VALUE) -> VALUE; } +extern "C" { + pub fn rb_hash_stlike_foreach( + hash: VALUE, + func: st_foreach_callback_func, + arg: st_data_t, + ) -> ::std::os::raw::c_int; +} extern "C" { pub fn rb_hash_new_with_size(size: st_index_t) -> VALUE; }