socket: check length of received control messages

Make sure each processed control messages has the expected length.
Beside improved safety, this should prevent potential issues with broken
timestamps on systems that support both 64-bit and 32-bit time_t.
This commit is contained in:
Miroslav Lichvar 2021-01-13 13:36:13 +01:00
parent e7897eb9cc
commit fc8783a933

View file

@ -737,6 +737,17 @@ init_message_nonaddress(SCK_Message *message)
/* ================================================== */ /* ================================================== */
static int
match_cmsg(struct cmsghdr *cmsg, int level, int type, size_t length)
{
if (cmsg->cmsg_type == type && cmsg->cmsg_level == level &&
(length == 0 || cmsg->cmsg_len == CMSG_LEN(length)))
return 1;
return 0;
}
/* ================================================== */
static int static int
process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags, process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
SCK_Message *message) SCK_Message *message)
@ -795,7 +806,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
for (cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) { for (cmsg = CMSG_FIRSTHDR(msg); cmsg; cmsg = CMSG_NXTHDR(msg, cmsg)) {
#ifdef HAVE_IN_PKTINFO #ifdef HAVE_IN_PKTINFO
if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_PKTINFO) { if (match_cmsg(cmsg, IPPROTO_IP, IP_PKTINFO, sizeof (struct in_pktinfo))) {
struct in_pktinfo ipi; struct in_pktinfo ipi;
if (message->addr_type != SCK_ADDR_IP) if (message->addr_type != SCK_ADDR_IP)
@ -807,7 +818,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
message->if_index = ipi.ipi_ifindex; message->if_index = ipi.ipi_ifindex;
} }
#elif defined(IP_RECVDSTADDR) #elif defined(IP_RECVDSTADDR)
if (cmsg->cmsg_level == IPPROTO_IP && cmsg->cmsg_type == IP_RECVDSTADDR) { if (match_cmsg(cmsg, IPPROTO_IP, IP_RECVDSTADDR, sizeof (struct in_addr))) {
struct in_addr addr; struct in_addr addr;
if (message->addr_type != SCK_ADDR_IP) if (message->addr_type != SCK_ADDR_IP)
@ -820,7 +831,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
#endif #endif
#ifdef HAVE_IN6_PKTINFO #ifdef HAVE_IN6_PKTINFO
if (cmsg->cmsg_level == IPPROTO_IPV6 && cmsg->cmsg_type == IPV6_PKTINFO) { if (match_cmsg(cmsg, IPPROTO_IPV6, IPV6_PKTINFO, sizeof (struct in6_pktinfo))) {
struct in6_pktinfo ipi; struct in6_pktinfo ipi;
if (message->addr_type != SCK_ADDR_IP) if (message->addr_type != SCK_ADDR_IP)
@ -835,7 +846,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
#endif #endif
#ifdef SCM_TIMESTAMP #ifdef SCM_TIMESTAMP
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMP) { if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMP, sizeof (struct timeval))) {
struct timeval tv; struct timeval tv;
memcpy(&tv, CMSG_DATA(cmsg), sizeof (tv)); memcpy(&tv, CMSG_DATA(cmsg), sizeof (tv));
@ -844,14 +855,15 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
#endif #endif
#ifdef SCM_TIMESTAMPNS #ifdef SCM_TIMESTAMPNS
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPNS) { if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPNS, sizeof (message->timestamp.kernel))) {
memcpy(&message->timestamp.kernel, CMSG_DATA(cmsg), sizeof (message->timestamp.kernel)); memcpy(&message->timestamp.kernel, CMSG_DATA(cmsg), sizeof (message->timestamp.kernel));
} }
#endif #endif
#ifdef HAVE_LINUX_TIMESTAMPING #ifdef HAVE_LINUX_TIMESTAMPING
#ifdef HAVE_LINUX_TIMESTAMPING_OPT_PKTINFO #ifdef HAVE_LINUX_TIMESTAMPING_OPT_PKTINFO
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING_PKTINFO) { if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING_PKTINFO,
sizeof (struct scm_ts_pktinfo))) {
struct scm_ts_pktinfo ts_pktinfo; struct scm_ts_pktinfo ts_pktinfo;
memcpy(&ts_pktinfo, CMSG_DATA(cmsg), sizeof (ts_pktinfo)); memcpy(&ts_pktinfo, CMSG_DATA(cmsg), sizeof (ts_pktinfo));
@ -860,7 +872,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
} }
#endif #endif
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_TIMESTAMPING) { if (match_cmsg(cmsg, SOL_SOCKET, SCM_TIMESTAMPING, sizeof (struct scm_timestamping))) {
struct scm_timestamping ts3; struct scm_timestamping ts3;
memcpy(&ts3, CMSG_DATA(cmsg), sizeof (ts3)); memcpy(&ts3, CMSG_DATA(cmsg), sizeof (ts3));
@ -868,8 +880,9 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
message->timestamp.hw = ts3.ts[2]; message->timestamp.hw = ts3.ts[2];
} }
if ((cmsg->cmsg_level == SOL_IP && cmsg->cmsg_type == IP_RECVERR) || if ((match_cmsg(cmsg, SOL_IP, IP_RECVERR, 0) ||
(cmsg->cmsg_level == SOL_IPV6 && cmsg->cmsg_type == IPV6_RECVERR)) { match_cmsg(cmsg, SOL_IPV6, IPV6_RECVERR, 0)) &&
cmsg->cmsg_len >= CMSG_LEN(sizeof (struct sock_extended_err))) {
struct sock_extended_err err; struct sock_extended_err err;
memcpy(&err, CMSG_DATA(cmsg), sizeof (err)); memcpy(&err, CMSG_DATA(cmsg), sizeof (err));
@ -882,7 +895,7 @@ process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
} }
#endif #endif
if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { if (match_cmsg(cmsg, SOL_SOCKET, SCM_RIGHTS, 0)) {
if (!(flags & SCK_FLAG_MSG_DESCRIPTOR) || cmsg->cmsg_len != CMSG_LEN(sizeof (int))) { if (!(flags & SCK_FLAG_MSG_DESCRIPTOR) || cmsg->cmsg_len != CMSG_LEN(sizeof (int))) {
int i, fd; int i, fd;