From 604a0968472780d851121a73d889ea3b157b0a46 Mon Sep 17 00:00:00 2001 From: Leland Batey Date: Thu, 29 Sep 2016 14:23:08 -0700 Subject: [PATCH] Add support for repeated query parameters --- gengokit/gentesthelper/helper.go | 64 ++++++++++++++++++++++++ gengokit/httptransport/httptransport.go | 60 +++++++++++++++++++++- gengokit/httptransport/templates.go | 22 +++----- gengokit/httptransport/templates_test.go | 24 ++------- gengokit/template/template.go | 26 +++++----- 5 files changed, 148 insertions(+), 48 deletions(-) create mode 100644 gengokit/gentesthelper/helper.go diff --git a/gengokit/gentesthelper/helper.go b/gengokit/gentesthelper/helper.go new file mode 100644 index 0000000..69d7bab --- /dev/null +++ b/gengokit/gentesthelper/helper.go @@ -0,0 +1,64 @@ +package gentesthelper + +import ( + "bytes" + "fmt" + "go/ast" + "go/parser" + "go/printer" + "go/token" + "reflect" + "runtime" + "strings" + + "github.com/pmezard/go-difflib/difflib" +) + +// FuncSourceCode returns a string representing the source code of the function +// provided to it. +func FuncSourceCode(val interface{}) (string, error) { + ptr := reflect.ValueOf(val).Pointer() + fpath, _ := runtime.FuncForPC(ptr).FileLine(ptr) + + funcName := runtime.FuncForPC(ptr).Name() + parts := strings.Split(funcName, ".") + funcName = parts[len(parts)-1] + + // Parse the go file into the ast + fset := token.NewFileSet() + fileAst, err := parser.ParseFile(fset, fpath, nil, parser.ParseComments) + if err != nil { + return "", fmt.Errorf("ERROR: go parser couldn't parse file '%v'\n", fpath) + } + + // Search ast for function declaration with name of function passed + var fAst *ast.FuncDecl + for _, decs := range fileAst.Decls { + if f, ok := decs.(*ast.FuncDecl); ok && f.Name.String() == funcName { + fAst = f + break + } + } + code := bytes.NewBuffer(nil) + err = printer.Fprint(code, fset, fAst) + + if err != nil { + return "", fmt.Errorf("couldn't print code for func %q: %v\n", funcName, err) + } + + return code.String(), nil +} + +// DiffStrings returns the line differences of two strings. Useful for +// examining how generated code differs from expected code. +func DiffStrings(a, b string) string { + t := difflib.UnifiedDiff{ + A: difflib.SplitLines(a), + B: difflib.SplitLines(b), + FromFile: "A", + ToFile: "B", + Context: 5, + } + text, _ := difflib.GetUnifiedDiffString(t) + return text +} diff --git a/gengokit/httptransport/httptransport.go b/gengokit/httptransport/httptransport.go index 913bfd7..452a88e 100644 --- a/gengokit/httptransport/httptransport.go +++ b/gengokit/httptransport/httptransport.go @@ -11,6 +11,7 @@ import ( "text/template" "unicode" + log "github.com/Sirupsen/logrus" "github.com/TuneLab/go-truss/deftree" "github.com/TuneLab/go-truss/gengokit/clientarggen" gogen "github.com/golang/protobuf/protoc-gen-go/generator" @@ -87,7 +88,7 @@ func NewBinding(i int, meth *deftree.ServiceMethod) *Binding { var gt string var ok bool tmap := clientarggen.ProtoToGoTypeMap - if gt, ok = tmap[nField.ProtobufType]; !ok || field.Label == "LABEL_REPEATED" { + if gt, ok = tmap[nField.ProtobufType]; !ok { gt = "string" nField.IsBaseType = false } else { @@ -101,6 +102,24 @@ func NewBinding(i int, meth *deftree.ServiceMethod) *Binding { nField.LowCamelName = LowCamelName(nField.Name) nBinding.Fields = append(nBinding.Fields, &nField) + + // Emit warnings for certain cases + if !nField.IsBaseType { + log.Warnf("%s.%s is a custom type '%s', only base types and repeated base "+ + "types are supported. As a result, the generated HTTP "+ + "transport will fail to compile. Remove non-base types.", + meth.GetName(), + nField.Name, + nField.ProtobufType) + } + if field.Label == "LABEL_REPEATED" && nField.Location == "path" { + log.Warnf( + "%s.%s is a repeated field specified to be in the path. "+ + "Repeated fields are not supported in the path and may"+ + "result in generated code which fails to compile.", + meth.GetName(), + nField.Name) + } } return &nBinding } @@ -159,6 +178,45 @@ func (b *Binding) PathSections() []string { return rv } +// GenQueryUnmarshaler returns the generated code for server-side unmarshaling +// of a query parameter into it's correct field on the request struct. +func (f *Field) GenQueryUnmarshaler() (string, error) { + repeatedQueryLogic := ` +for _, {{.LocalName}}Str := range r.URL.Query()["{{.Name}}"] { + {{.ConvertFunc}} + if err != nil { + fmt.Printf("Error while extracting {{.LocalName}} from {{.Location}}: %v\n", err) + fmt.Printf("{{.Location}}Params: %v\n", {{.Location}}Params) + return nil, err + } + req.{{.CamelName}} = append(req.{{.CamelName}}, {{.TypeConversion}}) +} +` + genericLogic := ` +{{.LocalName}}Str := {{.Location}}Params["{{.Name}}"] +{{.ConvertFunc}} +// TODO: Better error handling +if err != nil { + fmt.Printf("Error while extracting {{.LocalName}} from {{.Location}}: %v\n", err) + fmt.Printf("{{.Location}}Params: %v\n", {{.Location}}Params) + return nil, err +} +req.{{.CamelName}} = {{.TypeConversion}} +` + var selected string + if f.Location == "query" && f.ProtobufLabel == "LABEL_REPEATED" { + selected = repeatedQueryLogic + } else if f.Location != "body" { + selected = genericLogic + } + code, err := ApplyTemplate("FieldEncodeLogic", selected, f, TemplateFuncs) + if err != nil { + return "", err + } + code = FormatCode(code) + return code, nil +} + // createDecodeConvertFunc creates a go string representing the function to // convert the string form of the field to it's correct go type. func createDecodeConvertFunc(f Field) string { diff --git a/gengokit/httptransport/templates.go b/gengokit/httptransport/templates.go index 7f2948c..6101a16 100644 --- a/gengokit/httptransport/templates.go +++ b/gengokit/httptransport/templates.go @@ -13,29 +13,19 @@ func DecodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request) (in pathParams, err := PathParams(r.URL.Path, "{{$binding.PathTemplate}}") _ = pathParams - // TODO: Better error handling if err != nil { fmt.Printf("Error while reading path params: %v\n", err) return nil, err } queryParams, err := QueryParams(r.URL.Query()) _ = queryParams - // TODO: Better error handling if err != nil { fmt.Printf("Error while reading query params: %v\n", err) return nil, err } {{range $field := $binding.Fields}} {{if ne $field.Location "body"}} - {{$field.LocalName}}Str := {{$field.Location}}Params["{{$field.Name}}"] - {{$field.ConvertFunc}} - // TODO: Better error handling - if err != nil { - fmt.Printf("Error while extracting {{$field.LocalName}} from {{$field.Location}}: %v\n", err) - fmt.Printf("{{$field.Location}}Params: %v\n", {{$field.Location}}Params) - return nil, err - } - req.{{$field.CamelName}} = {{$field.TypeConversion}} + {{$field.GenQueryUnmarshaler}} {{end}} {{end}} return &req, err @@ -61,8 +51,6 @@ func EncodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request, req {{$section}}, {{- end}} }, "/") - //r.URL.Scheme, - //r.URL.Host, u, err := url.Parse(path) if err != nil { return err @@ -74,7 +62,13 @@ func EncodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request, req values := r.URL.Query() {{- range $field := $binding.Fields }} {{- if eq $field.Location "query"}} - values.Add("{{$field.Name}}", fmt.Sprint(req.{{$field.CamelName}})) + {{- if eq $field.ProtobufLabel "LABEL_REPEATED"}} + for _, v := range req.{{$field.CamelName}} { + values.Add("{{$field.Name}}", fmt.Sprint(v)) + } + {{else}} + values.Add("{{$field.Name}}", fmt.Sprint(req.{{$field.CamelName}})) + {{- end }} {{- end }} {{- end}} diff --git a/gengokit/httptransport/templates_test.go b/gengokit/httptransport/templates_test.go index df6d50a..d2d498b 100644 --- a/gengokit/httptransport/templates_test.go +++ b/gengokit/httptransport/templates_test.go @@ -3,7 +3,7 @@ package httptransport import ( "testing" - "github.com/pmezard/go-difflib/difflib" + "github.com/TuneLab/go-truss/gengokit/gentesthelper" ) func TestGenClientEncode(t *testing.T) { @@ -69,8 +69,6 @@ func EncodeHTTPSumZeroRequest(_ context.Context, r *http.Request, request interf "sum", fmt.Sprint(req.A), }, "/") - //r.URL.Scheme, - //r.URL.Host, u, err := url.Parse(path) if err != nil { return err @@ -98,7 +96,7 @@ func EncodeHTTPSumZeroRequest(_ context.Context, r *http.Request, request interf ` if got, want := str, desired; got != want { t.Errorf("Generated code differs from result.\ngot = %s\nwant = %s", got, want) - t.Log(DiffStrings(got, want)) + t.Log(gentesthelper.DiffStrings(got, want)) } } @@ -162,14 +160,12 @@ func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{}, pathParams, err := PathParams(r.URL.Path, "/sum/{a}") _ = pathParams - // TODO: Better error handling if err != nil { fmt.Printf("Error while reading path params: %v\n", err) return nil, err } queryParams, err := QueryParams(r.URL.Query()) _ = queryParams - // TODO: Better error handling if err != nil { fmt.Printf("Error while reading query params: %v\n", err) return nil, err @@ -201,7 +197,7 @@ func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{}, ` if got, want := str, desired; got != want { t.Errorf("Generated code differs from result.\ngot = %s\nwant = %s", got, want) - t.Log(DiffStrings(got, want)) + t.Log(gentesthelper.DiffStrings(got, want)) } } @@ -214,18 +210,6 @@ func TestHTTPAssistFuncs(t *testing.T) { if got, want := tmplfncs, FormatCode(source); got != want { t.Errorf("Assistant functions in templates differ from the source of those functions as they exist within the codebase") - t.Log(DiffStrings(got, want)) + t.Log(gentesthelper.DiffStrings(got, want)) } } - -func DiffStrings(a, b string) string { - t := difflib.UnifiedDiff{ - A: difflib.SplitLines(a), - B: difflib.SplitLines(b), - FromFile: "A", - ToFile: "B", - Context: 5, - } - text, _ := difflib.GetUnifiedDiffString(t) - return text -} diff --git a/gengokit/template/template.go b/gengokit/template/template.go index acd477d..7d4a838 100644 --- a/gengokit/template/template.go +++ b/gengokit/template/template.go @@ -95,7 +95,7 @@ func nameServiceNameClientClient_mainGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/NAME-client/client_main.gotemplate", size: 4462, mode: os.FileMode(436), modTime: time.Unix(1475107892, 0)} + info := bindataFileInfo{name: "NAME-service/NAME-client/client_main.gotemplate", size: 4462, mode: os.FileMode(436), modTime: time.Unix(1475181328, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -115,7 +115,7 @@ func nameServiceNameServerServer_mainGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/NAME-server/server_main.gotemplate", size: 6628, mode: os.FileMode(436), modTime: time.Unix(1474074090, 0)} + info := bindataFileInfo{name: "NAME-service/NAME-server/server_main.gotemplate", size: 6628, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -135,7 +135,7 @@ func nameServiceGeneratedClientGrpcClientGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/client/grpc/client.gotemplate", size: 2482, mode: os.FileMode(436), modTime: time.Unix(1474072015, 0)} + info := bindataFileInfo{name: "NAME-service/generated/client/grpc/client.gotemplate", size: 2482, mode: os.FileMode(436), modTime: time.Unix(1474057071, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -155,7 +155,7 @@ func nameServiceGeneratedClientHttpClientGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/client/http/client.gotemplate", size: 2739, mode: os.FileMode(436), modTime: time.Unix(1474072015, 0)} + info := bindataFileInfo{name: "NAME-service/generated/client/http/client.gotemplate", size: 2739, mode: os.FileMode(436), modTime: time.Unix(1474057071, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -175,7 +175,7 @@ func nameServiceGeneratedEndpointsGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/endpoints.gotemplate", size: 3492, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/generated/endpoints.gotemplate", size: 3492, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -195,7 +195,7 @@ func nameServiceGeneratedService_middleGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/service_middle.gotemplate", size: 2777, mode: os.FileMode(436), modTime: time.Unix(1474072015, 0)} + info := bindataFileInfo{name: "NAME-service/generated/service_middle.gotemplate", size: 2777, mode: os.FileMode(436), modTime: time.Unix(1474057071, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -215,7 +215,7 @@ func nameServiceGeneratedTransport_grpcGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/transport_grpc.gotemplate", size: 3601, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/generated/transport_grpc.gotemplate", size: 3601, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -235,7 +235,7 @@ func nameServiceGeneratedTransport_httpGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/generated/transport_http.gotemplate", size: 3980, mode: os.FileMode(436), modTime: time.Unix(1474072015, 0)} + info := bindataFileInfo{name: "NAME-service/generated/transport_http.gotemplate", size: 3980, mode: os.FileMode(436), modTime: time.Unix(1474057071, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -255,7 +255,7 @@ func nameServiceHandlersClientClient_handlerGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/handlers/client/client_handler.gotemplate", size: 940, mode: os.FileMode(436), modTime: time.Unix(1475107892, 0)} + info := bindataFileInfo{name: "NAME-service/handlers/client/client_handler.gotemplate", size: 940, mode: os.FileMode(436), modTime: time.Unix(1475181328, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -275,7 +275,7 @@ func nameServiceHandlersServerServer_handlerGotemplate() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/handlers/server/server_handler.gotemplate", size: 1191, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/handlers/server/server_handler.gotemplate", size: 1191, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -295,7 +295,7 @@ func nameServicePartial_templateClient_handlerMethods() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/partial_template/client_handler.methods", size: 867, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/partial_template/client_handler.methods", size: 867, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -315,7 +315,7 @@ func nameServicePartial_templateServiceInterface() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/partial_template/service.interface", size: 197, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/partial_template/service.interface", size: 197, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil } @@ -335,7 +335,7 @@ func nameServicePartial_templateServiceMethods() (*asset, error) { return nil, err } - info := bindataFileInfo{name: "NAME-service/partial_template/service.methods", size: 539, mode: os.FileMode(436), modTime: time.Unix(1474673731, 0)} + info := bindataFileInfo{name: "NAME-service/partial_template/service.methods", size: 539, mode: os.FileMode(436), modTime: time.Unix(1474335271, 0)} a := &asset{bytes: bytes, info: info} return a, nil }