[v17.11,v2,2/4] vhost: add number of fds to vhost-user messages

Message ID 20191112151927.27418-2-maxime.coquelin@redhat.com (mailing list archive)
State Not Applicable, archived
Delegated to: Maxime Coquelin
Headers
Series [v17.11,v2,1/4] vhost: validate virtqueue size |

Commit Message

Maxime Coquelin Nov. 12, 2019, 3:19 p.m. UTC
  As soon as some ancillary data (fds) are received, it is copied
without checking its length.

This patch adds the number of fds received to the message,
which is set in read_vhost_message().

This is preliminary work to support sending fds to Qemu.

Signed-off-by: Dr. David Alan Gilbert <dgilbert@redhat.com>
Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
(cherry picked from commit c00bb88d35fe975ede0ea35bdf4f765a2cece7e8)
Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
---
 lib/librte_vhost/socket.c     | 22 +++++++++++++++++-----
 lib/librte_vhost/vhost_user.c |  2 +-
 lib/librte_vhost/vhost_user.h |  4 +++-
 3 files changed, 21 insertions(+), 7 deletions(-)
  

Patch

diff --git a/lib/librte_vhost/socket.c b/lib/librte_vhost/socket.c
index 88be697c2f..2fa7ea0e09 100644
--- a/lib/librte_vhost/socket.c
+++ b/lib/librte_vhost/socket.c
@@ -117,17 +117,23 @@  static struct vhost_user vhost_user = {
 	.mutex = PTHREAD_MUTEX_INITIALIZER,
 };
 
-/* return bytes# of read on success or negative val on failure. */
+/*
+ * return bytes# of read on success or negative val on failure. Update fdnum
+ * with number of fds read.
+ */
 int
-read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
+read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
+		int *fd_num)
 {
 	struct iovec iov;
 	struct msghdr msgh;
-	size_t fdsize = fd_num * sizeof(int);
-	char control[CMSG_SPACE(fdsize)];
+	char control[CMSG_SPACE(max_fds * sizeof(int))];
 	struct cmsghdr *cmsg;
+	int got_fds = 0;
 	int ret;
 
+	*fd_num = 0;
+
 	memset(&msgh, 0, sizeof(msgh));
 	iov.iov_base = buf;
 	iov.iov_len  = buflen;
@@ -152,11 +158,17 @@  read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
 		cmsg = CMSG_NXTHDR(&msgh, cmsg)) {
 		if ((cmsg->cmsg_level == SOL_SOCKET) &&
 			(cmsg->cmsg_type == SCM_RIGHTS)) {
-			memcpy(fds, CMSG_DATA(cmsg), fdsize);
+			got_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+			*fd_num = got_fds;
+			memcpy(fds, CMSG_DATA(cmsg), got_fds * sizeof(int));
 			break;
 		}
 	}
 
+	/* Clear out unused file descriptors */
+	while (got_fds < max_fds)
+		fds[got_fds++] = -1;
+
 	return ret;
 }
 
diff --git a/lib/librte_vhost/vhost_user.c b/lib/librte_vhost/vhost_user.c
index 93e871c5bb..9773691097 100644
--- a/lib/librte_vhost/vhost_user.c
+++ b/lib/librte_vhost/vhost_user.c
@@ -1248,7 +1248,7 @@  read_vhost_message(int sockfd, struct VhostUserMsg *msg)
 	int ret;
 
 	ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
-		msg->fds, VHOST_MEMORY_MAX_NREGIONS);
+		msg->fds, VHOST_MEMORY_MAX_NREGIONS, &msg->fd_num);
 	if (ret <= 0)
 		return ret;
 
diff --git a/lib/librte_vhost/vhost_user.h b/lib/librte_vhost/vhost_user.h
index 76d9fe2fc5..10b8cc5e1a 100644
--- a/lib/librte_vhost/vhost_user.h
+++ b/lib/librte_vhost/vhost_user.h
@@ -130,6 +130,7 @@  typedef struct VhostUserMsg {
 		struct vhost_iotlb_msg iotlb;
 	} payload;
 	int fds[VHOST_MEMORY_MAX_NREGIONS];
+	int fd_num;
 } __attribute((packed)) VhostUserMsg;
 
 #define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
@@ -143,7 +144,8 @@  int vhost_user_msg_handler(int vid, int fd);
 int vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm);
 
 /* socket.c */
-int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
+int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
+		int *fd_num);
 int send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
 
 #endif