[v5,07/19] vhost: add number of fds to vhost-user messages and use it
Checks
Commit Message
As soons as some anciliarry datai (fds) are received, it is copied
without checking its length.
This patch adds adds the number of fds received to the message,
which is set in read_vhost_message().
This is preliminary work to support sending fds to Qemu.
Signed-off-by: Dr. David Alan Gilbert <dgilbert@redhat.com>
Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
---
lib/librte_vhost/socket.c | 25 ++++++++++++++++++++-----
lib/librte_vhost/vhost_user.c | 2 +-
lib/librte_vhost/vhost_user.h | 4 +++-
3 files changed, 24 insertions(+), 7 deletions(-)
Comments
On Tue, Oct 09, 2018 at 10:54:14PM +0200, Maxime Coquelin wrote:
> As soons as some anciliarry datai (fds) are received, it is copied
typo: soons anciliarry datai
> without checking its length.
>
> This patch adds adds the number of fds received to the message,
s/adds adds/adds/
> which is set in read_vhost_message().
>
> This is preliminary work to support sending fds to Qemu.
>
> Signed-off-by: Dr. David Alan Gilbert <dgilbert@redhat.com>
> Signed-off-by: Maxime Coquelin <maxime.coquelin@redhat.com>
> ---
> lib/librte_vhost/socket.c | 25 ++++++++++++++++++++-----
> lib/librte_vhost/vhost_user.c | 2 +-
> lib/librte_vhost/vhost_user.h | 4 +++-
> 3 files changed, 24 insertions(+), 7 deletions(-)
>
> diff --git a/lib/librte_vhost/socket.c b/lib/librte_vhost/socket.c
> index d63031747..7cad5593e 100644
> --- a/lib/librte_vhost/socket.c
> +++ b/lib/librte_vhost/socket.c
> @@ -94,18 +94,24 @@ static struct vhost_user vhost_user = {
> .mutex = PTHREAD_MUTEX_INITIALIZER,
> };
>
> -/* return bytes# of read on success or negative val on failure. */
> +/*
> + * return bytes# of read on success or negative val on failure. Update fdnum
> + * with number of fds read.
> + */
> int
> -read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
> +read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
> + int *fd_num)
> {
> struct iovec iov;
> struct msghdr msgh;
> - size_t fdsize = fd_num * sizeof(int);
> - char control[CMSG_SPACE(fdsize)];
> + char control[CMSG_SPACE(max_fds * sizeof(int))];
> struct cmsghdr *cmsg;
> int got_fds = 0;
> + int *tmp_fds;
> int ret;
>
> + *fd_num = 0;
> +
> memset(&msgh, 0, sizeof(msgh));
> iov.iov_base = buf;
> iov.iov_len = buflen;
> @@ -131,13 +137,22 @@ read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
> if ((cmsg->cmsg_level == SOL_SOCKET) &&
> (cmsg->cmsg_type == SCM_RIGHTS)) {
> got_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
> + if (got_fds > max_fds) {
> + RTE_LOG(ERR, VHOST_CONFIG,
> + "Received msg contains more fds than supported\n");
I think it's preferred to keep the code aligned
with tab size set to 8. Above code looks like
this when tab size is 8:
¦ ¦ ¦ ¦ RTE_LOG(ERR, VHOST_CONFIG,
¦ ¦ ¦ ¦ ¦ ¦ "Received msg contains more fds than supported\n");
Something like this is better:
¦ ¦ ¦ ¦ RTE_LOG(ERR, VHOST_CONFIG,
¦ ¦ ¦ ¦ ¦ "Received msg contains more fds than supported\n");
There are some other similar cases in this series,
please also fix them. It should be quite easy to
find them by setting the tab size to 8. Thanks!
> + tmp_fds = (int *)CMSG_DATA(cmsg);
> + while (got_fds--)
> + close(tmp_fds[got_fds]);
> + return -1;
> + }
> + *fd_num = got_fds;
> memcpy(fds, CMSG_DATA(cmsg), got_fds * sizeof(int));
> break;
> }
> }
>
> /* Clear out unused file descriptors */
> - while (got_fds < fd_num)
> + while (got_fds < max_fds)
> fds[got_fds++] = -1;
>
> return ret;
> diff --git a/lib/librte_vhost/vhost_user.c b/lib/librte_vhost/vhost_user.c
> index 83d3e6321..c1c5f35ff 100644
> --- a/lib/librte_vhost/vhost_user.c
> +++ b/lib/librte_vhost/vhost_user.c
> @@ -1509,7 +1509,7 @@ read_vhost_message(int sockfd, struct VhostUserMsg *msg)
> int ret;
>
> ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
> - msg->fds, VHOST_MEMORY_MAX_NREGIONS);
> + msg->fds, VHOST_MEMORY_MAX_NREGIONS, &msg->fd_num);
> if (ret <= 0)
> return ret;
>
> diff --git a/lib/librte_vhost/vhost_user.h b/lib/librte_vhost/vhost_user.h
> index 62654f736..9a91d496b 100644
> --- a/lib/librte_vhost/vhost_user.h
> +++ b/lib/librte_vhost/vhost_user.h
> @@ -132,6 +132,7 @@ typedef struct VhostUserMsg {
> VhostUserVringArea area;
> } payload;
> int fds[VHOST_MEMORY_MAX_NREGIONS];
> + int fd_num;
> } __attribute((packed)) VhostUserMsg;
>
> #define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
> @@ -155,7 +156,8 @@ int vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm);
> int vhost_user_host_notifier_ctrl(int vid, bool enable);
>
> /* socket.c */
> -int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
> +int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
> + int *fd_num);
> int send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
>
> #endif
> --
> 2.17.1
>
@@ -94,18 +94,24 @@ static struct vhost_user vhost_user = {
.mutex = PTHREAD_MUTEX_INITIALIZER,
};
-/* return bytes# of read on success or negative val on failure. */
+/*
+ * return bytes# of read on success or negative val on failure. Update fdnum
+ * with number of fds read.
+ */
int
-read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
+read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
+ int *fd_num)
{
struct iovec iov;
struct msghdr msgh;
- size_t fdsize = fd_num * sizeof(int);
- char control[CMSG_SPACE(fdsize)];
+ char control[CMSG_SPACE(max_fds * sizeof(int))];
struct cmsghdr *cmsg;
int got_fds = 0;
+ int *tmp_fds;
int ret;
+ *fd_num = 0;
+
memset(&msgh, 0, sizeof(msgh));
iov.iov_base = buf;
iov.iov_len = buflen;
@@ -131,13 +137,22 @@ read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num)
if ((cmsg->cmsg_level == SOL_SOCKET) &&
(cmsg->cmsg_type == SCM_RIGHTS)) {
got_fds = (cmsg->cmsg_len - CMSG_LEN(0)) / sizeof(int);
+ if (got_fds > max_fds) {
+ RTE_LOG(ERR, VHOST_CONFIG,
+ "Received msg contains more fds than supported\n");
+ tmp_fds = (int *)CMSG_DATA(cmsg);
+ while (got_fds--)
+ close(tmp_fds[got_fds]);
+ return -1;
+ }
+ *fd_num = got_fds;
memcpy(fds, CMSG_DATA(cmsg), got_fds * sizeof(int));
break;
}
}
/* Clear out unused file descriptors */
- while (got_fds < fd_num)
+ while (got_fds < max_fds)
fds[got_fds++] = -1;
return ret;
@@ -1509,7 +1509,7 @@ read_vhost_message(int sockfd, struct VhostUserMsg *msg)
int ret;
ret = read_fd_message(sockfd, (char *)msg, VHOST_USER_HDR_SIZE,
- msg->fds, VHOST_MEMORY_MAX_NREGIONS);
+ msg->fds, VHOST_MEMORY_MAX_NREGIONS, &msg->fd_num);
if (ret <= 0)
return ret;
@@ -132,6 +132,7 @@ typedef struct VhostUserMsg {
VhostUserVringArea area;
} payload;
int fds[VHOST_MEMORY_MAX_NREGIONS];
+ int fd_num;
} __attribute((packed)) VhostUserMsg;
#define VHOST_USER_HDR_SIZE offsetof(VhostUserMsg, payload.u64)
@@ -155,7 +156,8 @@ int vhost_user_iotlb_miss(struct virtio_net *dev, uint64_t iova, uint8_t perm);
int vhost_user_host_notifier_ctrl(int vid, bool enable);
/* socket.c */
-int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
+int read_fd_message(int sockfd, char *buf, int buflen, int *fds, int max_fds,
+ int *fd_num);
int send_fd_message(int sockfd, char *buf, int buflen, int *fds, int fd_num);
#endif