From 6c1b1fa1f52f7c79b98a6b86f4f1f03f069dd36b Mon Sep 17 00:00:00 2001 From: Takashi Kokubun Date: Tue, 7 Feb 2023 00:17:13 -0800 Subject: [PATCH] Refactor BranchStub --- lib/ruby_vm/mjit/branch_stub.rb | 36 +++++++++++------- lib/ruby_vm/mjit/c_pointer.rb | 3 ++ lib/ruby_vm/mjit/compiler.rb | 63 +++++++++++-------------------- lib/ruby_vm/mjit/exit_compiler.rb | 8 ++-- lib/ruby_vm/mjit/insn_compiler.rb | 44 ++++++++++----------- mjit.c | 6 +-- mjit_c.rb | 4 +- tool/mjit/bindgen.rb | 5 ++- vm_core.h | 4 +- 9 files changed, 83 insertions(+), 90 deletions(-) diff --git a/lib/ruby_vm/mjit/branch_stub.rb b/lib/ruby_vm/mjit/branch_stub.rb index 27ea5b9515..0f015e2f72 100644 --- a/lib/ruby_vm/mjit/branch_stub.rb +++ b/lib/ruby_vm/mjit/branch_stub.rb @@ -1,14 +1,24 @@ -class RubyVM::MJIT::BranchStub < Struct.new( - :iseq, # @param [RubyVM::MJIT::CPointer::Struct_rb_iseq_struct] Branch target ISEQ - :ctx, # @param [RubyVM::MJIT::Context] Branch target context - :branch_target_pc, # @param [Integer] Branch target PC - :branch_target_addr, # @param [Integer] Branch target address - :branch_target_next, # @param [Proc] Compile branch target next - :fallthrough_pc, # @param [Integer] Fallthrough PC - :fallthrough_addr, # @param [Integer] Fallthrough address - :fallthrough_next, # @param [Proc] Compile fallthrough next - :neither_next, # @param [Proc] Compile neither branch target nor fallthrough next - :start_addr, # @param [Integer] Stub source start address to be re-generated - :end_addr, # @param [Integer] Stub source end address to be re-generated -) +module RubyVM::MJIT + # Branch shapes + Next0 = :Next0 # target0 is a fallthrough + Next1 = :Next1 # target1 is a fallthrough + Default = :Default # neither targets is a fallthrough + + class BranchStub < Struct.new( + :iseq, # @param [RubyVM::MJIT::CPointer::Struct_rb_iseq_struct] Branch target ISEQ + :shape, # @param [Symbol] Next0, Next1, or Default + :target0, # @param [RubyVM::MJIT::BranchTarget] First branch target + :target1, # @param [RubyVM::MJIT::BranchTarget,NilClass] Second branch target (optional) + :compile, # @param [Proc] A callback to (re-)generate this branch stub + :start_addr, # @param [Integer] Stub source start address to be re-generated + :end_addr, # @param [Integer] Stub source end address to be re-generated + ) + end + + class BranchTarget < Struct.new( + :pc, + :ctx, + :address, + ) + end end diff --git a/lib/ruby_vm/mjit/c_pointer.rb b/lib/ruby_vm/mjit/c_pointer.rb index 6bdf92b6cf..03742dd53a 100644 --- a/lib/ruby_vm/mjit/c_pointer.rb +++ b/lib/ruby_vm/mjit/c_pointer.rb @@ -81,6 +81,9 @@ module RubyVM::MJIT end define_method("#{member}=") do |value| + if to_ruby + value = C.to_value(value) + end self[member] = value end end diff --git a/lib/ruby_vm/mjit/compiler.rb b/lib/ruby_vm/mjit/compiler.rb index 18f2d94016..0ad289c063 100644 --- a/lib/ruby_vm/mjit/compiler.rb +++ b/lib/ruby_vm/mjit/compiler.rb @@ -63,10 +63,7 @@ module RubyVM::MJIT asm.comment("Block: #{iseq.body.location.label}@#{C.rb_iseq_path(iseq)}:#{iseq.body.location.first_lineno}") compile_prologue(asm) compile_block(asm, jit:) - @cb.write(asm).tap do |addr| - jit.block.start_addr = addr - iseq.body.jit_func = addr - end + iseq.body.jit_func = @cb.write(asm) rescue Exception => e $stderr.puts e.full_message # TODO: check verbose end @@ -99,76 +96,62 @@ module RubyVM::MJIT @cb.write(asm) end new_addr - end.tap do |addr| - jit.block.start_addr = addr end end # Compile a branch stub. # @param branch_stub [RubyVM::MJIT::BranchStub] # @param cfp `RubyVM::MJIT::CPointer::Struct_rb_control_frame_t` - # @param branch_target_p [TrueClass,FalseClass] + # @param target0_p [TrueClass,FalseClass] # @return [Integer] The starting address of the compiled branch stub - def branch_stub_hit(branch_stub, cfp, branch_target_p) + def branch_stub_hit(branch_stub, cfp, target0_p) # Update cfp->pc for `jit.at_current_insn?` - pc = branch_target_p ? branch_stub.branch_target_pc : branch_stub.fallthrough_pc - cfp.pc = pc + target = target0_p ? branch_stub.target0 : branch_stub.target1 + cfp.pc = target.pc # Prepare the jump target new_asm = Assembler.new.tap do |asm| jit = JITState.new(iseq: branch_stub.iseq, cfp:) - compile_block(asm, jit:, pc:, ctx: branch_stub.ctx.dup) + compile_block(asm, jit:, pc: target.pc, ctx: target.ctx.dup) end # Rewrite the branch stub if @cb.write_addr == branch_stub.end_addr - # If the branch stub's jump is the last code, overwrite the jump with the new code. + # If the branch stub's jump is the last code, allow overwriting part of + # the old branch code with the new block code. @cb.set_write_addr(branch_stub.start_addr) + branch_stub.shape = target0_p ? Next0 : Next1 Assembler.new.tap do |branch_asm| - if branch_target_p - branch_stub.branch_target_next.call(branch_asm) - else - branch_stub.fallthrough_next.call(branch_asm) - end + branch_stub.compile.call(branch_asm) @cb.write(branch_asm) end - # Compile a fallthrough over the jump - if branch_target_p - branch_stub.branch_target_addr = @cb.write(new_asm) + # Compile a fallthrough right after the new branch code + if target0_p + branch_stub.target0.address = @cb.write(new_asm) else - branch_stub.fallthrough_addr = @cb.write(new_asm) + branch_stub.target1.address = @cb.write(new_asm) end else - # Otherwise, just prepare the new code somewhere - if branch_target_p - unless @cb.include?(branch_stub.branch_target_addr) - branch_stub.branch_target_addr = @cb.write(new_asm) - end + # Otherwise, just prepare the new block somewhere + if target0_p + branch_stub.target0.address = @cb.write(new_asm) else - unless @cb.include?(branch_stub.fallthrough_addr) - branch_stub.fallthrough_addr = @cb.write(new_asm) - end + branch_stub.target1.address = @cb.write(new_asm) end # Update jump destinations - branch_asm = Assembler.new - if branch_stub.end_addr == branch_stub.branch_target_addr # branch_target_next has been used - branch_stub.branch_target_next.call(branch_asm) - elsif branch_stub.end_addr == branch_stub.fallthrough_addr # fallthrough_next has been used - branch_stub.fallthrough_next.call(branch_asm) - else - branch_stub.neither_next.call(branch_asm) - end @cb.with_write_addr(branch_stub.start_addr) do + branch_asm = Assembler.new + branch_stub.compile.call(branch_asm) @cb.write(branch_asm) end end - if branch_target_p - branch_stub.branch_target_addr + if target0_p + branch_stub.target0.address else - branch_stub.fallthrough_addr + branch_stub.target1.address end end diff --git a/lib/ruby_vm/mjit/exit_compiler.rb b/lib/ruby_vm/mjit/exit_compiler.rb index f21ccced85..32ad59404f 100644 --- a/lib/ruby_vm/mjit/exit_compiler.rb +++ b/lib/ruby_vm/mjit/exit_compiler.rb @@ -75,13 +75,13 @@ module RubyVM::MJIT # @param ctx [RubyVM::MJIT::Context] # @param asm [RubyVM::MJIT::Assembler] # @param branch_stub [RubyVM::MJIT::BranchStub] - # @param branch_target_p [TrueClass,FalseClass] - def compile_branch_stub(jit, ctx, asm, branch_stub, branch_target_p) + # @param target0_p [TrueClass,FalseClass] + def compile_branch_stub(jit, ctx, asm, branch_stub, target0_p) # Call rb_mjit_branch_stub_hit - asm.comment("branch stub hit: #{branch_stub.iseq.body.location.label}@#{C.rb_iseq_path(branch_stub.iseq)}:#{iseq_lineno(branch_stub.iseq, branch_target_p ? branch_stub.branch_target_pc : branch_stub.fallthrough_pc)}") + asm.comment("branch stub hit: #{branch_stub.iseq.body.location.label}@#{C.rb_iseq_path(branch_stub.iseq)}:#{iseq_lineno(branch_stub.iseq, target0_p ? branch_stub.target0.pc : branch_stub.target1.pc)}") asm.mov(:rdi, to_value(branch_stub)) asm.mov(:esi, ctx.sp_offset) - asm.mov(:edx, branch_target_p ? 1 : 0) + asm.mov(:edx, target0_p ? 1 : 0) asm.call(C.rb_mjit_branch_stub_hit) # Jump to the address returned by rb_mjit_stub_hit diff --git a/lib/ruby_vm/mjit/insn_compiler.rb b/lib/ruby_vm/mjit/insn_compiler.rb index fb8b4af4ab..55d2e072f0 100644 --- a/lib/ruby_vm/mjit/insn_compiler.rb +++ b/lib/ruby_vm/mjit/insn_compiler.rb @@ -279,45 +279,38 @@ module RubyVM::MJIT ctx.stack_pop(1) # Set stubs - # TODO: reuse already-compiled blocks jumped from different blocks branch_stub = BranchStub.new( iseq: jit.iseq, - ctx: ctx.dup, - branch_target_pc: jit.pc + (jit.insn.len + jit.operand(0)) * C.VALUE.size, - fallthrough_pc: jit.pc + jit.insn.len * C.VALUE.size, + shape: Default, + target0: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * (jit.insn.len + jit.operand(0))), # branch target + target1: BranchTarget.new(ctx:, pc: jit.pc + C.VALUE.size * jit.insn.len), # fallthrough ) - branch_stub.branch_target_addr = Assembler.new.then do |ocb_asm| + branch_stub.target0.address = Assembler.new.then do |ocb_asm| @exit_compiler.compile_branch_stub(jit, ctx, ocb_asm, branch_stub, true) @ocb.write(ocb_asm) end - branch_stub.fallthrough_addr = Assembler.new.then do |ocb_asm| + branch_stub.target1.address = Assembler.new.then do |ocb_asm| @exit_compiler.compile_branch_stub(jit, ctx, ocb_asm, branch_stub, false) @ocb.write(ocb_asm) end - # Prepare codegen for all cases - branch_stub.branch_target_next = proc do |branch_asm| + # Jump to target0 on jz + branch_stub.compile = proc do |branch_asm| + branch_asm.comment("branchunless #{branch_stub.shape}") branch_asm.stub(branch_stub) do - branch_asm.comment('branch_target_next') - branch_asm.jnz(branch_stub.fallthrough_addr) - end - end - branch_stub.fallthrough_next = proc do |branch_asm| - branch_asm.stub(branch_stub) do - branch_asm.comment('fallthrough_next') - branch_asm.jz(branch_stub.branch_target_addr) - end - end - branch_stub.neither_next = proc do |branch_asm| - branch_asm.stub(branch_stub) do - branch_asm.comment('neither_next') - branch_asm.jz(branch_stub.branch_target_addr) - branch_asm.jmp(branch_stub.fallthrough_addr) + case branch_stub.shape + in Default + branch_asm.jz(branch_stub.target0.address) + branch_asm.jmp(branch_stub.target1.address) + in Next0 + branch_asm.jnz(branch_stub.target1.address) + in Next1 + branch_asm.jz(branch_stub.target0.address) + end end end + branch_stub.compile.call(asm) - # Just jump to stubs - branch_stub.neither_next.call(asm) EndBlock end @@ -598,6 +591,7 @@ module RubyVM::MJIT asm.incr_counter(:send_protected) return CantCompile # TODO: support this else + # TODO: Change them to a constant and use case-in instead raise 'unreachable' end diff --git a/mjit.c b/mjit.c index 26a05fc6ea..8cfaffca47 100644 --- a/mjit.c +++ b/mjit.c @@ -281,7 +281,7 @@ mjit_child_after_fork(void) void mjit_mark_cc_entries(const struct rb_iseq_constant_body *const body) { - // TODO: implement + rb_gc_mark(body->mjit_blocks); } // Compile ISeq to C code in `f`. It returns true if it succeeds to compile. @@ -401,7 +401,7 @@ rb_mjit_block_stub_hit(VALUE block_stub, int sp_offset) } void * -rb_mjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int branch_target_p) +rb_mjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p) { VALUE result; @@ -415,7 +415,7 @@ rb_mjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int branch_target_p) cfp->sp += sp_offset; // preserve stack values, also using the actual sp_offset to make jit.peek_at_stack work VALUE cfp_ptr = rb_funcall(rb_cMJITCfpPtr, rb_intern("new"), 1, SIZET2NUM((size_t)cfp)); - result = rb_funcall(rb_MJITCompiler, rb_intern("branch_stub_hit"), 3, branch_stub, cfp_ptr, RBOOL(branch_target_p)); + result = rb_funcall(rb_MJITCompiler, rb_intern("branch_stub_hit"), 3, branch_stub, cfp_ptr, RBOOL(target0_p)); cfp->sp -= sp_offset; // reset for consistency with the code without the stub diff --git a/mjit_c.rb b/mjit_c.rb index 871e5b461c..45e8c260ae 100644 --- a/mjit_c.rb +++ b/mjit_c.rb @@ -45,7 +45,7 @@ module RubyVM::MJIT # :nodoc: all def rb_mjit_branch_stub_hit Primitive.cstmt! %{ - extern void *rb_mjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int branch_target_p); + extern void *rb_mjit_branch_stub_hit(VALUE branch_stub, int sp_offset, int target0_p); return SIZET2NUM((size_t)rb_mjit_branch_stub_hit); } end @@ -703,7 +703,7 @@ module RubyVM::MJIT # :nodoc: all mandatory_only_iseq: [CType::Pointer.new { self.rb_iseq_t }, Primitive.cexpr!("OFFSETOF((*((struct rb_iseq_constant_body *)NULL)), mandatory_only_iseq)")], jit_func: [CType::Immediate.parse("void *"), Primitive.cexpr!("OFFSETOF((*((struct rb_iseq_constant_body *)NULL)), jit_func)")], total_calls: [CType::Immediate.parse("unsigned long"), Primitive.cexpr!("OFFSETOF((*((struct rb_iseq_constant_body *)NULL)), total_calls)")], - mjit_unit: [CType::Pointer.new { self.rb_mjit_unit }, Primitive.cexpr!("OFFSETOF((*((struct rb_iseq_constant_body *)NULL)), mjit_unit)")], + mjit_blocks: [self.VALUE, Primitive.cexpr!("OFFSETOF((*((struct rb_iseq_constant_body *)NULL)), mjit_blocks)"), true], ) end diff --git a/tool/mjit/bindgen.rb b/tool/mjit/bindgen.rb index 1dbb742c59..891cb7d297 100755 --- a/tool/mjit/bindgen.rb +++ b/tool/mjit/bindgen.rb @@ -437,11 +437,14 @@ generator = BindingGenerator.new( rb_iseq_constant_body: %w[yjit_payload], # conditionally defined }, ruby_fields: { + rb_iseq_constant_body: %w[ + mjit_blocks + ], rb_iseq_location_struct: %w[ base_label label pathobj - ] + ], }, ) generator.generate(nodes) diff --git a/vm_core.h b/vm_core.h index 9d7c54b4d1..449f9b9a26 100644 --- a/vm_core.h +++ b/vm_core.h @@ -511,8 +511,8 @@ struct rb_iseq_constant_body { #endif #if USE_MJIT - // MJIT stores some data on each iseq. - struct rb_mjit_unit *mjit_unit; + // MJIT stores { Context => Block } for each iseq. + VALUE mjit_blocks; #endif #if USE_YJIT