diff --git a/api/client/hijack.go b/api/client/hijack.go index adc012bace..617a0b3f61 100644 --- a/api/client/hijack.go +++ b/api/client/hijack.go @@ -2,6 +2,7 @@ package client import ( "crypto/tls" + "errors" "fmt" "io" "net" @@ -10,6 +11,7 @@ import ( "os" "runtime" "strings" + "time" log "github.com/Sirupsen/logrus" "github.com/docker/docker/api" @@ -19,9 +21,99 @@ import ( "github.com/docker/docker/pkg/term" ) +type tlsClientCon struct { + *tls.Conn + rawConn net.Conn +} + +func (c *tlsClientCon) CloseWrite() error { + // Go standard tls.Conn doesn't provide the CloseWrite() method so we do it + // on its underlying connection. + if cwc, ok := c.rawConn.(interface { + CloseWrite() error + }); ok { + return cwc.CloseWrite() + } + return nil +} + +func tlsDial(network, addr string, config *tls.Config) (net.Conn, error) { + return tlsDialWithDialer(new(net.Dialer), network, addr, config) +} + +// We need to copy Go's implementation of tls.Dial (pkg/cryptor/tls/tls.go) in +// order to return our custom tlsClientCon struct which holds both the tls.Conn +// object _and_ its underlying raw connection. The rationale for this is that +// we need to be able to close the write end of the connection when attaching, +// which tls.Conn does not provide. +func tlsDialWithDialer(dialer *net.Dialer, network, addr string, config *tls.Config) (net.Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- errors.New("") + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := *config + c.ServerName = hostname + config = &c + } + + conn := tls.Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() + return nil, err + } + + // This is Docker difference with standard's crypto/tls package: returned a + // wrapper which holds both the TLS and raw connections. + return &tlsClientCon{conn, rawConn}, nil +} + func (cli *DockerCli) dial() (net.Conn, error) { if cli.tlsConfig != nil && cli.proto != "unix" { - return tls.Dial(cli.proto, cli.addr, cli.tlsConfig) + // Notice this isn't Go standard's tls.Dial function + return tlsDial(cli.proto, cli.addr, cli.tlsConfig) } return net.Dial(cli.proto, cli.addr) } @@ -109,12 +201,11 @@ func (cli *DockerCli) hijack(method, path string, setRawTerminal bool, in io.Rea io.Copy(rwc, in) log.Debugf("[hijack] End of stdin") } - if tcpc, ok := rwc.(*net.TCPConn); ok { - if err := tcpc.CloseWrite(); err != nil { - log.Debugf("Couldn't send EOF: %s", err) - } - } else if unixc, ok := rwc.(*net.UnixConn); ok { - if err := unixc.CloseWrite(); err != nil { + + if conn, ok := rwc.(interface { + CloseWrite() error + }); ok { + if err := conn.CloseWrite(); err != nil { log.Debugf("Couldn't send EOF: %s", err) } }