diff --git a/lib/ruby_vm/rjit/assembler.rb b/lib/ruby_vm/rjit/assembler.rb index 6bd654fd3e..bd8f8ad1d6 100644 --- a/lib/ruby_vm/rjit/assembler.rb +++ b/lib/ruby_vm/rjit/assembler.rb @@ -838,6 +838,20 @@ module RubyVM::RJIT end end + def xor(dst, src) + case [dst, src] + # XOR r/m64, r64 (Mod 11: reg) + in [R64 => dst_reg, R64 => src_reg] + # REX.W + 31 /r + # MR: Operand 1: ModRM:r/m (r, w), Operand 2: ModRM:reg (r) + insn( + prefix: REX_W, + opcode: 0x31, + mod_rm: ModRM[mod: Mod11, reg: src_reg, rm: dst_reg], + ) + end + end + # # Utilities # diff --git a/lib/ruby_vm/rjit/insn_compiler.rb b/lib/ruby_vm/rjit/insn_compiler.rb index f860eb1c71..336dd9cea9 100644 --- a/lib/ruby_vm/rjit/insn_compiler.rb +++ b/lib/ruby_vm/rjit/insn_compiler.rb @@ -2783,6 +2783,75 @@ module RubyVM::RJIT true end + # @param jit [RubyVM::RJIT::JITState] + # @param ctx [RubyVM::RJIT::Context] + # @param asm [RubyVM::RJIT::Assembler] + def jit_rb_str_concat(jit, ctx, asm, argc, known_recv_class) + # The << operator can accept integer codepoints for characters + # as the argument. We only specially optimise string arguments. + # If the peeked-at compile time argument is something other than + # a string, assume it won't be a string later either. + comptime_arg = jit.peek_at_stack(0) + unless C.RB_TYPE_P(comptime_arg, C::RUBY_T_STRING) + return false + end + + # Generate a side exit + side_exit = side_exit(jit, ctx) + + # Guard that the concat argument is a string + asm.mov(:rax, ctx.stack_opnd(0)) + guard_object_is_string(asm, :rax, :rcx, side_exit) + + # Guard buffers from GC since rb_str_buf_append may allocate. + jit_save_sp(ctx, asm) + + concat_arg = ctx.stack_pop(1) + recv = ctx.stack_pop(1) + + # Test if string encodings differ. If different, use rb_str_append. If the same, + # use rb_yjit_str_simple_append, which calls rb_str_cat. + asm.comment('<< on strings') + + # Take receiver's object flags XOR arg's flags. If any + # string-encoding flags are different between the two, + # the encodings don't match. + recv_reg = :rax + asm.mov(recv_reg, recv) + concat_arg_reg = :rcx + asm.mov(concat_arg_reg, concat_arg) + asm.mov(recv_reg, [recv_reg, C.RBasic.offsetof(:flags)]) + asm.mov(concat_arg_reg, [concat_arg_reg, C.RBasic.offsetof(:flags)]) + asm.xor(recv_reg, concat_arg_reg) + asm.test(recv_reg, C::RUBY_ENCODING_MASK) + + # Push once, use the resulting operand in both branches below. + stack_ret = ctx.stack_push + + enc_mismatch = asm.new_label('enc_mismatch') + asm.jnz(enc_mismatch) + + # If encodings match, call the simple append function and jump to return + asm.mov(C_ARGS[0], recv) + asm.mov(C_ARGS[1], concat_arg) + asm.call(C.rjit_str_simple_append) + ret_label = asm.new_label('func_return') + asm.mov(stack_ret, C_RET) + asm.jmp(ret_label) + + # If encodings are different, use a slower encoding-aware concatenate + asm.write_label(enc_mismatch) + asm.mov(C_ARGS[0], recv) + asm.mov(C_ARGS[1], concat_arg) + asm.call(C.rb_str_buf_append) + asm.mov(stack_ret, C_RET) + # Drop through to return + + asm.write_label(ret_label) + + true + end + # @param jit [RubyVM::RJIT::JITState] # @param ctx [RubyVM::RJIT::Context] # @param asm [RubyVM::RJIT::Assembler] @@ -2870,7 +2939,7 @@ module RubyVM::RJIT register_cfunc_method(String, :to_s, :jit_rb_str_to_s) register_cfunc_method(String, :to_str, :jit_rb_str_to_s) register_cfunc_method(String, :bytesize, :jit_rb_str_bytesize) - #register_cfunc_method(String, :<<, :jit_rb_str_concat) + register_cfunc_method(String, :<<, :jit_rb_str_concat) #register_cfunc_method(String, :+@, :jit_rb_str_uplus) # rb_ary_empty_p() method in array.c @@ -2994,6 +3063,17 @@ module RubyVM::RJIT asm.jne(side_exit) end + def guard_object_is_string(asm, object_reg, flags_reg, side_exit) + asm.comment('guard object is string') + # Pull out the type mask + asm.mov(flags_reg, [object_reg, C.RBasic.offsetof(:flags)]) + asm.and(flags_reg, C::RUBY_T_MASK) + + # Compare the result with T_STRING + asm.cmp(flags_reg, C::RUBY_T_STRING) + asm.jne(side_exit) + end + # @param jit [RubyVM::RJIT::JITState] # @param ctx [RubyVM::RJIT::Context] # @param asm [RubyVM::RJIT::Assembler] diff --git a/rjit_c.c b/rjit_c.c index dd6067f334..bf0bf6f410 100644 --- a/rjit_c.c +++ b/rjit_c.c @@ -170,6 +170,12 @@ rjit_str_neq_internal(VALUE str1, VALUE str2) return rb_str_eql_internal(str1, str2) == Qtrue ? Qfalse : Qtrue; } +static VALUE +rjit_str_simple_append(VALUE str1, VALUE str2) +{ + return rb_str_cat(str1, RSTRING_PTR(str2), RSTRING_LEN(str2)); +} + // The code we generate in gen_send_cfunc() doesn't fire the c_return TracePoint event // like the interpreter. When tracing for c_return is enabled, we patch the code after // the C method return to call into this to fire the event. diff --git a/rjit_c.rb b/rjit_c.rb index fe6cd38678..fcc510adcb 100644 --- a/rjit_c.rb +++ b/rjit_c.rb @@ -364,6 +364,7 @@ module RubyVM::RJIT # :nodoc: all C::RMODULE_IS_REFINEMENT = Primitive.cexpr! %q{ SIZET2NUM(RMODULE_IS_REFINEMENT) } C::ROBJECT_EMBED = Primitive.cexpr! %q{ SIZET2NUM(ROBJECT_EMBED) } C::RSTRUCT_EMBED_LEN_MASK = Primitive.cexpr! %q{ SIZET2NUM(RSTRUCT_EMBED_LEN_MASK) } + C::RUBY_ENCODING_MASK = Primitive.cexpr! %q{ SIZET2NUM(RUBY_ENCODING_MASK) } C::RUBY_EVENT_CLASS = Primitive.cexpr! %q{ SIZET2NUM(RUBY_EVENT_CLASS) } C::RUBY_EVENT_C_CALL = Primitive.cexpr! %q{ SIZET2NUM(RUBY_EVENT_C_CALL) } C::RUBY_EVENT_C_RETURN = Primitive.cexpr! %q{ SIZET2NUM(RUBY_EVENT_C_RETURN) } @@ -603,6 +604,10 @@ module RubyVM::RJIT # :nodoc: all Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_reg_nth_match) } end + def C.rb_str_buf_append + Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_str_buf_append) } + end + def C.rb_str_bytesize Primitive.cexpr! %q{ SIZET2NUM((size_t)rb_str_bytesize) } end @@ -683,6 +688,10 @@ module RubyVM::RJIT # :nodoc: all Primitive.cexpr! %q{ SIZET2NUM((size_t)rjit_str_neq_internal) } end + def C.rjit_str_simple_append + Primitive.cexpr! %q{ SIZET2NUM((size_t)rjit_str_simple_append) } + end + def C.CALL_DATA @CALL_DATA ||= self.rb_call_data end diff --git a/test/ruby/rjit/test_assembler.rb b/test/ruby/rjit/test_assembler.rb index 45805115a5..639cf170e5 100644 --- a/test/ruby/rjit/test_assembler.rb +++ b/test/ruby/rjit/test_assembler.rb @@ -321,6 +321,14 @@ module RubyVM::RJIT EOS end + def test_xor + asm = Assembler.new + asm.xor(:rax, :rbx) + assert_compile(asm, <<~EOS) + 0x0: xor rax, rbx + EOS + end + private def rel32(offset) diff --git a/tool/rjit/bindgen.rb b/tool/rjit/bindgen.rb index 6764f98594..4cca493bcf 100755 --- a/tool/rjit/bindgen.rb +++ b/tool/rjit/bindgen.rb @@ -475,6 +475,7 @@ generator = BindingGenerator.new( VM_METHOD_TYPE_UNDEF VM_METHOD_TYPE_ZSUPER VM_SPECIAL_OBJECT_VMCORE + RUBY_ENCODING_MASK ], }, values: { @@ -546,6 +547,8 @@ generator = BindingGenerator.new( rb_str_intern rb_vm_setclassvariable rb_str_bytesize + rjit_str_simple_append + rb_str_buf_append ], types: %w[ CALL_DATA