diff --git a/lib/prism/parse_result.rb b/lib/prism/parse_result.rb index a27f30d43b..4b0c57ed4b 100644 --- a/lib/prism/parse_result.rb +++ b/lib/prism/parse_result.rb @@ -56,6 +56,23 @@ module Prism character_offset(byte_offset) - character_offset(line_start(byte_offset)) end + # Returns the offset from the start of the file for the given byte offset + # counting in code units for the given encoding. + # + # This method is tested with UTF-8, UTF-16, and UTF-32. If there is the + # concept of code units that differs from the number of characters in other + # encodings, it is not captured here. + def code_units_offset(byte_offset, encoding) + byteslice = source.byteslice(0, byte_offset).encode(encoding) + (encoding == Encoding::UTF_16LE || encoding == Encoding::UTF_16BE) ? (byteslice.bytesize / 2) : byteslice.length + end + + # Returns the column number in code units for the given encoding for the + # given byte offset. + def code_units_column(byte_offset, encoding) + code_units_offset(byte_offset, encoding) - code_units_offset(line_start(byte_offset), encoding) + end + private # Binary search through the offsets to find the line number for the given @@ -138,6 +155,11 @@ module Prism source.character_offset(start_offset) end + # The offset from the start of the file in code units of the given encoding. + def start_code_units_offset(encoding = Encoding::UTF_16LE) + source.code_units_offset(start_offset, encoding) + end + # The byte offset from the beginning of the source where this location ends. def end_offset start_offset + length @@ -149,6 +171,11 @@ module Prism source.character_offset(end_offset) end + # The offset from the start of the file in code units of the given encoding. + def end_code_units_offset(encoding = Encoding::UTF_16LE) + source.code_units_offset(end_offset, encoding) + end + # The line number where this location starts. def start_line source.line(start_offset) @@ -177,6 +204,12 @@ module Prism source.character_column(start_offset) end + # The column number in code units of the given encoding where this location + # starts from the start of the line. + def start_code_units_column(encoding = Encoding::UTF_16LE) + source.code_units_column(start_offset, encoding) + end + # The column number in bytes where this location ends from the start of the # line. def end_column @@ -189,6 +222,12 @@ module Prism source.character_column(end_offset) end + # The column number in code units of the given encoding where this location + # ends from the start of the line. + def end_code_units_column(encoding = Encoding::UTF_16LE) + source.code_units_column(end_offset, encoding) + end + # Implement the hash pattern matching interface for Location. def deconstruct_keys(keys) { start_offset: start_offset, end_offset: end_offset } diff --git a/test/prism/ruby_api_test.rb b/test/prism/ruby_api_test.rb index cf7ea437cf..ff69ef5417 100644 --- a/test/prism/ruby_api_test.rb +++ b/test/prism/ruby_api_test.rb @@ -116,6 +116,86 @@ module Prism assert_equal 7, location.end_character_column end + def test_location_code_units + program = Prism.parse("šŸ˜€ + šŸ˜€\nšŸ˜ ||= šŸ˜").value + + # first šŸ˜€ + location = program.statements.body.first.receiver.location + + assert_equal 0, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 0, location.start_code_units_offset(Encoding::UTF_16LE) + assert_equal 0, location.start_code_units_offset(Encoding::UTF_32LE) + + assert_equal 1, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 2, location.end_code_units_offset(Encoding::UTF_16LE) + assert_equal 1, location.end_code_units_offset(Encoding::UTF_32LE) + + assert_equal 0, location.start_code_units_column(Encoding::UTF_8) + assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE) + assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE) + + assert_equal 1, location.end_code_units_column(Encoding::UTF_8) + assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE) + assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE) + + # second šŸ˜€ + location = program.statements.body.first.arguments.arguments.first.location + + assert_equal 4, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 5, location.start_code_units_offset(Encoding::UTF_16LE) + assert_equal 4, location.start_code_units_offset(Encoding::UTF_32LE) + + assert_equal 5, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 7, location.end_code_units_offset(Encoding::UTF_16LE) + assert_equal 5, location.end_code_units_offset(Encoding::UTF_32LE) + + assert_equal 4, location.start_code_units_column(Encoding::UTF_8) + assert_equal 5, location.start_code_units_column(Encoding::UTF_16LE) + assert_equal 4, location.start_code_units_column(Encoding::UTF_32LE) + + assert_equal 5, location.end_code_units_column(Encoding::UTF_8) + assert_equal 7, location.end_code_units_column(Encoding::UTF_16LE) + assert_equal 5, location.end_code_units_column(Encoding::UTF_32LE) + + # first šŸ˜ + location = program.statements.body.last.name_loc + + assert_equal 6, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 8, location.start_code_units_offset(Encoding::UTF_16LE) + assert_equal 6, location.start_code_units_offset(Encoding::UTF_32LE) + + assert_equal 7, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 10, location.end_code_units_offset(Encoding::UTF_16LE) + assert_equal 7, location.end_code_units_offset(Encoding::UTF_32LE) + + assert_equal 0, location.start_code_units_column(Encoding::UTF_8) + assert_equal 0, location.start_code_units_column(Encoding::UTF_16LE) + assert_equal 0, location.start_code_units_column(Encoding::UTF_32LE) + + assert_equal 1, location.end_code_units_column(Encoding::UTF_8) + assert_equal 2, location.end_code_units_column(Encoding::UTF_16LE) + assert_equal 1, location.end_code_units_column(Encoding::UTF_32LE) + + # second šŸ˜ + location = program.statements.body.last.value.location + + assert_equal 12, location.start_code_units_offset(Encoding::UTF_8) + assert_equal 15, location.start_code_units_offset(Encoding::UTF_16LE) + assert_equal 12, location.start_code_units_offset(Encoding::UTF_32LE) + + assert_equal 13, location.end_code_units_offset(Encoding::UTF_8) + assert_equal 17, location.end_code_units_offset(Encoding::UTF_16LE) + assert_equal 13, location.end_code_units_offset(Encoding::UTF_32LE) + + assert_equal 6, location.start_code_units_column(Encoding::UTF_8) + assert_equal 7, location.start_code_units_column(Encoding::UTF_16LE) + assert_equal 6, location.start_code_units_column(Encoding::UTF_32LE) + + assert_equal 7, location.end_code_units_column(Encoding::UTF_8) + assert_equal 9, location.end_code_units_column(Encoding::UTF_16LE) + assert_equal 7, location.end_code_units_column(Encoding::UTF_32LE) + end + def test_heredoc? refute parse_expression("\"foo\"").heredoc? refute parse_expression("\"foo \#{1}\"").heredoc?