Add context to transport to context for http transport

This commit is contained in:
Adam Ryman 2016-10-25 15:47:19 -07:00
Родитель 7f984bfdff
Коммит 0e8bb160a6
2 изменённых файлов: 117 добавлений и 67 удалений

Просмотреть файл

@ -3,7 +3,7 @@ package httptransport
// ServerDecodeTemplate is the template for generating the server-side decoding
// function for a particular Binding.
var ServerDecodeTemplate = `
{{ with $binding := .}}
{{- with $binding := . -}}
// DecodeHTTP{{$binding.Label}}Request is a transport/http.DecodeRequestFunc that
// decodes a JSON-encoded {{ToLower $binding.Parent.Name}} request from the HTTP request
// body. Primarily useful in a server.
@ -34,71 +34,71 @@ var ServerDecodeTemplate = `
{{end}}
return &req, err
}
{{end}}
{{- end -}}
`
// ClientEncodeTemplate is the template for generating the client-side encoding
// function for a particular Binding.
var ClientEncodeTemplate = `
{{ with $binding := .}}
// EncodeHTTP{{$binding.Label}}Request is a transport/http.EncodeRequestFunc
// that encodes a {{ToLower $binding.Parent.Name}} request into the various portions of
// the http request (path, query, and body).
func EncodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request, request interface{}) error {
fmt.Printf("Encoding request %v\n", request)
req := request.(*pb.{{GoName $binding.Parent.RequestType}})
_ = req
{{- with $binding := . -}}
// EncodeHTTP{{$binding.Label}}Request is a transport/http.EncodeRequestFunc
// that encodes a {{ToLower $binding.Parent.Name}} request into the various portions of
// the http request (path, query, and body).
func EncodeHTTP{{$binding.Label}}Request(_ context.Context, r *http.Request, request interface{}) error {
fmt.Printf("Encoding request %v\n", request)
req := request.(*pb.{{GoName $binding.Parent.RequestType}})
_ = req
// Set the path parameters
path := strings.Join([]string{
{{- range $section := $binding.PathSections}}
{{$section}},
{{- end}}
}, "/")
u, err := url.Parse(path)
if err != nil {
return errors.Wrapf(err, "couldn't unmarshal path %q", path)
}
r.URL.RawPath = u.RawPath
r.URL.Path = u.Path
// Set the path parameters
path := strings.Join([]string{
{{- range $section := $binding.PathSections}}
{{$section}},
{{- end}}
}, "/")
u, err := url.Parse(path)
if err != nil {
return errors.Wrapf(err, "couldn't unmarshal path %q", path)
}
r.URL.RawPath = u.RawPath
r.URL.Path = u.Path
// Set the query parameters
values := r.URL.Query()
var tmp []byte
_ = tmp
{{- range $field := $binding.Fields }}
{{- if eq $field.Location "query"}}
{{if or (not $field.IsBaseType) $field.Repeated}}
tmp, err = json.Marshal(req.{{$field.CamelName}})
if err != nil {
return errors.Wrap(err, "failed to marshal req.{{$field.CamelName}}")
}
values.Add("{{$field.Name}}", string(tmp))
{{else}}
values.Add("{{$field.Name}}", fmt.Sprint(req.{{$field.CamelName}}))
// Set the query parameters
values := r.URL.Query()
var tmp []byte
_ = tmp
{{- range $field := $binding.Fields }}
{{- if eq $field.Location "query"}}
{{if or (not $field.IsBaseType) $field.Repeated}}
tmp, err = json.Marshal(req.{{$field.CamelName}})
if err != nil {
return errors.Wrap(err, "failed to marshal req.{{$field.CamelName}}")
}
values.Add("{{$field.Name}}", string(tmp))
{{else}}
values.Add("{{$field.Name}}", fmt.Sprint(req.{{$field.CamelName}}))
{{- end }}
{{- end }}
{{- end }}
{{- end}}
{{- end}}
r.URL.RawQuery = values.Encode()
r.URL.RawQuery = values.Encode()
// Set the body parameters
var buf bytes.Buffer
toRet := map[string]interface{}{
{{- range $field := $binding.Fields -}}
{{if eq $field.Location "body"}}
"{{$field.CamelName}}" : req.{{$field.CamelName}},
{{end}}
{{- end -}}
// Set the body parameters
var buf bytes.Buffer
toRet := map[string]interface{}{
{{- range $field := $binding.Fields -}}
{{if eq $field.Location "body"}}
"{{$field.CamelName}}" : req.{{$field.CamelName}},
{{end}}
{{- end -}}
}
if err := json.NewEncoder(&buf).Encode(toRet); err != nil {
return errors.Wrapf(err, "couldn't encode body as json %v", toRet)
}
r.Body = ioutil.NopCloser(&buf)
fmt.Printf("URL: %v\n", r.URL)
return nil
}
if err := json.NewEncoder(&buf).Encode(toRet); err != nil {
return errors.Wrapf(err, "couldn't encode body as json %v", toRet)
}
r.Body = ioutil.NopCloser(&buf)
fmt.Printf("URL: %v\n", r.URL)
return nil
}
{{end}}
{{- end -}}
`
// WARNING: Changing the contents of these strings, even a little bit, will cause tests
@ -228,10 +228,9 @@ var (
// MakeHTTPHandler returns a handler that makes a set of endpoints available
// on predefined paths.
func MakeHTTPHandler(ctx context.Context, endpoints Endpoints, logger log.Logger) http.Handler {
/*options := []httptransport.ServerOption{
httptransport.ServerErrorEncoder(errorEncoder),
httptransport.ServerErrorLogger(logger),
}*/
serverOptions := []httptransport.ServerOption{
httptransport.ServerBefore(headersToContext),
}
m := http.NewServeMux()
{{range $method := .HTTPHelper.Methods}}
@ -241,7 +240,7 @@ func MakeHTTPHandler(ctx context.Context, endpoints Endpoints, logger log.Logger
endpoints.{{$method.Name}}Endpoint,
HttpDecodeLogger(DecodeHTTP{{$binding.Label}}Request, logger),
EncodeHTTPGenericResponse,
//options...,
serverOptions...,
))
{{- end}}
{{- end}}
@ -319,6 +318,14 @@ func EncodeHTTPGenericResponse(_ context.Context, w http.ResponseWriter, respons
// Helper functions
{{.HTTPHelper.PathParamsBuilder}}
func headersToContext(ctx context.Context, r *http.Request) context.Context {
for k, v := range r.Header {
ctx = context.WithValue(ctx, k, v)
}
return ctx
}
`
var clientTemplate = `
@ -328,9 +335,12 @@ package http
import (
"net/url"
"strings"
"net/http"
"github.com/go-kit/kit/endpoint"
httptransport "github.com/go-kit/kit/transport/http"
"github.com/pkg/errors"
"golang.org/x/net/context"
// This Service
handler "{{.ImportPath -}} /handlers/server"
@ -345,10 +355,20 @@ var (
// New returns a service backed by an HTTP server living at the remote
// instance. We expect instance to come from a service discovery system, so
// likely of the form "host:port".
func New(instance string) (handler.Service, error) {
//options := []httptransport.ClientOptions{
//httptransport.ClientBefore(),
//}
func New(instance string, options ...ClientOption) (handler.Service, error) {
var cc clientConfig
for _, f := range options {
err := f(&cc)
if err != nil {
return nil, errors.Wrap(err, "cannot apply option")
}
}
clientOptions := []httptransport.ClientOption{
httptransport.ClientBefore(
contextValuesToHttpHeaders(cc.headers)),
}
if !strings.HasPrefix(instance, "http") {
instance = "http://" + instance
@ -368,7 +388,7 @@ func New(instance string) (handler.Service, error) {
copyURL(u, "{{ToLower $binding.BasePath}}"),
svc.EncodeHTTP{{$binding.Label}}Request,
svc.DecodeHTTP{{$method.Name}}Response,
//options...,
clientOptions...,
).Endpoint()
}
{{- end}}
@ -389,4 +409,33 @@ func copyURL(base *url.URL, path string) *url.URL {
return &next
}
type clientConfig struct {
headers []string
}
// ClientOption is a function that modifies the client config
type ClientOption func(*clientConfig) error
// CtxValuesToSend configures the http client to pull the specified keys out of
// the context and add them to the http request as headers. Note that keys
// will have net/http.CanonicalHeaderKey called on them before being send over
// the wire and that is the form they will be available in the server context.
func CtxValuesToSend(keys []string) ClientOption {
return func(o *clientConfig) error {
o.headers = keys
return nil
}
}
func contextValuesToHttpHeaders(keys []string) httptransport.RequestFunc {
return func(ctx context.Context, r *http.Request) context.Context {
for _, k := range keys {
if v, ok := ctx.Value(k).(string); ok {
r.Header.Set(k, v)
}
}
return ctx
}
}
`

Просмотреть файл

@ -1,6 +1,7 @@
package httptransport
import (
"strings"
"testing"
"github.com/TuneLab/go-truss/gengokit/gentesthelper"
@ -97,7 +98,7 @@ func EncodeHTTPSumZeroRequest(_ context.Context, r *http.Request, request interf
}
`
if got, want := str, desired; got != want {
if got, want := strings.TrimSpace(str), strings.TrimSpace(desired); got != want {
t.Errorf("Generated code differs from result.\ngot = %s\nwant = %s", got, want)
t.Log(gentesthelper.DiffStrings(got, want))
}
@ -202,7 +203,7 @@ func DecodeHTTPSumZeroRequest(_ context.Context, r *http.Request) (interface{},
}
`
if got, want := str, desired; got != want {
if got, want := strings.TrimSpace(str), strings.TrimSpace(desired); got != want {
t.Errorf("Generated code differs from result.\ngot = %s\nwant = %s", got, want)
t.Log(gentesthelper.DiffStrings(got, want))
}