diff --git a/lib/rubygems/safe_marshal.rb b/lib/rubygems/safe_marshal.rb index 33e56867cd..fb8c23f7f6 100644 --- a/lib/rubygems/safe_marshal.rb +++ b/lib/rubygems/safe_marshal.rb @@ -11,8 +11,9 @@ module Gem module SafeMarshal PERMITTED_CLASSES = %w[ - Time Date + Time + Rational Gem::Dependency Gem::NameTuple @@ -28,45 +29,39 @@ module Gem private_constant :PERMITTED_CLASSES PERMITTED_SYMBOLS = %w[ - E - - offset - zone - nano_num - nano_den - submicro - - @_zone - @cpu - @debug_created_info - @force_ruby_platform - @marshal_with_utc_coercion - @name - @os - @platform - @prerelease - @requirement - @taguri - @type - @type_id - @value - @version - @version_requirement - @version_requirements - development runtime ].freeze private_constant :PERMITTED_SYMBOLS + PERMITTED_IVARS = { + "String" => %w[E @taguri @debug_created_info], + "Time" => %w[ + offset zone nano_num nano_den submicro + @_zone @marshal_with_utc_coercion + ], + "Gem::Dependency" => %w[ + @name @requirement @prerelease @version_requirement @version_requirements @type + @force_ruby_platform + ], + "Gem::NameTuple" => %w[@name @version @platform], + "Gem::Platform" => %w[@os @cpu @version], + "Psych::PrivateType" => %w[@value @type_id], + }.freeze + private_constant :PERMITTED_IVARS + def self.safe_load(input) - load(input, permitted_classes: PERMITTED_CLASSES, permitted_symbols: PERMITTED_SYMBOLS) + load(input, permitted_classes: PERMITTED_CLASSES, permitted_symbols: PERMITTED_SYMBOLS, permitted_ivars: PERMITTED_IVARS) end - def self.load(input, permitted_classes: [::Symbol], permitted_symbols: []) + def self.load(input, permitted_classes: [::Symbol], permitted_symbols: [], permitted_ivars: {}) root = Reader.new(StringIO.new(input, "r")).read! - Visitors::ToRuby.new(permitted_classes: permitted_classes, permitted_symbols: permitted_symbols).visit(root) + Visitors::ToRuby.new( + permitted_classes: permitted_classes, + permitted_symbols: permitted_symbols, + permitted_ivars: permitted_ivars, + ).visit(root) end end end diff --git a/lib/rubygems/safe_marshal/visitors/stream_printer.rb b/lib/rubygems/safe_marshal/visitors/stream_printer.rb new file mode 100644 index 0000000000..162b36ad05 --- /dev/null +++ b/lib/rubygems/safe_marshal/visitors/stream_printer.rb @@ -0,0 +1,31 @@ +# frozen_string_literal: true + +require_relative "visitor" + +module Gem::SafeMarshal + module Visitors + class StreamPrinter < Visitor + def initialize(io, indent: "") + @io = io + @indent = indent + @level = 0 + end + + def visit(target) + @io.write("#{@indent * @level}#{target.class}") + target.instance_variables.each do |ivar| + value = target.instance_variable_get(ivar) + next if Elements::Element === value || Array === value + @io.write(" #{ivar}=#{value.inspect}") + end + @io.write("\n") + begin + @level += 1 + super + ensure + @level -= 1 + end + end + end + end +end diff --git a/lib/rubygems/safe_marshal/visitors/to_ruby.rb b/lib/rubygems/safe_marshal/visitors/to_ruby.rb index e0e7b459cf..8d5c05e3ca 100644 --- a/lib/rubygems/safe_marshal/visitors/to_ruby.rb +++ b/lib/rubygems/safe_marshal/visitors/to_ruby.rb @@ -5,9 +5,10 @@ require_relative "visitor" module Gem::SafeMarshal module Visitors class ToRuby < Visitor - def initialize(permitted_classes:, permitted_symbols:) + def initialize(permitted_classes:, permitted_symbols:, permitted_ivars:) @permitted_classes = permitted_classes @permitted_symbols = permitted_symbols | permitted_classes | ["E"] + @permitted_ivars = permitted_ivars @objects = [] @symbols = [] @@ -17,7 +18,8 @@ module Gem::SafeMarshal end def inspect # :nodoc: - format("#<%s permitted_classes: %p permitted_symbols: %p>", self.class, @permitted_classes, @permitted_symbols) + format("#<%s permitted_classes: %p permitted_symbols: %p permitted_ivars: %p>", + self.class, @permitted_classes, @permitted_symbols, @permitted_ivars) end def visit(target) @@ -37,14 +39,16 @@ module Gem::SafeMarshal end def visit_Gem_SafeMarshal_Elements_Symbol(s) - resolve_symbol(s.name) + name = s.name + raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name) + visit_symbol_type(s) end - def map_ivars(ivars) + def map_ivars(klass, ivars) ivars.map.with_index do |(k, v), i| - @stack << "ivar #{i}" - k = visit(k) - @stack << k + @stack << "ivar_#{i}" + k = resolve_ivar(klass, k) + @stack[-1] = k next k, visit(v) end end @@ -54,12 +58,12 @@ module Gem::SafeMarshal object_offset = @objects.size @stack << "object" object = visit(e.object) - ivars = map_ivars(e.ivars) + ivars = map_ivars(object.class, e.ivars) case e.object when Elements::UserDefined if object.class == ::Time - offset = zone = nano_num = nano_den = nil + offset = zone = nano_num = nano_den = submicro = nil ivars.reject! do |k, v| case k when :offset @@ -71,6 +75,7 @@ module Gem::SafeMarshal when :nano_den nano_den = v when :submicro + submicro = v else next false end @@ -80,17 +85,23 @@ module Gem::SafeMarshal if (nano_den || nano_num) && !(nano_den && nano_num) raise FormatError, "Must have all of nano_den, nano_num for Time #{e.pretty_inspect}" elsif nano_den && nano_num - nano = Rational(nano_num, nano_den) - nsec, subnano = nano.divmod(1) - nano = nsec + subnano - - object = Time.at(object.to_r, nano, :nanosecond) + if RUBY_ENGINE == "jruby" + nano = Rational(nano_num, nano_den * 1_000_000_000) + object = Time.at(object.to_i + nano + object.subsec) + elsif RUBY_ENGINE == "truffleruby" + object = Time.at(object.to_i, Rational(nano_num, nano_den).to_i, :nanosecond) + else # assume "ruby" + nano = Rational(nano_num, nano_den) + nsec, subnano = nano.divmod(1) + nano = nsec + subnano + object = Time.at(object.to_r, nano, :nanosecond) + end end if zone require "time" zone = "+0000" if zone == "UTC" && offset == 0 - Time.send(:force_zone!, object, zone, offset) + call_method(Time, :force_zone!, object, zone, offset) elsif offset object = object.localtime offset end @@ -157,14 +168,23 @@ module Gem::SafeMarshal end def visit_Gem_SafeMarshal_Elements_UserDefined(o) - register_object(resolve_class(o.name).send(:_load, o.binary_string)) + register_object(call_method(resolve_class(o.name), :_load, o.binary_string)) end def visit_Gem_SafeMarshal_Elements_UserMarshal(o) - register_object(resolve_class(o.name).allocate).tap do |object| - @stack << :data - object.marshal_load visit(o.data) + klass = resolve_class(o.name) + compat = COMPAT_CLASSES.fetch(klass, nil) + idx = @objects.size + object = register_object(call_method(compat || klass, :allocate)) + + @stack << :data + ret = call_method(object, :marshal_load, visit(o.data)) + + if compat + object = @objects[idx] = ret end + + object end def visit_Gem_SafeMarshal_Elements_Integer(i) @@ -218,16 +238,9 @@ module Gem::SafeMarshal def resolve_class(n) @class_cache[n] ||= begin - name = nil - case n - when Elements::Symbol, Elements::SymbolLink - @stack << "class name" - name = visit(n) - else - raise FormatError, "Class names must be Symbol or SymbolLink" - end - to_s = name.to_s - raise UnpermittedClassError.new(name: name, stack: @stack.dup) unless @permitted_classes.include?(to_s) + to_s = resolve_symbol_name(n) + raise UnpermittedClassError.new(name: to_s, stack: @stack.dup) unless @permitted_classes.include?(to_s) + visit_symbol_type(n) begin ::Object.const_get(to_s) rescue NameError @@ -236,11 +249,47 @@ module Gem::SafeMarshal end end - def resolve_symbol(name) - raise UnpermittedSymbolError.new(symbol: name, stack: @stack.dup) unless @permitted_symbols.include?(name) - sym = name.to_sym - @symbols << sym - sym + class RationalCompat + def marshal_load(s) + num, den = s + raise ArgumentError, "Expected 2 ints" unless s.size == 2 && num.is_a?(Integer) && den.is_a?(Integer) + Rational(num, den) + end + end + + COMPAT_CLASSES = {}.tap do |h| + h[Rational] = RationalCompat if RUBY_VERSION >= "3" + end.freeze + private_constant :COMPAT_CLASSES + + def resolve_ivar(klass, name) + to_s = resolve_symbol_name(name) + + raise UnpermittedIvarError.new(symbol: to_s, klass: klass, stack: @stack.dup) unless @permitted_ivars.fetch(klass.name, [].freeze).include?(to_s) + + visit_symbol_type(name) + end + + def visit_symbol_type(element) + case element + when Elements::Symbol + sym = element.name.to_sym + @symbols << sym + sym + when Elements::SymbolLink + visit_Gem_SafeMarshal_Elements_SymbolLink(element) + end + end + + def resolve_symbol_name(element) + case element + when Elements::Symbol + element.name + when Elements::SymbolLink + visit_Gem_SafeMarshal_Elements_SymbolLink(element).to_s + else + raise FormatError, "Expected symbol or symbol link, got #{element.inspect} @ #{@stack.join(".")}" + end end def register_object(o) @@ -248,6 +297,14 @@ module Gem::SafeMarshal o end + def call_method(receiver, method, *args) + receiver.__send__(method, *args) + rescue NoMethodError => e + raise unless e.receiver == receiver + + raise MethodCallError, "Unable to call #{method.inspect} on #{receiver.inspect}, perhaps it is a class using marshal compat, which is not visible in ruby? #{e}" + end + class UnpermittedSymbolError < StandardError def initialize(symbol:, stack:) @symbol = symbol @@ -256,6 +313,15 @@ module Gem::SafeMarshal end end + class UnpermittedIvarError < StandardError + def initialize(symbol:, klass:, stack:) + @symbol = symbol + @klass = klass + @stack = stack + super "Attempting to set unpermitted ivar #{symbol.inspect} on object of class #{klass} @ #{stack.join "."}" + end + end + class UnpermittedClassError < StandardError def initialize(name:, stack:) @name = name @@ -266,6 +332,9 @@ module Gem::SafeMarshal class FormatError < StandardError end + + class MethodCallError < StandardError + end end end end diff --git a/test/rubygems/test_gem_safe_marshal.rb b/test/rubygems/test_gem_safe_marshal.rb index 9d29958b59..36b11080f7 100644 --- a/test/rubygems/test_gem_safe_marshal.rb +++ b/test/rubygems/test_gem_safe_marshal.rb @@ -20,7 +20,9 @@ class TestGemSafeMarshal < Gem::TestCase def test_recursive_string s = String.new("hello") s.instance_variable_set(:@type, s) - assert_safe_load_as s, additional_methods: [:instance_variables] + with_const(Gem::SafeMarshal, :PERMITTED_IVARS, { "String" => %w[@type E] }) do + assert_safe_load_as s, additional_methods: [:instance_variables] + end end def test_recursive_array @@ -39,11 +41,17 @@ class TestGemSafeMarshal < Gem::TestCase end def test_string_with_ivar - assert_safe_load_as String.new("abc").tap {|s| s.instance_variable_set :@type, "type" } + str = String.new("abc") + str.instance_variable_set :@type, "type" + with_const(Gem::SafeMarshal, :PERMITTED_IVARS, { "String" => %w[@type E] }) do + assert_safe_load_as str + end end def test_time_with_ivar - assert_safe_load_as Time.new.tap {|t| t.instance_variable_set :@type, "type" } + with_const(Gem::SafeMarshal, :PERMITTED_IVARS, { "Time" => %w[@type offset zone nano_num nano_den submicro], "String" => "E" }) do + assert_safe_load_as Time.new.tap {|t| t.instance_variable_set :@type, :runtime } + end end secs = Time.new(2000, 12, 31, 23, 59, 59).to_i @@ -64,7 +72,7 @@ class TestGemSafeMarshal < Gem::TestCase Time.at(secs, 1.01, :nanosecond), Time.at(secs, 1.001, :nanosecond), Time.at(secs, 1.00001, :nanosecond), - Time.at(secs, 1.00001, :nanosecond).tap {|t| t.instance_variable_set :@type, "type" }, + Time.at(secs, 1.00001, :nanosecond), ].each_with_index do |t, i| define_method("test_time_#{i} #{t.inspect}") do assert_safe_load_as t, additional_methods: [:ctime, :to_f, :to_r, :to_i, :zone, :subsec, :instance_variables, :dst?, :to_a] @@ -79,19 +87,33 @@ class TestGemSafeMarshal < Gem::TestCase end def test_hash_with_ivar - assert_safe_load_as({ runtime: :development }.tap {|h| h.instance_variable_set :@type, "null" }) + h = { runtime: :development } + h.instance_variable_set :@type, [] + with_const(Gem::SafeMarshal, :PERMITTED_IVARS, { "Hash" => %w[@type] }) do + assert_safe_load_as(h) + end end def test_hash_with_default_value assert_safe_load_as Hash.new([]) end + def test_hash_with_compare_by_identity + pend "`read_user_class` not yet implemented" + + assert_safe_load_as Hash.new.compare_by_identity + end + def test_frozen_object assert_safe_load_as Gem::Version.new("1.abc").freeze end def test_date - assert_safe_load_as Date.new + assert_safe_load_as Date.new(1994, 12, 9) + end + + def test_rational + assert_safe_load_as Rational(1, 3) end [ @@ -142,4 +164,18 @@ class TestGemSafeMarshal < Gem::TestCase end assert_equal Marshal.dump(loaded), Marshal.dump(safe_loaded), "should Marshal.dump the same" end + + def with_const(mod, name, new_value, &block) + orig = mod.const_get(name) + mod.send :remove_const, name + mod.const_set name, new_value + + begin + yield + ensure + mod.send :remove_const, name + mod.const_set name, orig + mod.send :private_constant, name + end + end end