From b20ef4cd7f83905f37d6e77ab217ed1e0a9aa04d Mon Sep 17 00:00:00 2001 From: Miroslav Lichvar Date: Tue, 24 Mar 2020 15:22:31 +0100 Subject: [PATCH] socket: simplify receiving messages Don't require the caller to provide a SCK_Message (on stack). Modify the SCK_ReceiveMessage*() functions to return a pointer to static buffers, as the message buffer which SCK_Message points to already is. --- cmdmon.c | 27 ++++++++++++++------------- ntp_io.c | 6 +++--- nts_ke_server.c | 11 ++++++----- privops.c | 14 +++++++------- socket.c | 41 +++++++++++++++++++++++++++-------------- socket.h | 14 +++++--------- 6 files changed, 62 insertions(+), 51 deletions(-) diff --git a/cmdmon.c b/cmdmon.c index 908977e..60ae1cc 100644 --- a/cmdmon.c +++ b/cmdmon.c @@ -1237,7 +1237,7 @@ handle_reset(CMD_Request *rx_message, CMD_Reply *tx_message) static void read_from_cmd_socket(int sock_fd, int event, void *anything) { - SCK_Message sck_message; + SCK_Message *sck_message; CMD_Request rx_message; CMD_Reply tx_message; IPAddr loopback_addr, remote_ip; @@ -1246,26 +1246,27 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) unsigned short rx_command; struct timespec now, cooked_now; - if (!SCK_ReceiveMessage(sock_fd, &sck_message, 0)) + sck_message = SCK_ReceiveMessage(sock_fd, 0); + if (!sck_message) return; - read_length = sck_message.length; + read_length = sck_message->length; /* Get current time cheaply */ SCH_GetLastEventTime(&cooked_now, NULL, &now); /* Check if it's from localhost (127.0.0.1, ::1, or Unix domain), or an authorised address */ - switch (sck_message.addr_type) { + switch (sck_message->addr_type) { case SCK_ADDR_IP: assert(sock_fd == sock_fd4 || sock_fd == sock_fd6); - remote_ip = sck_message.remote_addr.ip.ip_addr; + remote_ip = sck_message->remote_addr.ip.ip_addr; SCK_GetLoopbackIPAddress(remote_ip.family, &loopback_addr); localhost = UTI_CompareIPs(&remote_ip, &loopback_addr, NULL) == 0; if (!localhost && !ADF_IsAllowed(access_auth_table, &remote_ip)) { DEBUG_LOG("Unauthorised host %s", - UTI_IPSockAddrToString(&sck_message.remote_addr.ip)); + UTI_IPSockAddrToString(&sck_message->remote_addr.ip)); return; } @@ -1291,7 +1292,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) return; } - memcpy(&rx_message, sck_message.data, read_length); + memcpy(&rx_message, sck_message->data, read_length); if (rx_message.pkt_type != PKT_TYPE_CMD_REQUEST || rx_message.res1 != 0 || @@ -1313,8 +1314,8 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) rx_command = ntohs(rx_message.command); memset(&tx_message, 0, sizeof (tx_message)); - sck_message.data = &tx_message; - sck_message.length = 0; + sck_message->data = &tx_message; + sck_message->length = 0; tx_message.version = PROTO_VERSION_NUMBER; tx_message.pkt_type = PKT_TYPE_CMD_REPLY; @@ -1329,7 +1330,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) if (rx_message.version >= PROTO_VERSION_MISMATCH_COMPAT_SERVER) { tx_message.status = htons(STT_BADPKTVERSION); - transmit_reply(sock_fd, &sck_message); + transmit_reply(sock_fd, sck_message); } return; } @@ -1339,7 +1340,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) DEBUG_LOG("Command packet has invalid command %d", rx_command); tx_message.status = htons(STT_INVALID); - transmit_reply(sock_fd, &sck_message); + transmit_reply(sock_fd, sck_message); return; } @@ -1348,7 +1349,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) expected_length); tx_message.status = htons(STT_BADPKTLENGTH); - transmit_reply(sock_fd, &sck_message); + transmit_reply(sock_fd, sck_message); return; } @@ -1629,7 +1630,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything) static int do_it=1; if (do_it) { - transmit_reply(sock_fd, &sck_message); + transmit_reply(sock_fd, sck_message); } #if 0 diff --git a/ntp_io.c b/ntp_io.c index a70f6c6..31e21c7 100644 --- a/ntp_io.c +++ b/ntp_io.c @@ -407,7 +407,7 @@ read_from_socket(int sock_fd, int event, void *anything) /* This should only be called when there is something to read, otherwise it may block */ - SCK_Message messages[SCK_MAX_RECV_MESSAGES]; + SCK_Message *messages; int i, received, flags = 0; #ifdef HAVE_LINUX_TIMESTAMPING @@ -423,8 +423,8 @@ read_from_socket(int sock_fd, int event, void *anything) #endif } - received = SCK_ReceiveMessages(sock_fd, messages, SCK_MAX_RECV_MESSAGES, flags); - if (received <= 0) + messages = SCK_ReceiveMessages(sock_fd, flags, &received); + if (!messages) return; for (i = 0; i < received; i++) diff --git a/nts_ke_server.c b/nts_ke_server.c index 14bb621..a3fa9f9 100644 --- a/nts_ke_server.c +++ b/nts_ke_server.c @@ -139,28 +139,29 @@ handle_client(int sock_fd, IPSockAddr *addr) static void handle_helper_request(int fd, int event, void *arg) { - SCK_Message message; + SCK_Message *message; HelperRequest *req; IPSockAddr client_addr; int sock_fd; - if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR)) + message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR); + if (!message) return; - sock_fd = message.descriptor; + sock_fd = message->descriptor; if (sock_fd < 0) { /* Message with no descriptor is a shutdown command */ SCH_QuitProgram(); return; } - if (message.length != sizeof (HelperRequest)) { + if (message->length != sizeof (HelperRequest)) { DEBUG_LOG("Unexpected message length"); SCK_CloseSocket(sock_fd); return; } - req = message.data; + req = message->data; /* Extract the server key and client address from the request */ server_keys[current_server_key].id = ntohl(req->key_id); diff --git a/privops.c b/privops.c index 6d06c4c..e999f36 100644 --- a/privops.c +++ b/privops.c @@ -171,22 +171,22 @@ send_response(int fd, const PrvResponse *res) static int receive_from_daemon(int fd, PrvRequest *req) { - SCK_Message message; + SCK_Message *message; - if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR) || - message.length != sizeof (*req)) + message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR); + if (!message || message->length != sizeof (*req)) return 0; - memcpy(req, message.data, sizeof (*req)); + memcpy(req, message->data, sizeof (*req)); if (req->op == OP_BINDSOCKET) { - req->data.bind_socket.sock = message.descriptor; + req->data.bind_socket.sock = message->descriptor; /* return error if valid descriptor not found */ if (req->data.bind_socket.sock < 0) return 0; - } else if (message.descriptor >= 0) { - SCK_CloseSocket(message.descriptor); + } else if (message->descriptor >= 0) { + SCK_CloseSocket(message->descriptor); return 0; } diff --git a/socket.c b/socket.c index 3c1ffd3..9e874d2 100644 --- a/socket.c +++ b/socket.c @@ -68,7 +68,7 @@ struct Message { }; #ifdef HAVE_RECVMMSG -#define MAX_RECV_MESSAGES SCK_MAX_RECV_MESSAGES +#define MAX_RECV_MESSAGES 4 #define MessageHeader mmsghdr #else /* Compatible with mmsghdr */ @@ -85,9 +85,10 @@ static int initialised; /* Flags supported by socket() */ static int supported_socket_flags; -/* Arrays of Message and MessageHeader */ +/* Arrays of Message, MessageHeader, and SCK_Message */ static ARR_Instance recv_messages; static ARR_Instance recv_headers; +static ARR_Instance recv_sck_messages; static unsigned int received_messages; @@ -867,22 +868,27 @@ process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int fla /* ================================================== */ -static int -receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags) +static SCK_Message * +receive_messages(int sock_fd, int flags, int max_messages, int *num_messages) { struct MessageHeader *hdr; + SCK_Message *messages; unsigned int i, n; int ret, recv_flags = 0; assert(initialised); + *num_messages = 0; + if (max_messages < 1) - return 0; + return NULL; /* Prepare used buffers for new messages */ prepare_buffers(received_messages); received_messages = 0; + messages = ARR_GetElements(recv_sck_messages); + hdr = ARR_GetElements(recv_headers); n = ARR_GetSize(recv_headers); n = MIN(n, max_messages); @@ -903,7 +909,7 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags if (ret < 0) { handle_recv_error(sock_fd, flags); - return 0; + return NULL; } received_messages = n; @@ -911,13 +917,15 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags for (i = 0; i < n; i++) { hdr = ARR_GetElement(recv_headers, i); if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, flags, &messages[i])) - return 0; + return NULL; log_message(sock_fd, 1, &messages[i], flags & SCK_FLAG_MSG_ERRQUEUE ? "Received error" : "Received", NULL); } - return n; + *num_messages = n; + + return messages; } /* ================================================== */ @@ -1092,6 +1100,8 @@ SCK_Initialise(void) ARR_SetSize(recv_messages, MAX_RECV_MESSAGES); recv_headers = ARR_CreateInstance(sizeof (struct MessageHeader)); ARR_SetSize(recv_headers, MAX_RECV_MESSAGES); + recv_sck_messages = ARR_CreateInstance(sizeof (SCK_Message)); + ARR_SetSize(recv_sck_messages, MAX_RECV_MESSAGES); received_messages = MAX_RECV_MESSAGES; @@ -1115,6 +1125,7 @@ SCK_Initialise(void) void SCK_Finalise(void) { + ARR_DestroyInstance(recv_sck_messages); ARR_DestroyInstance(recv_headers); ARR_DestroyInstance(recv_messages); @@ -1381,18 +1392,20 @@ SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags) /* ================================================== */ -int -SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags) +SCK_Message * +SCK_ReceiveMessage(int sock_fd, int flags) { - return SCK_ReceiveMessages(sock_fd, message, 1, flags); + int num_messages; + + return receive_messages(sock_fd, flags, 1, &num_messages); } /* ================================================== */ -int -SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages, int flags) +SCK_Message * +SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages) { - return receive_messages(sock_fd, messages, max_messages, flags); + return receive_messages(sock_fd, flags, MAX_RECV_MESSAGES, num_messages); } /* ================================================== */ diff --git a/socket.h b/socket.h index ee44526..949690b 100644 --- a/socket.h +++ b/socket.h @@ -41,9 +41,6 @@ #define SCK_FLAG_MSG_ERRQUEUE 1 #define SCK_FLAG_MSG_DESCRIPTOR 2 -/* Maximum number of received messages */ -#define SCK_MAX_RECV_MESSAGES 4 - typedef enum { SCK_ADDR_UNSPEC = 0, SCK_ADDR_IP, @@ -119,12 +116,11 @@ extern int SCK_ShutdownConnection(int sock_fd); extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags); extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags); -/* Receive a single message or multiple messages. The functions return the - number of received messages, or 0 on error. The returned data point to - static buffers, which are valid until another call of these functions. */ -extern int SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags); -extern int SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages, - int flags); +/* Receive a single message or multiple messages. The functions return + a pointer to static buffers, or NULL on error. The buffers are valid until + another call of the functions and can be reused for sending messages. */ +extern SCK_Message *SCK_ReceiveMessage(int sock_fd, int flags); +extern SCK_Message *SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages); /* Initialise a new message (e.g. before sending) */ extern void SCK_InitMessage(SCK_Message *message, SCK_AddressType addr_type);