twirp-ruby/protoc-gen-twirp_ruby/main.go

330 строки
9.3 KiB
Go

// Copyright 2018 Twitch Interactive, Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may not
// use this file except in compliance with the License. A copy of the License is
// located at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// or in the "license" file accompanying this file. This file is distributed on
// an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.
package main
import (
"bytes"
"fmt"
"io"
"io/ioutil"
"log"
"os"
"path/filepath"
"strings"
"unicode"
"google.golang.org/protobuf/proto"
descriptor "google.golang.org/protobuf/types/descriptorpb"
plugin "google.golang.org/protobuf/types/pluginpb"
"github.com/arthurnn/twirp-ruby/internal/gen/typemap"
)
func main() {
genReq := readGenRequest(os.Stdin)
genResp := newGenerator(genReq).Generate()
writeGenResponse(os.Stdout, genResp)
}
func newGenerator(genReq *plugin.CodeGeneratorRequest) *generator {
return &generator{
version: Version,
genReq: genReq,
fileToGoPackageName: make(map[*descriptor.FileDescriptorProto]string),
reg: typemap.New(genReq.ProtoFile),
}
}
type generator struct {
version string
genReq *plugin.CodeGeneratorRequest
reg *typemap.Registry
genFiles []*descriptor.FileDescriptorProto
fileToGoPackageName map[*descriptor.FileDescriptorProto]string
}
func fileDescSliceContains(slice []*descriptor.FileDescriptorProto, f *descriptor.FileDescriptorProto) bool {
for _, sf := range slice {
if f == sf {
return true
}
}
return false
}
func (g *generator) Generate() *plugin.CodeGeneratorResponse {
resp := new(plugin.CodeGeneratorResponse)
resp.SupportedFeatures = proto.Uint64(uint64(plugin.CodeGeneratorResponse_FEATURE_PROTO3_OPTIONAL))
g.findProtoFilesToGenerate()
for _, f := range g.genFiles {
twirpFileName := noExtension(filePath(f)) + "_twirp.rb" // e.g. "hello_world/service_twirp.rb"
pbFileRelativePath := noExtension(onlyBase(filePath(f))) + "_pb.rb" // e.g. "service_pb.rb"
rubyCode := g.generateRubyCode(f, pbFileRelativePath)
respFile := &plugin.CodeGeneratorResponse_File{
Name: proto.String(twirpFileName),
Content: proto.String(rubyCode),
}
resp.File = append(resp.File, respFile)
}
return resp
}
func (g *generator) generateRubyCode(file *descriptor.FileDescriptorProto, pbFileRelativePath string) string {
b := new(bytes.Buffer)
print(b, "# Code generated by protoc-gen-twirp_ruby %s, DO NOT EDIT.", g.version)
print(b, "require 'twirp'")
print(b, "require_relative '%s'", pbFileRelativePath) // require generated file with messages
print(b, "")
indent := indentation(0)
pkgName := file.GetPackage()
var modules []string
if file.Options != nil && file.Options.RubyPackage != nil {
modules = strings.Split(*file.Options.RubyPackage, "::")
} else {
modules = splitRubyConstants(pkgName)
}
for _, m := range modules {
print(b, "%smodule %s", indent, m)
indent += 1
}
for i, service := range file.Service {
svcName := service.GetName()
print(b, "%sclass %sService < ::Twirp::Service", indent, camelCase(svcName))
if pkgName != "" {
print(b, "%s package '%s'", indent, pkgName)
}
print(b, "%s service '%s'", indent, svcName)
for _, method := range service.GetMethod() {
rpcName := method.GetName()
rpcInput := g.toRubyType(method.GetInputType())
rpcOutput := g.toRubyType(method.GetOutputType())
print(b, "%s rpc :%s, %s, %s, :ruby_method => :%s",
indent, rpcName, rpcInput, rpcOutput, snakeCase(rpcName))
}
print(b, "%send", indent)
print(b, "")
print(b, "%sclass %sClient < ::Twirp::Client", indent, camelCase(svcName))
print(b, "%s client_for %sService", indent, camelCase(svcName))
print(b, "%send", indent)
if i < len(file.Service)-1 {
print(b, "")
}
}
for range modules {
indent -= 1
print(b, "%send", indent)
}
return b.String()
}
// protoFilesToGenerate selects descriptor proto files that were explicitly listed on the command-line.
func (g *generator) findProtoFilesToGenerate() {
for _, name := range g.genReq.FileToGenerate { // explicitly listed on the command-line
for _, f := range g.genReq.ProtoFile { // all files and everything they import
if f.GetName() == name { // match
g.genFiles = append(g.genFiles, f)
continue
}
}
}
for _, f := range g.genReq.ProtoFile {
if fileDescSliceContains(g.genFiles, f) {
g.fileToGoPackageName[f] = ""
} else {
g.fileToGoPackageName[f] = f.GetPackage()
}
}
}
// indentation represents the level of Ruby indentation for a block of code. It
// implements the fmt.Stringer interface to output the correct number of spaces
// for the given level of indentation
type indentation int
func (i indentation) String() string {
return strings.Repeat(" ", int(i))
}
func print(buf *bytes.Buffer, tpl string, args ...interface{}) {
buf.WriteString(fmt.Sprintf(tpl, args...))
buf.WriteByte('\n')
}
func filePath(f *descriptor.FileDescriptorProto) string {
return *f.Name
}
func onlyBase(path string) string {
return filepath.Base(path)
}
func noExtension(path string) string {
ext := filepath.Ext(path)
return strings.TrimSuffix(path, ext)
}
func Fail(msgs ...string) {
s := strings.Join(msgs, " ")
log.Print("error:", s)
os.Exit(1)
}
func readGenRequest(r io.Reader) *plugin.CodeGeneratorRequest {
data, err := ioutil.ReadAll(r)
if err != nil {
Fail(err.Error(), "reading input")
}
req := new(plugin.CodeGeneratorRequest)
if err = proto.Unmarshal(data, req); err != nil {
Fail(err.Error(), "parsing input proto")
}
if len(req.FileToGenerate) == 0 {
Fail("no files to generate")
}
return req
}
func writeGenResponse(w io.Writer, resp *plugin.CodeGeneratorResponse) {
data, err := proto.Marshal(resp)
if err != nil {
Fail(err.Error(), "marshaling response")
}
_, err = w.Write(data)
if err != nil {
Fail(err.Error(), "writing response")
}
}
// toRubyType converts a protobuf type reference to a Ruby constant.
// e.g. toRubyType("MyMessage", []string{}) => "MyMessage"
// e.g. toRubyType(".foo.my_message", []string{}) => "Foo::MyMessage"
// e.g. toRubyType(".foo.my_message", []string{"Foo"}) => "MyMessage"
// e.g. toRubyType("google.protobuf.Empty", []string{"Foo"}) => "Google::Protobuf::Empty"
func (g *generator) toRubyType(protoType string) string {
def := g.reg.MessageDefinition(protoType)
if def == nil {
panic("could not find message for " + protoType)
}
var prefix string
if pkg := g.fileToGoPackageName[def.File]; pkg != "" {
prefix = strings.Join(splitRubyConstants(pkg), "::") + "::"
}
var name string
for _, parent := range def.Lineage() {
name += camelCase(parent.Descriptor.GetName()) + "::"
}
name += camelCase(def.Descriptor.GetName())
return prefix + name
}
// splitRubyConstants converts a namespaced protobuf type (package name or mesasge)
// to a list of names that can be used as Ruby constants.
// e.g. splitRubyConstants("my.cool.package") => ["My", "Cool", "Package"]
// e.g. splitRubyConstants("google.protobuf.Empty") => ["Google", "Protobuf", "Empty"]
func splitRubyConstants(protoPckgName string) []string {
if protoPckgName == "" {
return []string{} // no modules
}
parts := []string{}
for _, p := range strings.Split(protoPckgName, ".") {
parts = append(parts, camelCase(p))
}
return parts
}
// snakeCase converts a string from CamelCase to snake_case.
func snakeCase(s string) string {
var buf bytes.Buffer
for i, r := range s {
if unicode.IsUpper(r) && i > 0 {
fmt.Fprintf(&buf, "_")
}
r = unicode.ToLower(r)
fmt.Fprintf(&buf, "%c", r)
}
return buf.String()
}
// camelCase converts a string from snake_case to CamelCased.
// If there is an interior underscore followed by a lower case letter, drop the
// underscore and convert the letter to upper case. There is a remote
// possibility of this rewrite causing a name collision, but it's so remote
// we're prepared to pretend it's nonexistent - since the C++ generator
// lowercases names, it's extremely unlikely to have two fields with different
// capitalizations. In short, _my_field_name_2 becomes XMyFieldName_2.
func camelCase(s string) string {
if s == "" {
return ""
}
t := make([]byte, 0, 32)
i := 0
if s[0] == '_' {
// Need a capital letter; drop the '_'.
t = append(t, 'X')
i++
}
// Invariant: if the next letter is lower case, it must be converted
// to upper case.
//
// That is, we process a word at a time, where words are marked by _ or upper
// case letter. Digits are treated as words.
for ; i < len(s); i++ {
c := s[i]
if c == '_' && i+1 < len(s) && (isASCIILower(s[i+1]) || isASCIIDigit(s[i+1])) {
continue // Skip the underscore in s.
}
// Assume we have a letter now - if not, it's a bogus identifier. The next
// word is a sequence of characters that must start upper case.
if isASCIILower(c) {
c ^= ' ' // Make it a capital letter.
}
t = append(t, c) // Guaranteed not lower case.
// Accept lower case sequence that follows.
for i+1 < len(s) && (isASCIILower(s[i+1]) || isASCIIDigit(s[i+1])) {
i++
t = append(t, s[i])
}
}
return string(t)
}
// Is c an ASCII lower-case letter?
func isASCIILower(c byte) bool {
return 'a' <= c && c <= 'z'
}
// Is c an ASCII digit?
func isASCIIDigit(c byte) bool {
return '0' <= c && c <= '9'
}