From beafae97505f9def3967e958bb1f7bc7fd7b9a7a Mon Sep 17 00:00:00 2001 From: Randy Stauner Date: Wed, 13 Nov 2024 10:25:09 -0700 Subject: [PATCH] YJIT: Specialize `String#[]` (`String#slice`) with fixnum arguments (#12069) * YJIT: Specialize `String#[]` (`String#slice`) with fixnum arguments String#[] is in the top few C calls of several YJIT benchmarks: liquid-compile rubocop mail sudoku This speeds up these benchmarks by 1-2%. * YJIT: Try harder to get type info for `String#[]` In the large generated code of the mail gem the context doesn't have the type info. In that case if we peek at the stack and add a guard we can still apply the specialization and it speeds up the mail benchmark by 5%. Co-authored-by: Maxime Chevalier-Boisvert Co-authored-by: Takashi Kokubun (k0kubun) --------- Co-authored-by: Maxime Chevalier-Boisvert Co-authored-by: Takashi Kokubun (k0kubun) --- bootstraptest/test_yjit.rb | 40 +++++++++++++++++ internal/string.h | 1 + string.c | 10 +++-- yjit/bindgen/src/main.rs | 1 + yjit/src/codegen.rs | 78 ++++++++++++++++++++++++++++++++++ yjit/src/cruby_bindings.inc.rs | 6 +++ yjit/src/stats.rs | 1 + 7 files changed, 134 insertions(+), 3 deletions(-) diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index d647992986..40f1db2ec5 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -5229,3 +5229,43 @@ assert_equal '[true, true]', <<~'RUBY' [pack, with_buffer] RUBY + +assert_equal 'ok', <<~'RUBY' + def error(klass) + yield + rescue klass + true + end + + def test + str = "こんにちは" + substr = "にち" + failures = [] + + # Use many small statements to keep context for each slice call smaller than MAX_CTX_TEMPS + + str[1] == "ん" && str.slice(4) == "は" || failures << :index + str[5].nil? && str.slice(5).nil? || failures << :index_end + + str[1, 2] == "んに" && str.slice(2, 1) == "に" || failures << :beg_len + str[5, 1] == "" && str.slice(5, 1) == "" || failures << :beg_len_end + + str[1..2] == "んに" && str.slice(2..2) == "に" || failures << :range + + str[/に./] == "にち" && str.slice(/に./) == "にち" || failures << :regexp + + str[/に./, 0] == "にち" && str.slice(/に./, 0) == "にち" || failures << :regexp_cap0 + + str[/に(.)/, 1] == "ち" && str.slice(/に(.)/, 1) == "ち" || failures << :regexp_cap1 + + str[substr] == substr && str.slice(substr) == substr || failures << :substr + + error(TypeError) { str[Object.new] } && error(TypeError) { str.slice(Object.new, 1) } || failures << :type_error + error(RangeError) { str[Float::INFINITY] } && error(RangeError) { str.slice(Float::INFINITY) } || failures << :range_error + + return "ok" if failures.empty? + {failures: failures} + end + + test +RUBY diff --git a/internal/string.h b/internal/string.h index efeb0827c9..87cefa13d5 100644 --- a/internal/string.h +++ b/internal/string.h @@ -53,6 +53,7 @@ int rb_enc_str_coderange_scan(VALUE str, rb_encoding *enc); int rb_ascii8bit_appendable_encoding_index(rb_encoding *enc, unsigned int code); VALUE rb_str_include(VALUE str, VALUE arg); VALUE rb_str_byte_substr(VALUE str, VALUE beg, VALUE len); +VALUE rb_str_substr_two_fixnums(VALUE str, VALUE beg, VALUE len, int empty); VALUE rb_str_tmp_frozen_no_embed_acquire(VALUE str); void rb_str_make_embedded(VALUE); VALUE rb_str_upto_each(VALUE, VALUE, int, int (*each)(VALUE, VALUE), VALUE); diff --git a/string.c b/string.c index f6f67e185c..d80d214b50 100644 --- a/string.c +++ b/string.c @@ -3152,6 +3152,12 @@ rb_str_substr(VALUE str, long beg, long len) return str_substr(str, beg, len, TRUE); } +VALUE +rb_str_substr_two_fixnums(VALUE str, VALUE beg, VALUE len, int empty) +{ + return str_substr(str, NUM2LONG(beg), NUM2LONG(len), empty); +} + static VALUE str_substr(VALUE str, long beg, long len, int empty) { @@ -5680,9 +5686,7 @@ rb_str_aref_m(int argc, VALUE *argv, VALUE str) return rb_str_subpat(str, argv[0], argv[1]); } else { - long beg = NUM2LONG(argv[0]); - long len = NUM2LONG(argv[1]); - return rb_str_substr(str, beg, len); + return rb_str_substr_two_fixnums(str, argv[0], argv[1], TRUE); } } rb_check_arity(argc, 1, 2); diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index 5f0f599b9a..e8f88cef64 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -226,6 +226,7 @@ fn main() { .allowlist_function("rb_str_concat_literals") .allowlist_function("rb_obj_as_string_result") .allowlist_function("rb_str_byte_substr") + .allowlist_function("rb_str_substr_two_fixnums") // From include/ruby/internal/intern/parse.h .allowlist_function("rb_backref_get") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index bff7960990..d810a9f0dd 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -5792,6 +5792,82 @@ fn jit_rb_str_byteslice( true } +fn jit_rb_str_aref_m( + jit: &mut JITState, + asm: &mut Assembler, + _ci: *const rb_callinfo, + _cme: *const rb_callable_method_entry_t, + _block: Option, + argc: i32, + _known_recv_class: Option, +) -> bool { + // In yjit-bench the most common usages by far are single fixnum or two fixnums. + // rb_str_substr should be leaf if indexes are fixnums + if argc == 2 { + match (asm.ctx.get_opnd_type(StackOpnd(0)), asm.ctx.get_opnd_type(StackOpnd(1))) { + (Type::Fixnum, Type::Fixnum) => {}, + // There is a two-argument form of (RegExp, Fixnum) which needs a different c func. + // Other types will raise. + _ => { return false }, + } + } else if argc == 1 { + match asm.ctx.get_opnd_type(StackOpnd(0)) { + Type::Fixnum => {}, + // Besides Fixnum this could also be a Range or a RegExp which are handled by separate c funcs. + // Other types will raise. + _ => { + // If the context doesn't have the type info we try a little harder. + let comptime_arg = jit.peek_at_stack(&asm.ctx, 0); + let arg0 = asm.stack_opnd(0); + if comptime_arg.fixnum_p() { + asm.test(arg0, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); + + jit_chain_guard( + JCC_JZ, + jit, + asm, + SEND_MAX_DEPTH, + Counter::guard_send_str_aref_not_fixnum, + ); + } else { + return false + } + }, + } + } else { + return false + } + + asm_comment!(asm, "String#[]"); + + // rb_str_substr allocates a substring + jit_prepare_call_with_gc(jit, asm); + + // Get stack operands after potential SP change + + // The "empty" arg distinguishes between the normal "one arg" behavior + // and the "two arg" special case that returns an empty string + // when the begin index is the length of the string. + // See the usages of rb_str_substr in string.c for more information. + let (beg_idx, empty, len) = if argc == 2 { + (1, Opnd::Imm(1), asm.stack_opnd(0)) + } else { + // If there is only one arg, the length will be 1. + (0, Opnd::Imm(0), VALUE::fixnum_from_usize(1).into()) + }; + + let beg = asm.stack_opnd(beg_idx); + let recv = asm.stack_opnd(beg_idx + 1); + + let ret_opnd = asm.ccall(rb_str_substr_two_fixnums as *const u8, vec![recv, beg, len, empty]); + asm.stack_pop(beg_idx as usize + 2); + + let out_opnd = asm.stack_push(Type::Unknown); + asm.mov(out_opnd, ret_opnd); + + true +} + fn jit_rb_str_getbyte( jit: &mut JITState, asm: &mut Assembler, @@ -10469,6 +10545,8 @@ pub fn yjit_reg_method_codegen_fns() { reg_method_codegen(rb_cString, "getbyte", jit_rb_str_getbyte); reg_method_codegen(rb_cString, "setbyte", jit_rb_str_setbyte); reg_method_codegen(rb_cString, "byteslice", jit_rb_str_byteslice); + reg_method_codegen(rb_cString, "[]", jit_rb_str_aref_m); + reg_method_codegen(rb_cString, "slice", jit_rb_str_aref_m); reg_method_codegen(rb_cString, "<<", jit_rb_str_concat); reg_method_codegen(rb_cString, "+@", jit_rb_str_uplus); diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 4635fd9c33..90d585ac37 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1091,6 +1091,12 @@ extern "C" { pub fn rb_ensure_iv_list_size(obj: VALUE, len: u32, newsize: u32); pub fn rb_vm_barrier(); pub fn rb_str_byte_substr(str_: VALUE, beg: VALUE, len: VALUE) -> VALUE; + pub fn rb_str_substr_two_fixnums( + str_: VALUE, + beg: VALUE, + len: VALUE, + empty: ::std::os::raw::c_int, + ) -> VALUE; pub fn rb_obj_as_string_result(str_: VALUE, obj: VALUE) -> VALUE; pub fn rb_str_concat_literals(num: usize, strary: *const VALUE) -> VALUE; pub fn rb_ec_str_resurrect( diff --git a/yjit/src/stats.rs b/yjit/src/stats.rs index 3dc37d4bac..ee38bf7fb9 100644 --- a/yjit/src/stats.rs +++ b/yjit/src/stats.rs @@ -462,6 +462,7 @@ make_counters! { guard_send_not_fixnum_or_flonum, guard_send_not_string, guard_send_respond_to_mid_mismatch, + guard_send_str_aref_not_fixnum, guard_send_cfunc_bad_splat_vargs,