[dpdk-dev,RFC,01/19] vhost: protect virtio_net device struct

Message ID 20170704094922.11405-2-maxime.coquelin@redhat.com (mailing list archive)
State Superseded, archived
Delegated to: Yuanhan Liu
Headers

Checks

Context Check Description
ci/checkpatch success coding style OK
ci/Intel-compilation success Compilation OK

Commit Message

Maxime Coquelin July 4, 2017, 9:49 a.m. UTC
  virtio_net device might be accessed while being reallocated
in case of NUMA awareness. This case might be theoretical,
but it will be needed anyway to protect vrings pages against
invalidation.

The virtio_net devs are now protected with a readers/writers
lock, so that before reallocating the device, it is ensured
that it is not being referenced by the processing threads.

Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
---
 lib/librte_vhost/vhost.c      | 223 +++++++++++++++++++++++++++++++++++-------
 lib/librte_vhost/vhost.h      |   3 +-
 lib/librte_vhost/vhost_user.c |  73 +++++---------
 lib/librte_vhost/virtio_net.c |  17 +++-
 4 files changed, 228 insertions(+), 88 deletions(-)
  

Comments

Jens Freimann July 5, 2017, 10:07 a.m. UTC | #1
On Tue, Jul 04, 2017 at 11:49:04AM +0200, Maxime Coquelin wrote:
>virtio_net device might be accessed while being reallocated
>in case of NUMA awareness. This case might be theoretical,
>but it will be needed anyway to protect vrings pages against
>invalidation.
>
>The virtio_net devs are now protected with a readers/writers
>lock, so that before reallocating the device, it is ensured
>that it is not being referenced by the processing threads.
>
>Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
>---
> lib/librte_vhost/vhost.c      | 223 +++++++++++++++++++++++++++++++++++-------
> lib/librte_vhost/vhost.h      |   3 +-
> lib/librte_vhost/vhost_user.c |  73 +++++---------
> lib/librte_vhost/virtio_net.c |  17 +++-
> 4 files changed, 228 insertions(+), 88 deletions(-)
[...]
>+int
>+realloc_device(int vid, int vq_index, int node)
>+{
>+	struct virtio_net *dev, *old_dev;
>+	struct vhost_virtqueue *vq;
>+
>+	dev = rte_malloc_socket(NULL, sizeof(*dev), 0, node);
>+	if (!dev)
>+		return -1;
>+
>+	vq = rte_malloc_socket(NULL, sizeof(*vq), 0, node);
>+	if (!vq)
>+		return -1;
>+
>+	old_dev = get_device_wr(vid);
>+	if (!old_dev)
>+		return -1;

Should we free vq and dev here?

regards,
Jens
  
Maxime Coquelin July 7, 2017, 7:31 a.m. UTC | #2
On 07/05/2017 12:07 PM, Jens Freimann wrote:
> On Tue, Jul 04, 2017 at 11:49:04AM +0200, Maxime Coquelin wrote:
>> virtio_net device might be accessed while being reallocated
>> in case of NUMA awareness. This case might be theoretical,
>> but it will be needed anyway to protect vrings pages against
>> invalidation.
>>
>> The virtio_net devs are now protected with a readers/writers
>> lock, so that before reallocating the device, it is ensured
>> that it is not being referenced by the processing threads.
>>
>> Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
>> ---
>> lib/librte_vhost/vhost.c      | 223 
>> +++++++++++++++++++++++++++++++++++-------
>> lib/librte_vhost/vhost.h      |   3 +-
>> lib/librte_vhost/vhost_user.c |  73 +++++---------
>> lib/librte_vhost/virtio_net.c |  17 +++-
>> 4 files changed, 228 insertions(+), 88 deletions(-)
> [...]
>> +int
>> +realloc_device(int vid, int vq_index, int node)
>> +{
>> +    struct virtio_net *dev, *old_dev;
>> +    struct vhost_virtqueue *vq;
>> +
>> +    dev = rte_malloc_socket(NULL, sizeof(*dev), 0, node);
>> +    if (!dev)
>> +        return -1;
>> +
>> +    vq = rte_malloc_socket(NULL, sizeof(*vq), 0, node);
>> +    if (!vq)
>> +        return -1;
>> +
>> +    old_dev = get_device_wr(vid);
>> +    if (!old_dev)
>> +        return -1;
> 
> Should we free vq and dev here?

Of course we should.
This will be fixed in next release.

Thanks,
Maxime
  

Patch

diff --git a/lib/librte_vhost/vhost.c b/lib/librte_vhost/vhost.c
index 19c5a43..2a4bc91 100644
--- a/lib/librte_vhost/vhost.c
+++ b/lib/librte_vhost/vhost.c
@@ -45,16 +45,25 @@ 
 #include <rte_string_fns.h>
 #include <rte_memory.h>
 #include <rte_malloc.h>
+#include <rte_rwlock.h>
 #include <rte_vhost.h>
 
 #include "vhost.h"
 
-struct virtio_net *vhost_devices[MAX_VHOST_DEVICE];
+struct vhost_device {
+	struct virtio_net *dev;
+	rte_rwlock_t lock;
+};
 
-struct virtio_net *
-get_device(int vid)
+/* Declared as static so that .lock is initialized */
+static struct vhost_device vhost_devices[MAX_VHOST_DEVICE];
+
+static inline struct virtio_net *
+__get_device(int vid)
 {
-	struct virtio_net *dev = vhost_devices[vid];
+	struct virtio_net *dev;
+
+	dev = vhost_devices[vid].dev;
 
 	if (unlikely(!dev)) {
 		RTE_LOG(ERR, VHOST_CONFIG,
@@ -64,6 +73,78 @@  get_device(int vid)
 	return dev;
 }
 
+struct virtio_net *
+get_device(int vid)
+{
+	struct virtio_net *dev;
+
+	rte_rwlock_read_lock(&vhost_devices[vid].lock);
+
+	dev = __get_device(vid);
+	if (unlikely(!dev))
+		rte_rwlock_read_unlock(&vhost_devices[vid].lock);
+
+	return dev;
+}
+
+void
+put_device(int vid)
+{
+	rte_rwlock_read_unlock(&vhost_devices[vid].lock);
+}
+
+static struct virtio_net *
+get_device_wr(int vid)
+{
+	struct virtio_net *dev;
+
+	rte_rwlock_write_lock(&vhost_devices[vid].lock);
+
+	dev = __get_device(vid);
+	if (unlikely(!dev))
+		rte_rwlock_write_unlock(&vhost_devices[vid].lock);
+
+	return dev;
+}
+
+static void
+put_device_wr(int vid)
+{
+	rte_rwlock_write_unlock(&vhost_devices[vid].lock);
+}
+
+int
+realloc_device(int vid, int vq_index, int node)
+{
+	struct virtio_net *dev, *old_dev;
+	struct vhost_virtqueue *vq;
+
+	dev = rte_malloc_socket(NULL, sizeof(*dev), 0, node);
+	if (!dev)
+		return -1;
+
+	vq = rte_malloc_socket(NULL, sizeof(*vq), 0, node);
+	if (!vq)
+		return -1;
+
+	old_dev = get_device_wr(vid);
+	if (!old_dev)
+		return -1;
+
+	memcpy(dev, old_dev, sizeof(*dev));
+	memcpy(vq, old_dev->virtqueue[vq_index], sizeof(*vq));
+	dev->virtqueue[vq_index] = vq;
+
+	rte_free(old_dev->virtqueue[vq_index]);
+	rte_free(old_dev);
+
+	vhost_devices[vid].dev = dev;
+
+	put_device_wr(vid);
+
+	return 0;
+}
+
 static void
 cleanup_vq(struct vhost_virtqueue *vq, int destroy)
 {
@@ -194,7 +275,7 @@  vhost_new_device(void)
 	}
 
 	for (i = 0; i < MAX_VHOST_DEVICE; i++) {
-		if (vhost_devices[i] == NULL)
+		if (vhost_devices[i].dev == NULL)
 			break;
 	}
 	if (i == MAX_VHOST_DEVICE) {
@@ -204,8 +285,10 @@  vhost_new_device(void)
 		return -1;
 	}
 
-	vhost_devices[i] = dev;
+	rte_rwlock_write_lock(&vhost_devices[i].lock);
+	vhost_devices[i].dev = dev;
 	dev->vid = i;
+	rte_rwlock_write_unlock(&vhost_devices[i].lock);
 
 	return i;
 }
@@ -227,10 +310,15 @@  vhost_destroy_device(int vid)
 		dev->notify_ops->destroy_device(vid);
 	}
 
+	put_device(vid);
+	dev = get_device_wr(vid);
+
 	cleanup_device(dev, 1);
 	free_device(dev);
 
-	vhost_devices[vid] = NULL;
+	vhost_devices[vid].dev = NULL;
+
+	put_device_wr(vid);
 }
 
 void
@@ -248,6 +336,8 @@  vhost_set_ifname(int vid, const char *if_name, unsigned int if_len)
 
 	strncpy(dev->ifname, if_name, len);
 	dev->ifname[sizeof(dev->ifname) - 1] = '\0';
+
+	put_device(vid);
 }
 
 void
@@ -259,25 +349,30 @@  vhost_enable_dequeue_zero_copy(int vid)
 		return;
 
 	dev->dequeue_zero_copy = 1;
+
+	put_device(vid);
 }
 
 int
 rte_vhost_get_mtu(int vid, uint16_t *mtu)
 {
 	struct virtio_net *dev = get_device(vid);
+	int ret = 0;
 
 	if (!dev)
 		return -ENODEV;
 
 	if (!(dev->flags & VIRTIO_DEV_READY))
-		return -EAGAIN;
+		ret = -EAGAIN;
 
 	if (!(dev->features & VIRTIO_NET_F_MTU))
-		return -ENOTSUP;
+		ret = -ENOTSUP;
 
 	*mtu = dev->mtu;
 
-	return 0;
+	put_device(vid);
+
+	return ret;
 }
 
 int
@@ -296,9 +391,11 @@  rte_vhost_get_numa_node(int vid)
 	if (ret < 0) {
 		RTE_LOG(ERR, VHOST_CONFIG,
 			"(%d) failed to query numa node: %d\n", vid, ret);
-		return -1;
+		numa_node = -1;
 	}
 
+	put_device(vid);
+
 	return numa_node;
 #else
 	RTE_SET_USED(vid);
@@ -310,22 +407,32 @@  uint32_t
 rte_vhost_get_queue_num(int vid)
 {
 	struct virtio_net *dev = get_device(vid);
+	uint32_t queue_num;
 
 	if (dev == NULL)
 		return 0;
 
-	return dev->nr_vring / 2;
+	queue_num = dev->nr_vring / 2;
+
+	put_device(vid);
+
+	return queue_num;
 }
 
 uint16_t
 rte_vhost_get_vring_num(int vid)
 {
 	struct virtio_net *dev = get_device(vid);
+	uint16_t vring_num;
 
 	if (dev == NULL)
 		return 0;
 
-	return dev->nr_vring;
+	vring_num = dev->nr_vring;
+
+	put_device(vid);
+
+	return vring_num;
 }
 
 int
@@ -341,6 +448,8 @@  rte_vhost_get_ifname(int vid, char *buf, size_t len)
 	strncpy(buf, dev->ifname, len);
 	buf[len - 1] = '\0';
 
+	put_device(vid);
+
 	return 0;
 }
 
@@ -354,6 +463,9 @@  rte_vhost_get_negotiated_features(int vid, uint64_t *features)
 		return -1;
 
 	*features = dev->features;
+
+	put_device(vid);
+
 	return 0;
 }
 
@@ -363,6 +475,7 @@  rte_vhost_get_mem_table(int vid, struct rte_vhost_memory **mem)
 	struct virtio_net *dev;
 	struct rte_vhost_memory *m;
 	size_t size;
+	int ret = 0;
 
 	dev = get_device(vid);
 	if (!dev)
@@ -370,14 +483,19 @@  rte_vhost_get_mem_table(int vid, struct rte_vhost_memory **mem)
 
 	size = dev->mem->nregions * sizeof(struct rte_vhost_mem_region);
 	m = malloc(sizeof(struct rte_vhost_memory) + size);
-	if (!m)
-		return -1;
+	if (!m) {
+		ret = -1;
+		goto out;
+	}
 
 	m->nregions = dev->mem->nregions;
 	memcpy(m->regions, dev->mem->regions, size);
 	*mem = m;
 
-	return 0;
+out:
+	put_device(vid);
+
+	return ret;
 }
 
 int
@@ -386,17 +504,22 @@  rte_vhost_get_vhost_vring(int vid, uint16_t vring_idx,
 {
 	struct virtio_net *dev;
 	struct vhost_virtqueue *vq;
+	int ret = 0;
 
 	dev = get_device(vid);
 	if (!dev)
 		return -1;
 
-	if (vring_idx >= VHOST_MAX_VRING)
-		return -1;
+	if (vring_idx >= VHOST_MAX_VRING) {
+		ret = -1;
+		goto out;
+	}
 
 	vq = dev->virtqueue[vring_idx];
-	if (!vq)
-		return -1;
+	if (!vq) {
+		ret = -1;
+		goto out;
+	}
 
 	vring->desc  = vq->desc;
 	vring->avail = vq->avail;
@@ -407,7 +530,10 @@  rte_vhost_get_vhost_vring(int vid, uint16_t vring_idx,
 	vring->kickfd  = vq->kickfd;
 	vring->size    = vq->size;
 
-	return 0;
+out:
+	put_device(vid);
+
+	return ret;
 }
 
 uint16_t
@@ -415,6 +541,7 @@  rte_vhost_avail_entries(int vid, uint16_t queue_id)
 {
 	struct virtio_net *dev;
 	struct vhost_virtqueue *vq;
+	uint16_t avail_entries = 0;
 
 	dev = get_device(vid);
 	if (!dev)
@@ -422,15 +549,23 @@  rte_vhost_avail_entries(int vid, uint16_t queue_id)
 
 	vq = dev->virtqueue[queue_id];
 	if (!vq->enabled)
-		return 0;
+		goto out;
+
 
-	return *(volatile uint16_t *)&vq->avail->idx - vq->last_used_idx;
+	avail_entries = *(volatile uint16_t *)&vq->avail->idx;
+	avail_entries -= vq->last_used_idx;
+
+out:
+	put_device(vid);
+
+	return avail_entries;
 }
 
 int
 rte_vhost_enable_guest_notification(int vid, uint16_t queue_id, int enable)
 {
 	struct virtio_net *dev = get_device(vid);
+	int ret = 0;
 
 	if (dev == NULL)
 		return -1;
@@ -438,11 +573,16 @@  rte_vhost_enable_guest_notification(int vid, uint16_t queue_id, int enable)
 	if (enable) {
 		RTE_LOG(ERR, VHOST_CONFIG,
 			"guest notification isn't supported.\n");
-		return -1;
+		ret = -1;
+		goto out;
 	}
 
 	dev->virtqueue[queue_id]->used->flags = VRING_USED_F_NO_NOTIFY;
-	return 0;
+
+out:
+	put_device(vid);
+
+	return ret;
 }
 
 void
@@ -454,6 +594,8 @@  rte_vhost_log_write(int vid, uint64_t addr, uint64_t len)
 		return;
 
 	vhost_log_write(dev, addr, len);
+
+	put_device(vid);
 }
 
 void
@@ -468,12 +610,15 @@  rte_vhost_log_used_vring(int vid, uint16_t vring_idx,
 		return;
 
 	if (vring_idx >= VHOST_MAX_VRING)
-		return;
+		goto out;
 	vq = dev->virtqueue[vring_idx];
 	if (!vq)
-		return;
+		goto out;
 
 	vhost_log_used_vring(dev, vq, offset, len);
+
+out:
+	put_device(vid);
 }
 
 uint32_t
@@ -481,6 +626,7 @@  rte_vhost_rx_queue_count(int vid, uint16_t qid)
 {
 	struct virtio_net *dev;
 	struct vhost_virtqueue *vq;
+	uint32_t queue_count;
 
 	dev = get_device(vid);
 	if (dev == NULL)
@@ -489,15 +635,26 @@  rte_vhost_rx_queue_count(int vid, uint16_t qid)
 	if (unlikely(qid >= dev->nr_vring || (qid & 1) == 0)) {
 		RTE_LOG(ERR, VHOST_DATA, "(%d) %s: invalid virtqueue idx %d.\n",
 			dev->vid, __func__, qid);
-		return 0;
+		queue_count = 0;
+		goto out;
 	}
 
 	vq = dev->virtqueue[qid];
-	if (vq == NULL)
-		return 0;
+	if (vq == NULL) {
+		queue_count = 0;
+		goto out;
+	}
 
-	if (unlikely(vq->enabled == 0 || vq->avail == NULL))
-		return 0;
+	if (unlikely(vq->enabled == 0 || vq->avail == NULL)) {
+		queue_count = 0;
+		goto out;
+	}
+
+	queue_count = *((volatile uint16_t *)&vq->avail->idx);
+	queue_count -= vq->last_avail_idx;
+
+out:
+	put_device(vid);
 
-	return *((volatile uint16_t *)&vq->avail->idx) - vq->last_avail_idx;
+	return queue_count;
 }
diff --git a/lib/librte_vhost/vhost.h b/lib/librte_vhost/vhost.h
index 0f294f3..18ad69c 100644
--- a/lib/librte_vhost/vhost.h
+++ b/lib/librte_vhost/vhost.h
@@ -269,7 +269,6 @@  vhost_log_used_vring(struct virtio_net *dev, struct vhost_virtqueue *vq,
 
 extern uint64_t VHOST_FEATURES;
 #define MAX_VHOST_DEVICE	1024
-extern struct virtio_net *vhost_devices[MAX_VHOST_DEVICE];
 
 /* Convert guest physical address to host physical address */
 static __rte_always_inline phys_addr_t
@@ -292,6 +291,8 @@  gpa_to_hpa(struct virtio_net *dev, uint64_t gpa, uint64_t size)
 }
 
 struct virtio_net *get_device(int vid);
+void put_device(int vid);
+int realloc_device(int vid, int vq_index, int node);
 
 int vhost_new_device(void);
 void cleanup_device(struct virtio_net *dev, int destroy);
diff --git a/lib/librte_vhost/vhost_user.c b/lib/librte_vhost/vhost_user.c
index ad2e8d3..5b3b881 100644
--- a/lib/librte_vhost/vhost_user.c
+++ b/lib/librte_vhost/vhost_user.c
@@ -241,62 +241,31 @@  vhost_user_set_vring_num(struct virtio_net *dev,
 static struct virtio_net*
 numa_realloc(struct virtio_net *dev, int index)
 {
-	int oldnode, newnode;
-	struct virtio_net *old_dev;
-	struct vhost_virtqueue *old_vq, *vq;
-	int ret;
+	int oldnode, newnode, vid, ret;
 
-	old_dev = dev;
-	vq = old_vq = dev->virtqueue[index];
+	vid = dev->vid;
 
-	ret = get_mempolicy(&newnode, NULL, 0, old_vq->desc,
+	ret = get_mempolicy(&newnode, NULL, 0, dev->virtqueue[index]->desc,
 			    MPOL_F_NODE | MPOL_F_ADDR);
 
 	/* check if we need to reallocate vq */
-	ret |= get_mempolicy(&oldnode, NULL, 0, old_vq,
+	ret |= get_mempolicy(&oldnode, NULL, 0, dev->virtqueue[index],
 			     MPOL_F_NODE | MPOL_F_ADDR);
 	if (ret) {
 		RTE_LOG(ERR, VHOST_CONFIG,
 			"Unable to get vq numa information.\n");
 		return dev;
 	}
-	if (oldnode != newnode) {
-		RTE_LOG(INFO, VHOST_CONFIG,
-			"reallocate vq from %d to %d node\n", oldnode, newnode);
-		vq = rte_malloc_socket(NULL, sizeof(*vq), 0, newnode);
-		if (!vq)
-			return dev;
-
-		memcpy(vq, old_vq, sizeof(*vq));
-		rte_free(old_vq);
-	}
 
-	/* check if we need to reallocate dev */
-	ret = get_mempolicy(&oldnode, NULL, 0, old_dev,
-			    MPOL_F_NODE | MPOL_F_ADDR);
-	if (ret) {
-		RTE_LOG(ERR, VHOST_CONFIG,
-			"Unable to get dev numa information.\n");
-		goto out;
-	}
 	if (oldnode != newnode) {
 		RTE_LOG(INFO, VHOST_CONFIG,
-			"reallocate dev from %d to %d node\n",
-			oldnode, newnode);
-		dev = rte_malloc_socket(NULL, sizeof(*dev), 0, newnode);
-		if (!dev) {
-			dev = old_dev;
-			goto out;
-		}
-
-		memcpy(dev, old_dev, sizeof(*dev));
-		rte_free(old_dev);
+			"reallocate vq from %d to %d node\n", oldnode, newnode);
+		put_device(vid);
+		if (realloc_device(vid, index, newnode))
+			RTE_LOG(ERR, VHOST_CONFIG, "Failed to realloc device\n");
+		dev = get_device(vid);
 	}
 
-out:
-	dev->virtqueue[index] = vq;
-	vhost_devices[dev->vid] = dev;
-
 	return dev;
 }
 #else
@@ -336,9 +305,10 @@  qva_to_vva(struct virtio_net *dev, uint64_t qva)
  * This function then converts these to our address space.
  */
 static int
-vhost_user_set_vring_addr(struct virtio_net *dev, VhostUserMsg *msg)
+vhost_user_set_vring_addr(struct virtio_net **pdev, VhostUserMsg *msg)
 {
 	struct vhost_virtqueue *vq;
+	struct virtio_net *dev = *pdev;
 
 	if (dev->mem == NULL)
 		return -1;
@@ -356,7 +326,7 @@  vhost_user_set_vring_addr(struct virtio_net *dev, VhostUserMsg *msg)
 		return -1;
 	}
 
-	dev = numa_realloc(dev, msg->payload.addr.index);
+	*pdev = dev = numa_realloc(dev, msg->payload.addr.index);
 	vq = dev->virtqueue[msg->payload.addr.index];
 
 	vq->avail = (struct vring_avail *)(uintptr_t)qva_to_vva(dev,
@@ -966,7 +936,7 @@  vhost_user_msg_handler(int vid, int fd)
 {
 	struct virtio_net *dev;
 	struct VhostUserMsg msg;
-	int ret;
+	int ret = 0;
 
 	dev = get_device(vid);
 	if (dev == NULL)
@@ -978,7 +948,8 @@  vhost_user_msg_handler(int vid, int fd)
 			RTE_LOG(ERR, VHOST_CONFIG,
 				"failed to get callback ops for driver %s\n",
 				dev->ifname);
-			return -1;
+			ret = -1;
+			goto out;
 		}
 	}
 
@@ -994,10 +965,10 @@  vhost_user_msg_handler(int vid, int fd)
 			RTE_LOG(ERR, VHOST_CONFIG,
 				"vhost read incorrect message\n");
 
-		return -1;
+		ret = -1;
+		goto out;
 	}
 
-	ret = 0;
 	RTE_LOG(INFO, VHOST_CONFIG, "read message %s\n",
 		vhost_message_str[msg.request]);
 
@@ -1005,7 +976,8 @@  vhost_user_msg_handler(int vid, int fd)
 	if (ret < 0) {
 		RTE_LOG(ERR, VHOST_CONFIG,
 			"failed to alloc queue\n");
-		return -1;
+		ret = -1;
+		goto out;
 	}
 
 	switch (msg.request) {
@@ -1054,7 +1026,7 @@  vhost_user_msg_handler(int vid, int fd)
 		vhost_user_set_vring_num(dev, &msg);
 		break;
 	case VHOST_USER_SET_VRING_ADDR:
-		vhost_user_set_vring_addr(dev, &msg);
+		vhost_user_set_vring_addr(&dev, &msg);
 		break;
 	case VHOST_USER_SET_VRING_BASE:
 		vhost_user_set_vring_base(dev, &msg);
@@ -1122,5 +1094,8 @@  vhost_user_msg_handler(int vid, int fd)
 		}
 	}
 
-	return 0;
+out:
+	put_device(vid);
+
+	return ret;
 }
diff --git a/lib/librte_vhost/virtio_net.c b/lib/librte_vhost/virtio_net.c
index ebfda1c..726d349 100644
--- a/lib/librte_vhost/virtio_net.c
+++ b/lib/librte_vhost/virtio_net.c
@@ -587,14 +587,19 @@  rte_vhost_enqueue_burst(int vid, uint16_t queue_id,
 	struct rte_mbuf **pkts, uint16_t count)
 {
 	struct virtio_net *dev = get_device(vid);
+	int ret = 0;
 
 	if (!dev)
 		return 0;
 
 	if (dev->features & (1 << VIRTIO_NET_F_MRG_RXBUF))
-		return virtio_dev_merge_rx(dev, queue_id, pkts, count);
+		ret = virtio_dev_merge_rx(dev, queue_id, pkts, count);
 	else
-		return virtio_dev_rx(dev, queue_id, pkts, count);
+		ret = virtio_dev_rx(dev, queue_id, pkts, count);
+
+	put_device(vid);
+
+	return ret;
 }
 
 static inline bool
@@ -993,12 +998,12 @@  rte_vhost_dequeue_burst(int vid, uint16_t queue_id,
 	if (unlikely(!is_valid_virt_queue_idx(queue_id, 1, dev->nr_vring))) {
 		RTE_LOG(ERR, VHOST_DATA, "(%d) %s: invalid virtqueue idx %d.\n",
 			dev->vid, __func__, queue_id);
-		return 0;
+		goto out;
 	}
 
 	vq = dev->virtqueue[queue_id];
 	if (unlikely(vq->enabled == 0))
-		return 0;
+		goto out;
 
 	if (unlikely(dev->dequeue_zero_copy)) {
 		struct zcopy_mbuf *zmbuf, *next;
@@ -1048,7 +1053,7 @@  rte_vhost_dequeue_burst(int vid, uint16_t queue_id,
 		if (rarp_mbuf == NULL) {
 			RTE_LOG(ERR, VHOST_DATA,
 				"Failed to allocate memory for mbuf.\n");
-			return 0;
+			goto out;
 		}
 
 		if (make_rarp_packet(rarp_mbuf, &dev->mac)) {
@@ -1167,5 +1172,7 @@  rte_vhost_dequeue_burst(int vid, uint16_t queue_id,
 		i += 1;
 	}
 
+	put_device(vid);
+
 	return i;
 }