From b1a28b05db87ff520807ccd55d34e69fcd6ea53c Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Thu, 28 Sep 2023 12:36:17 -0400 Subject: [PATCH] Support local variable targeting in pattern matching --- prism_compile.c | 46 +++++++++++++++++++++++---------- test/ruby/test_compile_prism.rb | 2 ++ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/prism_compile.c b/prism_compile.c index 0aee358c8b..7e0786c1f5 100644 --- a/prism_compile.c +++ b/prism_compile.c @@ -614,8 +614,8 @@ pm_reg_flags(const pm_node_t *node) { /** * Compile a pattern matching expression. */ -static void -pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, pm_compile_context_t *compile_context, LABEL *matched_label, LABEL *unmatched_label) +static int +pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, const uint8_t *src, pm_compile_context_t *compile_context, LABEL *matched_label, LABEL *unmatched_label, bool in_alternation_pattern) { int lineno = (int) pm_newline_list_line_column(&compile_context->parser->newline_list, node->location.start).line; NODE dummy_line_node = generate_dummy_line_node(lineno, lineno); @@ -630,15 +630,38 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re case PM_HASH_PATTERN_NODE: rb_bug("Hash pattern matching not yet supported."); break; - case PM_LOCAL_VARIABLE_TARGET_NODE: - rb_bug("Local variable target node matching not yet supported."); - break; case PM_IF_NODE: rb_bug("If guards on pattern matching not yet supported."); break; case PM_UNLESS_NODE: rb_bug("Unless guards on pattern matching not yet supported."); break; + case PM_CAPTURE_PATTERN_NODE: + rb_bug("Capture pattern matching not yet supported."); + break; + case PM_LOCAL_VARIABLE_TARGET_NODE: { + // Local variables can be targetted by placing identifiers in the place + // of a pattern. For example, foo in bar. This results in the value + // being matched being written to that local variable. + pm_local_variable_target_node_t *cast = (pm_local_variable_target_node_t *) node; + int index = pm_lookup_local_index(iseq, compile_context, cast->name); + + // If this local variable is being written from within an alternation + // pattern, then it cannot actually be added to the local table since + // it's ambiguous which value should be used. So instead we indicate + // this with a compile error. + if (in_alternation_pattern) { + const char *name = rb_id2name(pm_constant_id_lookup(compile_context, cast->name)); + if (name && strlen(name) > 0 && name[0] != '_') { + COMPILE_ERROR(ERROR_ARGS "illegal variable in alternative pattern (%"PRIsVALUE")", name); + return COMPILE_NG; + } + } + + ADD_SETLOCAL(ret, &dummy_line_node, index, (int) cast->depth); + ADD_INSNL(ret, &dummy_line_node, jump, matched_label); + break; + } case PM_ALTERNATION_PATTERN_NODE: { // Alternation patterns allow you to specify multiple patterns in a // single expression using the | operator. @@ -650,7 +673,7 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re // First, we're going to attempt to match against the left pattern. If // that pattern matches, then we'll skip matching the right pattern. ADD_INSN(ret, &dummy_line_node, dup); - pm_compile_pattern(iseq, cast->left, ret, src, compile_context, matched_left_label, unmatched_left_label); + pm_compile_pattern(iseq, cast->left, ret, src, compile_context, matched_left_label, unmatched_left_label, true); // If we get here, then we matched on the left pattern. In this case we // should pop out the duplicate value that we preemptively added to @@ -663,12 +686,9 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re // If we get here, then we didn't match on the left pattern. In this // case we attempt to match against the right pattern. ADD_LABEL(ret, unmatched_left_label); - pm_compile_pattern(iseq, cast->right, ret, src, compile_context, matched_label, unmatched_label); + pm_compile_pattern(iseq, cast->right, ret, src, compile_context, matched_label, unmatched_label, true); break; } - case PM_CAPTURE_PATTERN_NODE: - rb_bug("Capture pattern matching not yet supported."); - break; case PM_ARRAY_NODE: case PM_CLASS_VARIABLE_READ_NODE: case PM_CONSTANT_PATH_NODE: @@ -701,12 +721,12 @@ pm_compile_pattern(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const re break; case PM_PINNED_VARIABLE_NODE: { pm_pinned_variable_node_t *cast = (pm_pinned_variable_node_t *) node; - pm_compile_pattern(iseq, cast->variable, ret, src, compile_context, matched_label, unmatched_label); + pm_compile_pattern(iseq, cast->variable, ret, src, compile_context, matched_label, unmatched_label, false); break; } case PM_PINNED_EXPRESSION_NODE: { pm_pinned_expression_node_t *cast = (pm_pinned_expression_node_t *) node; - pm_compile_pattern(iseq, cast->expression, ret, src, compile_context, matched_label, unmatched_label); + pm_compile_pattern(iseq, cast->expression, ret, src, compile_context, matched_label, unmatched_label, false); break; } default: @@ -1739,7 +1759,7 @@ pm_compile_node(rb_iseq_t *iseq, const pm_node_t *node, LINK_ANCHOR *const ret, LABEL *matched_label = NEW_LABEL(lineno); LABEL *unmatched_label = NEW_LABEL(lineno); LABEL *done_label = NEW_LABEL(lineno); - pm_compile_pattern(iseq, cast->pattern, ret, src, compile_context, matched_label, unmatched_label); + pm_compile_pattern(iseq, cast->pattern, ret, src, compile_context, matched_label, unmatched_label, false); // If the pattern did not match, then compile the necessary instructions // to handle pushing false onto the stack, then jump to the end. diff --git a/test/ruby/test_compile_prism.rb b/test/ruby/test_compile_prism.rb index d3fdfd68f1..f65025b9df 100644 --- a/test/ruby/test_compile_prism.rb +++ b/test/ruby/test_compile_prism.rb @@ -373,6 +373,8 @@ module Prism test_prism_eval("\"foo1\" in /...\#{1}/") test_prism_eval("4 in ->(v) { v.even? }") + test_prism_eval("5 in foo") + test_prism_eval("1 in 2") end