diff --git a/bootstraptest/test_yjit.rb b/bootstraptest/test_yjit.rb index a52ed8027d..d124d180d1 100644 --- a/bootstraptest/test_yjit.rb +++ b/bootstraptest/test_yjit.rb @@ -2462,6 +2462,57 @@ assert_equal '[[1, 2, 3, 4]]', %q{ 5.times.map { foo(specified: 2, required: 1) }.uniq } +# cfunc kwargs +assert_equal '{:foo=>123}', %q{ + def foo(bar) + bar.store(:value, foo: 123) + bar[:value] + end + + foo({}) + foo({}) +} + +# cfunc kwargs +assert_equal '{:foo=>123}', %q{ + def foo(bar) + bar.replace(foo: 123) + end + + foo({}) + foo({}) +} + +# cfunc kwargs +assert_equal '{:foo=>123, :bar=>456}', %q{ + def foo(bar) + bar.replace(foo: 123, bar: 456) + end + + foo({}) + foo({}) +} + +# variadic cfunc kwargs +assert_equal '{:foo=>123}', %q{ + def foo(bar) + bar.merge(foo: 123) + end + + foo({}) + foo({}) +} + +# optimized cfunc kwargs +assert_equal 'false', %q{ + def foo + :foo.eql?(foo: :foo) + end + + foo + foo +} + # attr_reader on frozen object assert_equal 'false', %q{ class Foo diff --git a/test/ruby/test_yjit.rb b/test/ruby/test_yjit.rb index 41a6d50779..88f8e42813 100644 --- a/test/ruby/test_yjit.rb +++ b/test/ruby/test_yjit.rb @@ -523,6 +523,13 @@ class TestYJIT < Test::Unit::TestCase RUBY end + def test_cfunc_kwarg + assert_no_exits('{}.store(:value, foo: 123)') + assert_no_exits('{}.store(:value, foo: 123, bar: 456, baz: 789)') + assert_no_exits('{}.merge(foo: 123)') + assert_no_exits('{}.merge(foo: 123, bar: 456, baz: 789)') + end + def test_ctx_different_mappings # regression test simplified from URI::Generic#hostname= assert_compiles(<<~'RUBY', frozen_string_literal: true) diff --git a/yjit_codegen.c b/yjit_codegen.c index c8de630747..52ac8aaa32 100644 --- a/yjit_codegen.c +++ b/yjit_codegen.c @@ -3239,19 +3239,40 @@ c_method_tracing_currently_enabled(const jitstate_t *jit) return tracing_events & (RUBY_EVENT_C_CALL | RUBY_EVENT_C_RETURN); } +// Called at runtime to build hashes of passed kwargs +static VALUE +yjit_runtime_build_kwhash(const struct rb_callinfo *ci, const VALUE *sp) { + // similar to args_kw_argv_to_hash + const VALUE *const passed_keywords = vm_ci_kwarg(ci)->keywords; + const int kw_len = vm_ci_kwarg(ci)->keyword_len; + const VALUE h = rb_hash_new_with_size(kw_len); + + for (int i = 0; i < kw_len; i++) { + rb_hash_aset(h, passed_keywords[i], (sp - kw_len)[i]); + } + return h; +} + static codegen_status_t gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const rb_callable_method_entry_t *cme, rb_iseq_t *block, const int32_t argc, VALUE *recv_known_klass) { const rb_method_cfunc_t *cfunc = UNALIGNED_MEMBER_PTR(cme->def, body.cfunc); + const struct rb_callinfo_kwarg *kw_arg = vm_ci_kwarg(ci); + const int kw_arg_num = kw_arg ? kw_arg->keyword_len : 0; + + // Number of args which will be passed through to the callee + // This is adjusted by the kwargs being combined into a hash. + const int passed_argc = kw_arg ? argc - kw_arg_num + 1 : argc; + // If the argument count doesn't match - if (cfunc->argc >= 0 && cfunc->argc != argc) { + if (cfunc->argc >= 0 && cfunc->argc != passed_argc) { GEN_COUNTER_INC(cb, send_cfunc_argc_mismatch); return YJIT_CANT_COMPILE; } // Don't JIT functions that need C stack arguments for now - if (cfunc->argc >= 0 && argc + 1 > NUM_C_ARG_REGS) { + if (cfunc->argc >= 0 && passed_argc + 1 > NUM_C_ARG_REGS) { GEN_COUNTER_INC(cb, send_cfunc_toomany_args); return YJIT_CANT_COMPILE; } @@ -3265,7 +3286,7 @@ gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const // Delegate to codegen for C methods if we have it. { method_codegen_t known_cfunc_codegen; - if ((known_cfunc_codegen = lookup_cfunc_codegen(cme->def))) { + if (!kw_arg && (known_cfunc_codegen = lookup_cfunc_codegen(cme->def))) { if (known_cfunc_codegen(jit, ctx, ci, cme, block, argc, recv_known_klass)) { // cfunc codegen generated code. Terminate the block so // there isn't multiple calls in the same block. @@ -3337,6 +3358,9 @@ gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const // Write env flags at sp[-1] // sp[-1] = frame_type; uint64_t frame_type = VM_FRAME_MAGIC_CFUNC | VM_FRAME_FLAG_CFRAME | VM_ENV_FLAG_LOCAL; + if (kw_arg) { + frame_type |= VM_FRAME_FLAG_CFRAME_KW; + } mov(cb, mem_opnd(64, REG0, 8 * -1), imm_opnd(frame_type)); // Allocate a new CFP (ec->cfp--) @@ -3377,11 +3401,22 @@ gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const call_ptr(cb, REG0, (void *)&check_cfunc_dispatch); } + if (kw_arg) { + // Build a hash from all kwargs passed + jit_mov_gc_ptr(jit, cb, C_ARG_REGS[0], (VALUE)ci); + lea(cb, C_ARG_REGS[1], ctx_sp_opnd(ctx, 0)); + call_ptr(cb, REG0, (void *)&yjit_runtime_build_kwhash); + + // Replace the stack location at the start of kwargs with the new hash + x86opnd_t stack_opnd = ctx_stack_opnd(ctx, argc - passed_argc); + mov(cb, stack_opnd, RAX); + } + // Non-variadic method if (cfunc->argc >= 0) { // Copy the arguments from the stack to the C argument registers // self is the 0th argument and is at index argc from the stack top - for (int32_t i = 0; i < argc + 1; ++i) + for (int32_t i = 0; i < passed_argc + 1; ++i) { x86opnd_t stack_opnd = ctx_stack_opnd(ctx, argc - i); x86opnd_t c_arg_reg = C_ARG_REGS[i]; @@ -3392,7 +3427,7 @@ gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const if (cfunc->argc == -1) { // The method gets a pointer to the first argument // rb_f_puts(int argc, VALUE *argv, VALUE recv) - mov(cb, C_ARG_REGS[0], imm_opnd(argc)); + mov(cb, C_ARG_REGS[0], imm_opnd(passed_argc)); lea(cb, C_ARG_REGS[1], ctx_stack_opnd(ctx, argc - 1)); mov(cb, C_ARG_REGS[2], ctx_stack_opnd(ctx, argc)); } @@ -3410,7 +3445,7 @@ gen_send_cfunc(jitstate_t *jit, ctx_t *ctx, const struct rb_callinfo *ci, const // rb_ec_ary_new_from_values(rb_execution_context_t *ec, long n, const VLAUE *elts) mov(cb, C_ARG_REGS[0], REG_EC); - mov(cb, C_ARG_REGS[1], imm_opnd(argc)); + mov(cb, C_ARG_REGS[1], imm_opnd(passed_argc)); lea(cb, C_ARG_REGS[2], ctx_stack_opnd(ctx, argc - 1)); call_ptr(cb, REG0, (void *)rb_ec_ary_new_from_values); @@ -4122,10 +4157,6 @@ gen_send_general(jitstate_t *jit, ctx_t *ctx, struct rb_call_data *cd, rb_iseq_t case VM_METHOD_TYPE_ISEQ: return gen_send_iseq(jit, ctx, ci, cme, block, argc); case VM_METHOD_TYPE_CFUNC: - if ((vm_ci_flag(ci) & VM_CALL_KWARG) != 0) { - GEN_COUNTER_INC(cb, send_cfunc_kwargs); - return YJIT_CANT_COMPILE; - } return gen_send_cfunc(jit, ctx, ci, cme, block, argc, &comptime_recv_klass); case VM_METHOD_TYPE_IVAR: if (argc != 0) {