diff --git a/portfwd.c b/portfwd.c index 2fbaa299..8e5f087c 100644 --- a/portfwd.c +++ b/portfwd.c @@ -1120,6 +1120,19 @@ bool portfwdmgr_unlisten(PortFwdManager *mgr, const char *host, int port) return true; } +struct portfwdmgr_connect_ctx { + SockAddr *addr; + int port; + char *canonical_hostname; + Conf *conf; +}; +static Socket *portfwdmgr_connect_helper(void *vctx, Plug *plug) +{ + struct portfwdmgr_connect_ctx *ctx = (struct portfwdmgr_connect_ctx *)vctx; + return new_connection(ctx->addr, ctx->canonical_hostname, ctx->port, + false, true, false, false, plug, ctx->conf); +} + /* * Called when receiving a PORT OPEN from the server to make a * connection to a destination host. @@ -1131,26 +1144,39 @@ char *portfwdmgr_connect(PortFwdManager *mgr, Channel **chan_ret, char *hostname, int port, SshChannel *c, int addressfamily) { - SockAddr *addr; - const char *err; - char *dummy_realhost = NULL; - struct PortForwarding *pf; + struct portfwdmgr_connect_ctx ctx[1]; + const char *err_retd; + char *err_toret; /* * Try to find host. */ - addr = name_lookup(hostname, port, &dummy_realhost, mgr->conf, - addressfamily, NULL, NULL); - if ((err = sk_addr_error(addr)) != NULL) { - char *err_ret = dupstr(err); - sk_addr_free(addr); - sfree(dummy_realhost); - return err_ret; + ctx->addr = name_lookup(hostname, port, &ctx->canonical_hostname, + mgr->conf, addressfamily, NULL, NULL); + if ((err_retd = sk_addr_error(ctx->addr)) != NULL) { + err_toret = dupstr(err_retd); + goto out; } - /* - * Open socket. - */ + ctx->conf = mgr->conf; + ctx->port = port; + + err_toret = portfwdmgr_connect_socket( + mgr, chan_ret, portfwdmgr_connect_helper, ctx, c); + + out: + sk_addr_free(ctx->addr); + sfree(ctx->canonical_hostname); + return err_toret; +} + +char *portfwdmgr_connect_socket(PortFwdManager *mgr, Channel **chan_ret, + Socket *(*connect)(void *, Plug *), void *ctx, + SshChannel *c) +{ + struct PortForwarding *pf; + const char *err; + pf = new_portfwd_state(); *chan_ret = &pf->chan; pf->plug.vt = &PortForwarding_plugvt; @@ -1162,9 +1188,7 @@ char *portfwdmgr_connect(PortFwdManager *mgr, Channel **chan_ret, pf->cl = mgr->cl; pf->socks_state = SOCKS_NONE; - pf->s = new_connection(addr, dummy_realhost, port, - false, true, false, false, &pf->plug, mgr->conf); - sfree(dummy_realhost); + pf->s = connect(ctx, &pf->plug); if ((err = sk_socket_error(pf->s)) != NULL) { char *err_ret = dupstr(err); sk_close(pf->s); diff --git a/ssh.h b/ssh.h index 2a023121..e3acf5be 100644 --- a/ssh.h +++ b/ssh.h @@ -380,6 +380,9 @@ void portfwdmgr_close_all(PortFwdManager *mgr); char *portfwdmgr_connect(PortFwdManager *mgr, Channel **chan_ret, char *hostname, int port, SshChannel *c, int addressfamily); +char *portfwdmgr_connect_socket(PortFwdManager *mgr, Channel **chan_ret, + Socket *(*connect)(void *, Plug *), void *ctx, + SshChannel *c); bool portfwdmgr_listen(PortFwdManager *mgr, const char *host, int port, const char *keyhost, int keyport, Conf *conf); bool portfwdmgr_unlisten(PortFwdManager *mgr, const char *host, int port);