diff --git a/net/vmw_vsock/af_vsock.c b/net/vmw_vsock/af_vsock.c index 2828734b7bcf..b12d3a322242 100644 --- a/net/vmw_vsock/af_vsock.c +++ b/net/vmw_vsock/af_vsock.c @@ -421,7 +421,8 @@ static void vsock_deassign_transport(struct vsock_sock *vsk) * The vsk->remote_addr is used to decide which transport to use: * - remote CID == VMADDR_CID_LOCAL or g2h->local_cid or VMADDR_CID_HOST if * g2h is not loaded, will use local transport; - * - remote CID <= VMADDR_CID_HOST will use guest->host transport; + * - remote CID <= VMADDR_CID_HOST or h2g is not loaded or remote flags field + * includes VMADDR_FLAG_TO_HOST flag value, will use guest->host transport; * - remote CID > VMADDR_CID_HOST will use host->guest transport; */ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) @@ -429,6 +430,7 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) const struct vsock_transport *new_transport; struct sock *sk = sk_vsock(vsk); unsigned int remote_cid = vsk->remote_addr.svm_cid; + __u8 remote_flags; int ret; /* If the packet is coming with the source and destination CIDs higher @@ -443,6 +445,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) vsk->remote_addr.svm_cid > VMADDR_CID_HOST) vsk->remote_addr.svm_flags |= VMADDR_FLAG_TO_HOST; + remote_flags = vsk->remote_addr.svm_flags; + switch (sk->sk_type) { case SOCK_DGRAM: new_transport = transport_dgram; @@ -450,7 +454,8 @@ int vsock_assign_transport(struct vsock_sock *vsk, struct vsock_sock *psk) case SOCK_STREAM: if (vsock_use_local_transport(remote_cid)) new_transport = transport_local; - else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g) + else if (remote_cid <= VMADDR_CID_HOST || !transport_h2g || + (remote_flags & VMADDR_FLAG_TO_HOST)) new_transport = transport_g2h; else new_transport = transport_h2g;