diff --git a/decode_test.go b/decode_test.go index b94e7d7..ae68896 100644 --- a/decode_test.go +++ b/decode_test.go @@ -647,6 +647,37 @@ func (s *S) TestUnmarshalerError(c *C) { c.Assert(err, Equals, failingErr) } +type sliceUnmarshaler []int + +func (su *sliceUnmarshaler) UnmarshalYAML(unmarshal func(interface{}) error) error { + var slice []int + err := unmarshal(&slice) + if err == nil { + *su = slice + return nil + } + + var intVal int + err = unmarshal(&intVal) + if err == nil { + *su = []int{intVal} + return nil + } + + return err +} + +func (s *S) TestUnmarshalerRetry(c *C) { + var su sliceUnmarshaler + err := yaml.Unmarshal([]byte("[1, 2, 3]"), &su) + c.Assert(err, IsNil) + c.Assert(su, DeepEquals, sliceUnmarshaler([]int{1, 2, 3})) + + err = yaml.Unmarshal([]byte("1"), &su) + c.Assert(err, IsNil) + c.Assert(su, DeepEquals, sliceUnmarshaler([]int{1})) +} + // From http://yaml.org/type/merge.html var mergeTests = ` anchors: diff --git a/yaml.go b/yaml.go index 5d1b86c..70fb66b 100644 --- a/yaml.go +++ b/yaml.go @@ -90,7 +90,7 @@ func Unmarshal(in []byte, out interface{}) (err error) { } d.unmarshal(node, v) } - if d.terrors != nil { + if len(d.terrors) > 0 { return &TypeError{d.terrors} } return nil