Migrate clientarggen to svcdef
This commit is contained in:
Родитель
4cafb6b349
Коммит
c57b063590
|
@ -9,10 +9,10 @@ import (
|
|||
"fmt"
|
||||
"html/template"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
log "github.com/Sirupsen/logrus"
|
||||
"github.com/TuneLab/go-truss/deftree"
|
||||
generatego "github.com/golang/protobuf/protoc-gen-go/generator"
|
||||
"github.com/TuneLab/go-truss/svcdef"
|
||||
gogen "github.com/golang/protobuf/protoc-gen-go/generator"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
|
@ -58,6 +58,7 @@ type ClientArg struct {
|
|||
// field corresponding to this arg, as provided by the protocol buffer
|
||||
// compiler. For a list of these basic types and their corresponding Go
|
||||
// types, see the ProtoToGoTypeMap map in this file.
|
||||
// Deprecated by Svcdef.
|
||||
ProtbufType string
|
||||
|
||||
// IsBaseType is true if this arg corresponds to a protobuf field which is
|
||||
|
@ -161,54 +162,50 @@ var ProtoToGoTypeMap = map[string]string{
|
|||
|
||||
// New creates a ClientServiceArgs struct containing all the arguments for all
|
||||
// the methods of a given RPC.
|
||||
func New(svc *deftree.ProtoService) *ClientServiceArgs {
|
||||
func New(svc *svcdef.Service) *ClientServiceArgs {
|
||||
svcArgs := ClientServiceArgs{
|
||||
MethArgs: make(map[string]*MethodArgs),
|
||||
}
|
||||
for _, meth := range svc.Methods {
|
||||
m := MethodArgs{}
|
||||
for _, field := range meth.RequestType.Fields {
|
||||
// Skip map fields, as they're currently incorrectly implemented by
|
||||
// deftree
|
||||
// TODO implement correct map support in client argument generation
|
||||
if field.IsMap {
|
||||
// TODO implement correct map support in client argument generation
|
||||
for _, field := range meth.RequestType.Message.Fields {
|
||||
if field.Type.Map != nil {
|
||||
continue
|
||||
}
|
||||
newArg := newClientArg(meth.GetName(), field)
|
||||
newArg := newClientArg(meth.Name, field)
|
||||
m.Args = append(m.Args, newArg)
|
||||
}
|
||||
svcArgs.MethArgs[meth.GetName()] = &m
|
||||
svcArgs.MethArgs[meth.Name] = &m
|
||||
}
|
||||
|
||||
return &svcArgs
|
||||
}
|
||||
|
||||
// newClientArg returns a ClientArg generated from the provided method name and MessageField
|
||||
func newClientArg(methName string, field *deftree.MessageField) *ClientArg {
|
||||
func newClientArg(methName string, field *svcdef.Field) *ClientArg {
|
||||
newArg := ClientArg{}
|
||||
newArg.Name = field.GetName()
|
||||
newArg.Name = lowCamelName(field.Name)
|
||||
|
||||
if field.Label == "LABEL_REPEATED" {
|
||||
if field.Type.ArrayType {
|
||||
newArg.Repeated = true
|
||||
}
|
||||
newArg.ProtbufType = field.Type.GetName()
|
||||
|
||||
newArg.FlagName = fmt.Sprintf("%s.%s", strings.ToLower(methName), strings.ToLower(field.GetName()))
|
||||
newArg.FlagArg = fmt.Sprintf("flag%s%s", generatego.CamelCase(newArg.Name), generatego.CamelCase(methName))
|
||||
newArg.FlagName = fmt.Sprintf("%s.%s", strings.ToLower(methName), strings.ToLower(field.Name))
|
||||
newArg.FlagArg = fmt.Sprintf("flag%s%s", gogen.CamelCase(field.Name), gogen.CamelCase(methName))
|
||||
|
||||
if field.Type.Enum != nil {
|
||||
newArg.Enum = true
|
||||
log.WithField("Method", methName).WithField("Arg", newArg.Name).Debugf("type: %s", field.Type.GetName())
|
||||
}
|
||||
// Determine the FlagType and flag invocation
|
||||
var ft string
|
||||
var ok bool
|
||||
log.WithField("Method", methName).WithField("Arg", newArg.Name).Debugf("type: %s", field.Type.GetName())
|
||||
// For types outside the base types, have flag treat them as strings
|
||||
if ft, ok = ProtoToGoTypeMap[field.Type.GetName()]; !ok {
|
||||
if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil {
|
||||
ft = field.Type.Name
|
||||
newArg.IsBaseType = true
|
||||
} else {
|
||||
// For types outside the base types, have flag treat them as strings
|
||||
ft = "string"
|
||||
newArg.IsBaseType = false
|
||||
} else {
|
||||
newArg.IsBaseType = true
|
||||
}
|
||||
if newArg.Repeated {
|
||||
ft = "string"
|
||||
|
@ -216,19 +213,13 @@ func newClientArg(methName string, field *deftree.MessageField) *ClientArg {
|
|||
newArg.FlagType = ft
|
||||
newArg.FlagConvertFunc = createFlagConvertFunc(newArg)
|
||||
|
||||
newArg.GoArg = fmt.Sprintf("%s%s", generatego.CamelCase(newArg.Name), generatego.CamelCase(methName))
|
||||
newArg.GoArg = fmt.Sprintf("%s%s", gogen.CamelCase(newArg.Name), gogen.CamelCase(methName))
|
||||
// For types outside the base types, treat them as strings
|
||||
if newArg.IsBaseType {
|
||||
newArg.GoType = ProtoToGoTypeMap[field.Type.GetName()]
|
||||
//newArg.GoType = ProtoToGoTypeMap[field.Type.GetName()]
|
||||
newArg.GoType = field.Type.Name
|
||||
} else {
|
||||
// TODO: Have GoType derivation respect nested definitions
|
||||
tn := field.Type.GetName()
|
||||
sections := strings.Split(tn, ".")
|
||||
// Extract everything after the package name
|
||||
remaining := sections[2:]
|
||||
tn = generatego.CamelCaseSlice(remaining)
|
||||
//log.WithField("Method", methName).WithField("Arg", newArg.Name).Warnf("type: %v, %v", sections[2:], generatego.CamelCaseSlice(sections[2:]))
|
||||
newArg.GoType = "pb." + tn
|
||||
newArg.GoType = "pb." + field.Type.Name
|
||||
}
|
||||
// The GoType is a slice of the GoType if it's a repeated field
|
||||
if newArg.Repeated {
|
||||
|
@ -257,7 +248,7 @@ if {{.FlagArg}} != nil && len(*{{.FlagArg}}) > 0 {
|
|||
}
|
||||
`
|
||||
if a.Repeated || !a.IsBaseType {
|
||||
code, err := ApplyTemplate("UnmarshalCliArgs", jsonConvTmpl, a, nil)
|
||||
code, err := applyTemplate("UnmarshalCliArgs", jsonConvTmpl, a, nil)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("Couldn't apply template: %v", err))
|
||||
}
|
||||
|
@ -312,10 +303,10 @@ func flagTypeConversion(a ClientArg) string {
|
|||
return fmt.Sprintf(fType, a.FlagArg)
|
||||
}
|
||||
|
||||
// ApplyTemplate applies a template with a given name, executor context, and
|
||||
// applyTemplate applies a template with a given name, executor context, and
|
||||
// function map. Returns the output of the template on success, returns an
|
||||
// error if template failed to execute.
|
||||
func ApplyTemplate(name string, tmpl string, executor interface{}, fncs template.FuncMap) (string, error) {
|
||||
func applyTemplate(name string, tmpl string, executor interface{}, fncs template.FuncMap) (string, error) {
|
||||
codeTemplate := template.Must(template.New(name).Funcs(fncs).Parse(tmpl))
|
||||
|
||||
code := bytes.NewBuffer(nil)
|
||||
|
@ -325,3 +316,17 @@ func ApplyTemplate(name string, tmpl string, executor interface{}, fncs template
|
|||
}
|
||||
return code.String(), nil
|
||||
}
|
||||
|
||||
// lowCamelName returns a CamelCased string, but with the first letter
|
||||
// lowercased. "package_name" becomes "packageName".
|
||||
func lowCamelName(s string) string {
|
||||
s = gogen.CamelCase(s)
|
||||
new := []rune(s)
|
||||
if len(new) < 1 {
|
||||
return s
|
||||
}
|
||||
rv := []rune{}
|
||||
rv = append(rv, unicode.ToLower(new[0]))
|
||||
rv = append(rv, new[1:]...)
|
||||
return string(rv)
|
||||
}
|
||||
|
|
|
@ -6,8 +6,8 @@ import (
|
|||
|
||||
"github.com/davecgh/go-spew/spew"
|
||||
|
||||
"github.com/TuneLab/go-truss/deftree"
|
||||
"github.com/TuneLab/go-truss/gengokit/gentesthelper"
|
||||
"github.com/TuneLab/go-truss/svcdef"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -17,57 +17,37 @@ var (
|
|||
)
|
||||
|
||||
func TestNewClientServiceArgs(t *testing.T) {
|
||||
svc := deftree.ProtoService{
|
||||
Name: "AddSvc",
|
||||
Methods: []*deftree.ServiceMethod{
|
||||
&deftree.ServiceMethod{
|
||||
Name: "Sum",
|
||||
RequestType: &deftree.ProtoMessage{
|
||||
Name: "SumRequest",
|
||||
Fields: []*deftree.MessageField{
|
||||
&deftree.MessageField{
|
||||
Name: "a",
|
||||
Number: 1,
|
||||
Label: "LABEL_REPEATED",
|
||||
Type: deftree.FieldType{
|
||||
Name: "TYPE_INT64",
|
||||
},
|
||||
},
|
||||
&deftree.MessageField{
|
||||
Name: "b",
|
||||
Number: 2,
|
||||
Label: "LABEL_OPTIONAL",
|
||||
Type: deftree.FieldType{
|
||||
Name: "TYPE_INT64",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
ResponseType: &deftree.ProtoMessage{
|
||||
Name: "SumReply",
|
||||
Fields: []*deftree.MessageField{
|
||||
&deftree.MessageField{
|
||||
Name: "v",
|
||||
Number: 1,
|
||||
Label: "LABEL_OPTIONAL",
|
||||
Type: deftree.FieldType{
|
||||
Name: "TYPE_INT64",
|
||||
},
|
||||
},
|
||||
&deftree.MessageField{
|
||||
Name: "err",
|
||||
Number: 2,
|
||||
Label: "LABEL_OPTIONAL",
|
||||
Type: deftree.FieldType{
|
||||
Name: "TYPE_STRING",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
defStr := `
|
||||
syntax = "proto3";
|
||||
|
||||
// General package
|
||||
package general;
|
||||
|
||||
import "google/api/annotations.proto";
|
||||
|
||||
message SumRequest {
|
||||
repeated int64 a = 1;
|
||||
int64 b = 2;
|
||||
}
|
||||
|
||||
message SumReply {
|
||||
int64 v = 1;
|
||||
string err = 2;
|
||||
}
|
||||
|
||||
service SumSvc {
|
||||
rpc Sum(SumRequest) returns (SumReply) {
|
||||
option (google.api.http) = {
|
||||
get: "/sum/{a}"
|
||||
};
|
||||
}
|
||||
}
|
||||
`
|
||||
sd, err := svcdef.NewFromString(defStr)
|
||||
if err != nil {
|
||||
t.Fatal(err, "Failed to create a service from the definition string")
|
||||
}
|
||||
csa := New(&svc)
|
||||
csa := New(sd.Service)
|
||||
|
||||
expected := &ClientServiceArgs{
|
||||
MethArgs: map[string]*MethodArgs{
|
||||
|
@ -82,7 +62,7 @@ func TestNewClientServiceArgs(t *testing.T) {
|
|||
GoArg: "ASum",
|
||||
GoType: "[]int64",
|
||||
GoConvertInvoc: "\nvar ASum []int64\nif flagASum != nil && len(*flagASum) > 0 {\n\terr = json.Unmarshal([]byte(*flagASum), &ASum)\n\tif err != nil {\n\t\tpanic(errors.Wrapf(err, \"unmarshalling ASum from %v:\", flagASum))\n\t}\n}\n",
|
||||
ProtbufType: "TYPE_INT64",
|
||||
ProtbufType: "",
|
||||
IsBaseType: true,
|
||||
Repeated: true,
|
||||
},
|
||||
|
@ -96,7 +76,7 @@ func TestNewClientServiceArgs(t *testing.T) {
|
|||
GoArg: "BSum",
|
||||
GoType: "int64",
|
||||
GoConvertInvoc: "BSum := *flagBSum",
|
||||
ProtbufType: "TYPE_INT64",
|
||||
ProtbufType: "",
|
||||
IsBaseType: true,
|
||||
Repeated: false,
|
||||
},
|
||||
|
|
|
@ -67,7 +67,7 @@ func newTemplateExecutor(dt deftree.Deftree, sd *svcdef.Svcdef, conf config.Conf
|
|||
PBImportPath: conf.PBPackage,
|
||||
PackageName: sd.PkgName,
|
||||
Service: service,
|
||||
ClientArgs: clientarggen.New(service),
|
||||
ClientArgs: clientarggen.New(sd.Service),
|
||||
HTTPHelper: httptransport.NewHelper(sd.Service),
|
||||
funcMap: funcMap,
|
||||
}, nil
|
||||
|
|
|
@ -80,25 +80,25 @@ func NewBinding(i int, meth *svcdef.ServiceMethod) *Binding {
|
|||
for _, param := range binding.Params {
|
||||
// The 'Field' attr of each HTTPParameter always point to it's bound
|
||||
// Methods RequestType
|
||||
rq := param.Field
|
||||
field := param.Field
|
||||
newField := Field{
|
||||
Name: rq.Name,
|
||||
CamelName: gogen.CamelCase(rq.Name),
|
||||
LowCamelName: LowCamelName(rq.Name),
|
||||
Name: field.Name,
|
||||
CamelName: gogen.CamelCase(field.Name),
|
||||
LowCamelName: LowCamelName(field.Name),
|
||||
Location: param.Location,
|
||||
Repeated: rq.Type.ArrayType,
|
||||
GoType: rq.Type.Name,
|
||||
LocalName: fmt.Sprintf("%s%s", gogen.CamelCase(rq.Name), gogen.CamelCase(meth.Name)),
|
||||
Repeated: field.Type.ArrayType,
|
||||
GoType: field.Type.Name,
|
||||
LocalName: fmt.Sprintf("%s%s", gogen.CamelCase(field.Name), gogen.CamelCase(meth.Name)),
|
||||
}
|
||||
|
||||
if rq.Type.Message == nil && rq.Type.Enum == nil && rq.Type.Map == nil {
|
||||
if field.Type.Message == nil && field.Type.Enum == nil && field.Type.Map == nil {
|
||||
newField.IsBaseType = true
|
||||
}
|
||||
|
||||
// Modify GoType to reflect pointer or repeated status
|
||||
if rq.Type.StarExpr && rq.Type.ArrayType {
|
||||
if field.Type.StarExpr && field.Type.ArrayType {
|
||||
newField.GoType = "[]*" + newField.GoType
|
||||
} else if rq.Type.ArrayType {
|
||||
} else if field.Type.ArrayType {
|
||||
newField.GoType = "[]" + newField.GoType
|
||||
}
|
||||
|
||||
|
|
|
@ -92,7 +92,6 @@ func TestNewMethod(t *testing.T) {
|
|||
if got, want := newMeth, meth; !reflect.DeepEqual(got, want) {
|
||||
diff := gentesthelper.DiffStrings(spew.Sdump(got), spew.Sdump(want))
|
||||
t.Errorf("got != want; methods differ: %v\n", diff)
|
||||
//t.Errorf("methods differ;\ngot = %+v\nwant = %+v\n", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче