* parse.y (ripper_initialize): allow generic input as source, if
  it has #gets method.

git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@57022 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
This commit is contained in:
nobu 2016-12-08 00:45:13 +00:00
Родитель d8bcfd2b12
Коммит 510f0ec869
2 изменённых файлов: 35 добавлений и 1 удалений

18
parse.y
Просмотреть файл

@ -776,7 +776,7 @@ static VALUE parser_heredoc_dedent(struct parser_params*,VALUE);
# define rb_warning3L(l,fmt,a,b,c) WARNING_CALL(WARNING_ARGS_L(l, fmt, 4), (a), (b), (c))
# define rb_warning4L(l,fmt,a,b,c,d) WARNING_CALL(WARNING_ARGS_L(l, fmt, 5), (a), (b), (c), (d))
#ifdef RIPPER
static ID id_warn, id_warning;
static ID id_warn, id_warning, id_gets;
# define WARN_S_L(s,l) STR_NEW(s,l)
# define WARN_S(s) STR_NEW2(s)
# define WARN_I(i) INT2NUM(i)
@ -11317,6 +11317,18 @@ ripper_compile_error(struct parser_params *parser, const char *fmt, ...)
static VALUE
ripper_lex_get_generic(struct parser_params *parser, VALUE src)
{
VALUE line = rb_funcallv_public(src, id_gets, 0, 0);
if (!NIL_P(line) && !RB_TYPE_P(line, T_STRING)) {
rb_raise(rb_eTypeError,
"gets returned %"PRIsVALUE" (expected String or nil)",
rb_obj_class(line));
}
return line;
}
static VALUE
ripper_lex_io_get(struct parser_params *parser, VALUE src)
{
return rb_io_gets(src);
}
@ -11352,6 +11364,9 @@ ripper_initialize(int argc, VALUE *argv, VALUE self)
TypedData_Get_Struct(self, struct parser_params, &parser_data_type, parser);
rb_scan_args(argc, argv, "12", &src, &fname, &lineno);
if (RB_TYPE_P(src, T_FILE)) {
lex_gets = ripper_lex_io_get;
}
else if (rb_respond_to(src, id_gets)) {
lex_gets = ripper_lex_get_generic;
}
else {
@ -11519,6 +11534,7 @@ Init_ripper(void)
ripper_init_eventids2();
id_warn = rb_intern_const("warn");
id_warning = rb_intern_const("warning");
id_gets = rb_intern_const("gets");
InitVM(ripper);
}

Просмотреть файл

@ -117,4 +117,22 @@ end
assert_nothing_raised { Ripper.lex src }
end
class TestInput < self
Input = Struct.new(:lines) do
def gets
lines.shift
end
end
def setup
@ripper = Ripper.new(Input.new(["1 + 1"]))
end
def test_invalid_gets
ripper = assert_nothing_raised {Ripper.new(Input.new([0]))}
assert_raise(TypeError) {ripper.parse}
end
end
end if ripper_test