diff --git a/drivers/hv/hv_utils_transport.c b/drivers/hv/hv_utils_transport.c index 59c6f3d29f9a..31c2f8649271 100644 --- a/drivers/hv/hv_utils_transport.c +++ b/drivers/hv/hv_utils_transport.c @@ -27,11 +27,9 @@ static struct list_head hvt_list = LIST_HEAD_INIT(hvt_list); static void hvt_reset(struct hvutil_transport *hvt) { - mutex_lock(&hvt->lock); kfree(hvt->outmsg); hvt->outmsg = NULL; hvt->outmsg_len = 0; - mutex_unlock(&hvt->lock); if (hvt->on_reset) hvt->on_reset(); } @@ -44,10 +42,17 @@ static ssize_t hvt_op_read(struct file *file, char __user *buf, hvt = container_of(file->f_op, struct hvutil_transport, fops); - if (wait_event_interruptible(hvt->outmsg_q, hvt->outmsg_len > 0)) + if (wait_event_interruptible(hvt->outmsg_q, hvt->outmsg_len > 0 || + hvt->mode != HVUTIL_TRANSPORT_CHARDEV)) return -EINTR; mutex_lock(&hvt->lock); + + if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) { + ret = -EBADF; + goto out_unlock; + } + if (!hvt->outmsg) { ret = -EAGAIN; goto out_unlock; @@ -85,7 +90,10 @@ static ssize_t hvt_op_write(struct file *file, const char __user *buf, if (IS_ERR(inmsg)) return PTR_ERR(inmsg); - ret = hvt->on_msg(inmsg, count); + if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) + ret = -EBADF; + else + ret = hvt->on_msg(inmsg, count); kfree(inmsg); @@ -99,6 +107,10 @@ static unsigned int hvt_op_poll(struct file *file, poll_table *wait) hvt = container_of(file->f_op, struct hvutil_transport, fops); poll_wait(file, &hvt->outmsg_q, wait); + + if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) + return -EBADF; + if (hvt->outmsg_len > 0) return POLLIN | POLLRDNORM; @@ -108,26 +120,39 @@ static unsigned int hvt_op_poll(struct file *file, poll_table *wait) static int hvt_op_open(struct inode *inode, struct file *file) { struct hvutil_transport *hvt; + int ret = 0; + bool issue_reset = false; hvt = container_of(file->f_op, struct hvutil_transport, fops); - /* - * Switching to CHARDEV mode. We switch bach to INIT when device - * gets released. - */ - if (hvt->mode == HVUTIL_TRANSPORT_INIT) + mutex_lock(&hvt->lock); + + if (hvt->mode == HVUTIL_TRANSPORT_DESTROY) { + ret = -EBADF; + } else if (hvt->mode == HVUTIL_TRANSPORT_INIT) { + /* + * Switching to CHARDEV mode. We switch bach to INIT when + * device gets released. + */ hvt->mode = HVUTIL_TRANSPORT_CHARDEV; + } else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) { /* * We're switching from netlink communication to using char * device. Issue the reset first. */ - hvt_reset(hvt); + issue_reset = true; hvt->mode = HVUTIL_TRANSPORT_CHARDEV; - } else - return -EBUSY; + } else { + ret = -EBUSY; + } - return 0; + if (issue_reset) + hvt_reset(hvt); + + mutex_unlock(&hvt->lock); + + return ret; } static int hvt_op_release(struct inode *inode, struct file *file) @@ -136,12 +161,15 @@ static int hvt_op_release(struct inode *inode, struct file *file) hvt = container_of(file->f_op, struct hvutil_transport, fops); - hvt->mode = HVUTIL_TRANSPORT_INIT; + mutex_lock(&hvt->lock); + if (hvt->mode != HVUTIL_TRANSPORT_DESTROY) + hvt->mode = HVUTIL_TRANSPORT_INIT; /* * Cleanup message buffers to avoid spurious messages when the daemon * connects back. */ hvt_reset(hvt); + mutex_unlock(&hvt->lock); return 0; } @@ -168,6 +196,7 @@ static void hvt_cn_callback(struct cn_msg *msg, struct netlink_skb_parms *nsp) * Switching to NETLINK mode. Switching to CHARDEV happens when someone * opens the device. */ + mutex_lock(&hvt->lock); if (hvt->mode == HVUTIL_TRANSPORT_INIT) hvt->mode = HVUTIL_TRANSPORT_NETLINK; @@ -175,6 +204,7 @@ static void hvt_cn_callback(struct cn_msg *msg, struct netlink_skb_parms *nsp) hvt_found->on_msg(msg->data, msg->len); else pr_warn("hvt_cn_callback: unexpected netlink message!\n"); + mutex_unlock(&hvt->lock); } int hvutil_transport_send(struct hvutil_transport *hvt, void *msg, int len) @@ -182,7 +212,8 @@ int hvutil_transport_send(struct hvutil_transport *hvt, void *msg, int len) struct cn_msg *cn_msg; int ret = 0; - if (hvt->mode == HVUTIL_TRANSPORT_INIT) { + if (hvt->mode == HVUTIL_TRANSPORT_INIT || + hvt->mode == HVUTIL_TRANSPORT_DESTROY) { return -EINVAL; } else if (hvt->mode == HVUTIL_TRANSPORT_NETLINK) { cn_msg = kzalloc(sizeof(*cn_msg) + len, GFP_ATOMIC); @@ -198,6 +229,11 @@ int hvutil_transport_send(struct hvutil_transport *hvt, void *msg, int len) } /* HVUTIL_TRANSPORT_CHARDEV */ mutex_lock(&hvt->lock); + if (hvt->mode != HVUTIL_TRANSPORT_CHARDEV) { + ret = -EINVAL; + goto out_unlock; + } + if (hvt->outmsg) { /* Previous message wasn't received */ ret = -EFAULT; @@ -268,6 +304,11 @@ err_free_hvt: void hvutil_transport_destroy(struct hvutil_transport *hvt) { + mutex_lock(&hvt->lock); + hvt->mode = HVUTIL_TRANSPORT_DESTROY; + wake_up_interruptible(&hvt->outmsg_q); + mutex_unlock(&hvt->lock); + spin_lock(&hvt_list_lock); list_del(&hvt->list); spin_unlock(&hvt_list_lock); diff --git a/drivers/hv/hv_utils_transport.h b/drivers/hv/hv_utils_transport.h index bff4c921e817..06254a165a18 100644 --- a/drivers/hv/hv_utils_transport.h +++ b/drivers/hv/hv_utils_transport.h @@ -25,6 +25,7 @@ enum hvutil_transport_mode { HVUTIL_TRANSPORT_INIT = 0, HVUTIL_TRANSPORT_NETLINK, HVUTIL_TRANSPORT_CHARDEV, + HVUTIL_TRANSPORT_DESTROY, }; struct hvutil_transport {