ruby/lib/prism/pattern.rb

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

269 строки
8.1 KiB
Ruby
Исходник Обычный вид История

# frozen_string_literal: true
2023-09-27 19:24:48 +03:00
module Prism
# A pattern is an object that wraps a Ruby pattern matching expression. The
# expression would normally be passed to an `in` clause within a `case`
# expression or a rightward assignment expression. For example, in the
# following snippet:
#
# case node
2023-09-27 19:24:48 +03:00
# in ConstantPathNode[ConstantReadNode[name: :Prism], ConstantReadNode[name: :Pattern]]
# end
#
# the pattern is the <tt>ConstantPathNode[...]</tt> expression.
#
# The pattern gets compiled into an object that responds to #call by running
2023-09-27 19:24:48 +03:00
# the #compile method. This method itself will run back through Prism to
# parse the expression into a tree, then walk the tree to generate the
# necessary callable objects. For example, if you wanted to compile the
# expression above into a callable, you would:
#
2023-09-27 19:24:48 +03:00
# callable = Prism::Pattern.new("ConstantPathNode[ConstantReadNode[name: :Prism], ConstantReadNode[name: :Pattern]]").compile
# callable.call(node)
#
# The callable object returned by #compile is guaranteed to respond to #call
# with a single argument, which is the node to match against. It also is
# guaranteed to respond to #===, which means it itself can be used in a `case`
# expression, as in:
#
# case node
# when callable
# end
#
# If the query given to the initializer cannot be compiled into a valid
# matcher (either because of a syntax error or because it is using syntax we
2023-09-27 19:24:48 +03:00
# do not yet support) then a Prism::Pattern::CompilationError will be
# raised.
class Pattern
# Raised when the query given to a pattern is either invalid Ruby syntax or
# is using syntax that we don't yet support.
class CompilationError < StandardError
# Create a new CompilationError with the given representation of the node
# that caused the error.
def initialize(repr)
super(<<~ERROR)
2023-09-27 19:24:48 +03:00
prism was unable to compile the pattern you provided into a usable
expression. It failed on to understand the node represented by:
#{repr}
Note that not all syntax supported by Ruby's pattern matching syntax
2023-09-27 19:24:48 +03:00
is also supported by prism's patterns. If you're using some syntax
that you believe should be supported, please open an issue on
2023-09-27 19:24:48 +03:00
GitHub at https://github.com/ruby/prism/issues/new.
ERROR
end
end
# The query that this pattern was initialized with.
attr_reader :query
# Create a new pattern with the given query. The query should be a string
# containing a Ruby pattern matching expression.
def initialize(query)
@query = query
@compiled = nil
end
# Compile the query into a callable object that can be used to match against
# nodes.
def compile
2023-09-27 19:24:48 +03:00
result = Prism.parse("case nil\nin #{query}\nend")
case_match_node = result.value.statements.body.last
raise CompilationError, case_match_node.inspect unless case_match_node.is_a?(CaseMatchNode)
in_node = case_match_node.conditions.last
raise CompilationError, in_node.inspect unless in_node.is_a?(InNode)
compile_node(in_node.pattern)
end
# Scan the given node and all of its children for nodes that match the
# pattern. If a block is given, it will be called with each node that
# matches the pattern. If no block is given, an enumerator will be returned
# that will yield each node that matches the pattern.
def scan(root)
return to_enum(:scan, root) unless block_given?
@compiled ||= compile
queue = [root]
while (node = queue.shift)
yield node if @compiled.call(node) # steep:ignore
queue.concat(node.compact_child_nodes)
end
end
private
# Shortcut for combining two procs into one that returns true if both return
# true.
def combine_and(left, right)
->(other) { left.call(other) && right.call(other) }
end
# Shortcut for combining two procs into one that returns true if either
# returns true.
def combine_or(left, right)
->(other) { left.call(other) || right.call(other) }
end
# Raise an error because the given node is not supported.
def compile_error(node)
raise CompilationError, node.inspect
end
# in [foo, bar, baz]
def compile_array_pattern_node(node)
compile_error(node) if !node.rest.nil? || node.posts.any?
constant = node.constant
compiled_constant = compile_node(constant) if constant
preprocessed = node.requireds.map { |required| compile_node(required) }
compiled_requireds = ->(other) do
deconstructed = other.deconstruct
deconstructed.length == preprocessed.length &&
preprocessed
.zip(deconstructed)
.all? { |(matcher, value)| matcher.call(value) }
end
if compiled_constant
combine_and(compiled_constant, compiled_requireds)
else
compiled_requireds
end
end
# in foo | bar
def compile_alternation_pattern_node(node)
combine_or(compile_node(node.left), compile_node(node.right))
end
2023-09-27 19:24:48 +03:00
# in Prism::ConstantReadNode
def compile_constant_path_node(node)
parent = node.parent
2023-09-27 19:24:48 +03:00
if parent.is_a?(ConstantReadNode) && parent.slice == "Prism"
name = node.name
raise CompilationError, node.inspect if name.nil?
compile_constant_name(node, name)
else
compile_error(node)
end
end
# in ConstantReadNode
# in String
def compile_constant_read_node(node)
compile_constant_name(node, node.name)
end
# Compile a name associated with a constant.
def compile_constant_name(node, name)
if Prism.const_defined?(name, false)
clazz = Prism.const_get(name)
->(other) { clazz === other }
elsif Object.const_defined?(name, false)
clazz = Object.const_get(name)
->(other) { clazz === other }
else
compile_error(node)
end
end
# in InstanceVariableReadNode[name: Symbol]
# in { name: Symbol }
def compile_hash_pattern_node(node)
compile_error(node) if node.rest
compiled_constant = compile_node(node.constant) if node.constant
preprocessed =
node.elements.to_h do |element|
key = element.key
if key.is_a?(SymbolNode)
[key.unescaped.to_sym, compile_node(element.value)]
else
raise CompilationError, element.inspect
end
end
compiled_keywords = ->(other) do
deconstructed = other.deconstruct_keys(preprocessed.keys)
preprocessed.all? do |keyword, matcher|
deconstructed.key?(keyword) && matcher.call(deconstructed[keyword])
end
end
if compiled_constant
combine_and(compiled_constant, compiled_keywords)
else
compiled_keywords
end
end
# in nil
def compile_nil_node(node)
->(attribute) { attribute.nil? }
end
# in /foo/
def compile_regular_expression_node(node)
regexp = Regexp.new(node.unescaped, node.closing[1..])
->(attribute) { regexp === attribute }
end
# in ""
# in "foo"
def compile_string_node(node)
string = node.unescaped
->(attribute) { string === attribute }
end
# in :+
# in :foo
def compile_symbol_node(node)
symbol = node.unescaped.to_sym
->(attribute) { symbol === attribute }
end
# Compile any kind of node. Dispatch out to the individual compilation
# methods based on the type of node.
def compile_node(node)
case node
when AlternationPatternNode
compile_alternation_pattern_node(node)
when ArrayPatternNode
compile_array_pattern_node(node)
when ConstantPathNode
compile_constant_path_node(node)
when ConstantReadNode
compile_constant_read_node(node)
when HashPatternNode
compile_hash_pattern_node(node)
when NilNode
compile_nil_node(node)
when RegularExpressionNode
compile_regular_expression_node(node)
when StringNode
compile_string_node(node)
when SymbolNode
compile_symbol_node(node)
else
compile_error(node)
end
end
end
end