diff --git a/arch/um/drivers/virtio_uml.c b/arch/um/drivers/virtio_uml.c index ba562d68dc04..82ff3785bf69 100644 --- a/arch/um/drivers/virtio_uml.c +++ b/arch/um/drivers/virtio_uml.c @@ -63,6 +63,7 @@ struct virtio_uml_device { u8 config_changed_irq:1; uint64_t vq_irq_vq_map; + int recv_rc; }; struct virtio_uml_vq_info { @@ -148,14 +149,6 @@ static int vhost_user_recv(struct virtio_uml_device *vu_dev, rc = vhost_user_recv_header(fd, msg); - if (rc == -ECONNRESET && vu_dev->registered) { - struct virtio_uml_platform_data *pdata; - - pdata = vu_dev->pdata; - - virtio_break_device(&vu_dev->vdev); - schedule_work(&pdata->conn_broken_wk); - } if (rc) return rc; size = msg->header.size; @@ -164,6 +157,21 @@ static int vhost_user_recv(struct virtio_uml_device *vu_dev, return full_read(fd, &msg->payload, size, false); } +static void vhost_user_check_reset(struct virtio_uml_device *vu_dev, + int rc) +{ + struct virtio_uml_platform_data *pdata = vu_dev->pdata; + + if (rc != -ECONNRESET) + return; + + if (!vu_dev->registered) + return; + + virtio_break_device(&vu_dev->vdev); + schedule_work(&pdata->conn_broken_wk); +} + static int vhost_user_recv_resp(struct virtio_uml_device *vu_dev, struct vhost_user_msg *msg, size_t max_payload_size) @@ -171,8 +179,10 @@ static int vhost_user_recv_resp(struct virtio_uml_device *vu_dev, int rc = vhost_user_recv(vu_dev, vu_dev->sock, msg, max_payload_size, true); - if (rc) + if (rc) { + vhost_user_check_reset(vu_dev, rc); return rc; + } if (msg->header.flags != (VHOST_USER_FLAG_REPLY | VHOST_USER_VERSION)) return -EPROTO; @@ -369,6 +379,7 @@ static irqreturn_t vu_req_read_message(struct virtio_uml_device *vu_dev, sizeof(msg.msg.payload) + sizeof(msg.extra_payload)); + vu_dev->recv_rc = rc; if (rc) return IRQ_NONE; @@ -412,7 +423,9 @@ static irqreturn_t vu_req_interrupt(int irq, void *data) if (!um_irq_timetravel_handler_used()) ret = vu_req_read_message(vu_dev, NULL); - if (vu_dev->vq_irq_vq_map) { + if (vu_dev->recv_rc) { + vhost_user_check_reset(vu_dev, vu_dev->recv_rc); + } else if (vu_dev->vq_irq_vq_map) { struct virtqueue *vq; virtio_device_for_each_vq((&vu_dev->vdev), vq) {