From fc8783a93340d38378bdf6702f8ad56e26d9171a Mon Sep 17 00:00:00 2001 From: Miroslav Lichvar Date: Wed, 13 Jan 2021 13:36:13 +0100 Subject: [PATCH] socket: check length of received control messages Make sure each processed control messages has the expected length. Beside improved safety, this should prevent potential issues with broken timestamps on systems that support both 64-bit and 32-bit time_t. --- socket.c | 33 +++++++++++++++++++++++---------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/socket.c b/socket.c index 3551228..4d5c578 100644 --- a/socket.c +++ b/socket.c @@ -737,6 +737,17 @@ init_message_nonaddress(SCK_Message *message) /* ================================================== */ +static int +match_cmsg(struct cmsghdr *cmsg, int level, int type, size_t length) +{ + if (cmsg->cmsg_type == type && cmsg->cmsg_level == level && + (length == 0 || cmsg->cmsg_len == CMSG_LEN(length))) + return 1; + return 0; +} + +/* ================================================== */ + static int process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, SCK_Message *message) @@ -795,7 +806,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, for (cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { #ifdef HAVE_IN_PKTINFO - if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { + if (match_cmsg(cmsg, IPPROTO_IP, IP_PKTINFO, sizeof (struct in_pktinfo))) { struct in_pktinfo ipi; if (message->addr_type != SCK_ADDR_IP) @@ -807,7 +818,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, message->if_index = ipi.ipi_ifindex; } #elif defined(IP_RECVDSTADDR) - if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_RECVDSTADDR) { + if (match_cmsg(cmsg, IPPROTO_IP, IP_RECVDSTADDR, sizeof (struct in_addr))) { struct in_addr addr; if (message->addr_type != SCK_ADDR_IP) @@ -820,7 +831,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef HAVE_IN6_PKTINFO - if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { + if (match_cmsg(cmsg, IPPROTO_IPV6, IPV6_PKTINFO, sizeof (struct in6_pktinfo))) { struct in6_pktinfo ipi; if (message->addr_type != SCK_ADDR_IP) @@ -835,7 +846,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef SCM_TIMESTAMP - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMP) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMP, sizeof (struct timeval))) { struct timeval tv; memcpy(&tv, CMSG_DATA(cmsg), sizeof (tv)); @@ -844,14 +855,15 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, #endif #ifdef SCM_TIMESTAMPNS - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPNS) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPNS, sizeof (message->timestamp.kernel))) { memcpy(&message->timestamp.kernel, CMSG_DATA(cmsg), sizeof (message->timestamp.kernel)); } #endif #ifdef HAVE_LINUX_TIMESTAMPING #ifdef HAVE_LINUX_TIMESTAMPING_OPT_PKTINFO - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING_PKTINFO) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING_PKTINFO, + sizeof (struct scm_ts_pktinfo))) { struct scm_ts_pktinfo ts_pktinfo; memcpy(&ts_pktinfo, CMSG_DATA(cmsg), sizeof (ts_pktinfo)); @@ -860,7 +872,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, } #endif - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING, sizeof (struct scm_timestamping))) { struct scm_timestamping ts3; memcpy(&ts3, CMSG_DATA(cmsg), sizeof (ts3)); @@ -868,8 +880,9 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, message->timestamp.hw = ts3.ts[2]; } - if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) || - (cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) { + if ((match_cmsg(cmsg, SOL_IP, IP_RECVERR, 0) || + match_cmsg(cmsg, SOL_IPV6, IPV6_RECVERR, 0)) && + cmsg->cmsg_len >= CMSG_LEN(sizeof (struct sock_extended_err))) { struct sock_extended_err err; memcpy(&err, CMSG_DATA(cmsg), sizeof (err)); @@ -882,7 +895,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, } #endif - if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + if (match_cmsg(cmsg, SOL_SOCKET, SCM_RIGHTS, 0)) { if (!(flags & SCK_FLAG_MSG_DESCRIPTOR) || cmsg->cmsg_len != CMSG_LEN(sizeof (int))) { int i, fd;