diff --git a/socket.c b/socket.c index 39813cc..2211da1 100644 --- a/socket.c +++ b/socket.c @@ -630,12 +630,15 @@ init_message_nonaddress(SCK_Message *message) message->timestamp.if_index = INVALID_IF_INDEX; message->timestamp.l2_length = 0; message->timestamp.tx_flags = 0; + + message->descriptor = INVALID_SOCK_FD; } /* ================================================== */ static int -process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, SCK_Message *message) +process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int flags, + SCK_Message *message) { struct cmsghdr *cmsg; @@ -773,6 +776,18 @@ process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, SCK_Mes } } #endif + + if (cmsg->cmsg_level == SOL_SOCKET && cmsg->cmsg_type == SCM_RIGHTS) { + if (!(flags & SCK_FLAG_MSG_DESCRIPTOR) || cmsg->cmsg_len != CMSG_LEN(sizeof (int))) { + unsigned int i; + + DEBUG_LOG("Unexpected SCM_RIGHTS"); + for (i = 0; CMSG_LEN((i + 1) * sizeof (int)) <= cmsg->cmsg_len; i++) + close(((int *)CMSG_DATA(cmsg))[i]); + return 0; + } + message->descriptor = *(int *)CMSG_DATA(cmsg); + } } return 1; @@ -823,7 +838,7 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags for (i = 0; i < n; i++) { hdr = ARR_GetElement(recv_headers, i); - if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, &messages[i])) + if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, flags, &messages[i])) return 0; log_message(sock_fd, 1, &messages[i], @@ -876,8 +891,6 @@ send_message(int sock_fd, SCK_Message *message, int flags) struct msghdr msg; struct iovec iov; - assert(flags == 0); - switch (message->addr_type) { case SCK_ADDR_UNSPEC: saddr_len = 0; @@ -974,6 +987,16 @@ send_message(int sock_fd, SCK_Message *message, int flags) } #endif + if (flags & SCK_FLAG_MSG_DESCRIPTOR) { + int *fd; + + fd = add_control_message(&msg, SOL_SOCKET, SCM_RIGHTS, sizeof (*fd), sizeof (cmsg_buf)); + if (!fd) + return 0; + + *fd = message->descriptor; + } + /* This is apparently required on some systems */ if (msg.msg_controllen == 0) msg.msg_control = NULL; diff --git a/socket.h b/socket.h index 40a8fcb..f28329e 100644 --- a/socket.h +++ b/socket.h @@ -38,6 +38,7 @@ /* Flags for receiving and sending messages */ #define SCK_FLAG_MSG_ERRQUEUE 1 +#define SCK_FLAG_MSG_DESCRIPTOR 2 /* Maximum number of received messages */ #define SCK_MAX_RECV_MESSAGES 4 @@ -70,6 +71,8 @@ typedef struct { int l2_length; int tx_flags; } timestamp; + + int descriptor; } SCK_Message; /* Initialisation function */