diff --git a/example/Gemfile.lock b/example/Gemfile.lock index 7dbf097..a98eaad 100644 --- a/example/Gemfile.lock +++ b/example/Gemfile.lock @@ -1,7 +1,7 @@ PATH remote: .. specs: - twirp (0.5.2) + twirp (1.0.0) faraday (~> 0) google-protobuf (~> 3.0, >= 3.0.0) diff --git a/example/hello_world/service.proto b/example/hello_world/service.proto index 1e70a47..f7b4a81 100644 --- a/example/hello_world/service.proto +++ b/example/hello_world/service.proto @@ -1,6 +1,7 @@ syntax = "proto3"; package example.hello_world; + service HelloWorld { rpc Hello(HelloRequest) returns (HelloResponse); } diff --git a/protoc-gen-twirp_ruby/main.go b/protoc-gen-twirp_ruby/main.go index 49ab1ec..140d365 100644 --- a/protoc-gen-twirp_ruby/main.go +++ b/protoc-gen-twirp_ruby/main.go @@ -68,7 +68,7 @@ func (g *generator) generateRubyCode(file *descriptor.FileDescriptorProto, pbFil indent := indentation(0) pkgName := file.GetPackage() - modules := packageToRubyModules(pkgName) + modules := splitRubyConstants(pkgName) for _, m := range modules { print(b, "%smodule %s", indent, m) indent += 1 @@ -84,8 +84,10 @@ func (g *generator) generateRubyCode(file *descriptor.FileDescriptorProto, pbFil print(b, "%s service '%s'", indent, svcName) for _, method := range service.GetMethod() { rpcName := method.GetName() + rpcInput := toRubyType(method.GetInputType(), modules) + rpcOutput := toRubyType(method.GetOutputType(), modules) print(b, "%s rpc :%s, %s, %s, :ruby_method => :%s", - indent, rpcName, methodInputName(method), methodOutputName(method), snakeCase(rpcName)) + indent, rpcName, rpcInput, rpcOutput, snakeCase(rpcName)) } print(b, "%send", indent) print(b, "") @@ -147,18 +149,6 @@ func noExtension(path string) string { return strings.TrimSuffix(path, ext) } -func methodInputName(meth *descriptor.MethodDescriptorProto) string { - fullName := meth.GetInputType() - split := strings.Split(fullName, ".") - return split[len(split)-1] -} - -func methodOutputName(meth *descriptor.MethodDescriptorProto) string { - fullName := meth.GetOutputType() - split := strings.Split(fullName, ".") - return split[len(split)-1] -} - func Fail(msgs ...string) { s := strings.Join(msgs, " ") log.Print("error:", s) @@ -194,15 +184,42 @@ func writeGenResponse(w io.Writer, resp *plugin.CodeGeneratorResponse) { } } -// Modules converts protobuf package name to a list of Ruby module names to -// represent it. e.g. packageToRubyModules("my.cool.package") => ["My", "Cool", "Package"] -func packageToRubyModules(pkgName string) []string { - if pkgName == "" { +// toRubyType converts a protobuf type reference to a Ruby constant. +// e.g. toRubyType("MyMessage", []string{}) => "MyMessage" +// e.g. toRubyType(".foo.my_message", []string{}) => "Foo::MyMessage" +// e.g. toRubyType(".foo.my_message", []string{"Foo"}) => "MyMessage" +// e.g. toRubyType("google.protobuf.Empty", []string{"Foo"}) => "Google::Protobuf::Empty" +func toRubyType(protoType string, currentModules []string) string { + rubyConsts := splitRubyConstants(protoType) + if len(rubyConsts) == 0 { + return "" + } + rubyType := strings.Join(rubyConsts, "::") + + if len(rubyType) > 2 && rubyType[0:2] == "::" { + rubyType = rubyType[2:len(rubyType)] // Remove leading :: + } + + // Remove leading modules if they are the same as in the current context + currentModulesConst := strings.Join(currentModules, "::") + "::" + if strings.HasPrefix(rubyType, currentModulesConst) { + rubyType = rubyType[len(currentModulesConst):len(rubyType)] + } + + return rubyType +} + +// splitRubyConstants converts a namespaced protobuf type (package name or mesasge) +// to a list of names that can be used as Ruby constants. +// e.g. splitRubyConstants("my.cool.package") => ["My", "Cool", "Package"] +// e.g. splitRubyConstants("google.protobuf.Empty") => ["Google", "Protobuf", "Empty"] +func splitRubyConstants(protoPckgName string) []string { + if protoPckgName == "" { return []string{} // no modules } parts := []string{} - for _, p := range strings.Split(pkgName, ".") { + for _, p := range strings.Split(protoPckgName, ".") { parts = append(parts, camelCase(p)) } return parts diff --git a/protoc-gen-twirp_ruby/main_test.go b/protoc-gen-twirp_ruby/main_test.go index e606df0..594d2f8 100644 --- a/protoc-gen-twirp_ruby/main_test.go +++ b/protoc-gen-twirp_ruby/main_test.go @@ -35,19 +35,43 @@ func TestFilePathOnlyBaseNoExtension(t *testing.T) { } } -func TestPackageToRubyModules(t *testing.T) { +func TestToRubyType(t *testing.T) { + tests := []struct { + protoType string + modules []string + expected string + }{ + {"", []string{}, ""}, + {"", []string{"Foo", "Bar"}, ""}, + {".foo.my_message", []string{}, "Foo::MyMessage"}, + {".foo.my_message", []string{"Foo"}, "MyMessage"}, + {"m.v.p99.hello_world", []string{}, "M::V::P99::HelloWorld"}, + {"m.v.p99.hello_world", []string{"M", "V"}, "P99::HelloWorld"}, + {"m.v.p99.hello_world", []string{"M", "V", "P99"}, "HelloWorld"}, + {"m.v.p99.hello_world", []string{"P99"}, "M::V::P99::HelloWorld"}, + {"google.protobuf.Empty", []string{"Foo"}, "Google::Protobuf::Empty"}, + } + for _, tt := range tests { + actual := toRubyType(tt.protoType, tt.modules) + if !reflect.DeepEqual(actual, tt.expected) { + t.Errorf("expected %v; actual %v", tt.expected, actual) + } + } +} + +func TestSplitRubyConstants(t *testing.T) { tests := []struct { pkgName string expected []string }{ + {"", []string{}}, {"example", []string{"Example"}}, {"example.hello_world", []string{"Example", "HelloWorld"}}, {"m.v.p", []string{"M", "V", "P"}}, - {"p99.a2z", []string{"P99", "A2z"}}, // with numbers - {"", []string{}}, // empty, no modules + {"p99.a2z", []string{"P99", "A2z"}}, } for _, tt := range tests { - actual := packageToRubyModules(tt.pkgName) + actual := splitRubyConstants(tt.pkgName) if !reflect.DeepEqual(actual, tt.expected) { t.Errorf("expected %v; actual %v", tt.expected, actual) }