socket: simplify receiving messages
Don't require the caller to provide a SCK_Message (on stack). Modify the SCK_ReceiveMessage*() functions to return a pointer to static buffers, as the message buffer which SCK_Message points to already is.
This commit is contained in:
parent
b8b751a932
commit
b20ef4cd7f
6 changed files with 62 additions and 51 deletions
27
cmdmon.c
27
cmdmon.c
|
@ -1237,7 +1237,7 @@ handle_reset(CMD_Request *rx_message, CMD_Reply *tx_message)
|
||||||
static void
|
static void
|
||||||
read_from_cmd_socket(int sock_fd, int event, void *anything)
|
read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
{
|
{
|
||||||
SCK_Message sck_message;
|
SCK_Message *sck_message;
|
||||||
CMD_Request rx_message;
|
CMD_Request rx_message;
|
||||||
CMD_Reply tx_message;
|
CMD_Reply tx_message;
|
||||||
IPAddr loopback_addr, remote_ip;
|
IPAddr loopback_addr, remote_ip;
|
||||||
|
@ -1246,26 +1246,27 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
unsigned short rx_command;
|
unsigned short rx_command;
|
||||||
struct timespec now, cooked_now;
|
struct timespec now, cooked_now;
|
||||||
|
|
||||||
if (!SCK_ReceiveMessage(sock_fd, &sck_message, 0))
|
sck_message = SCK_ReceiveMessage(sock_fd, 0);
|
||||||
|
if (!sck_message)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
read_length = sck_message.length;
|
read_length = sck_message->length;
|
||||||
|
|
||||||
/* Get current time cheaply */
|
/* Get current time cheaply */
|
||||||
SCH_GetLastEventTime(&cooked_now, NULL, &now);
|
SCH_GetLastEventTime(&cooked_now, NULL, &now);
|
||||||
|
|
||||||
/* Check if it's from localhost (127.0.0.1, ::1, or Unix domain),
|
/* Check if it's from localhost (127.0.0.1, ::1, or Unix domain),
|
||||||
or an authorised address */
|
or an authorised address */
|
||||||
switch (sck_message.addr_type) {
|
switch (sck_message->addr_type) {
|
||||||
case SCK_ADDR_IP:
|
case SCK_ADDR_IP:
|
||||||
assert(sock_fd == sock_fd4 || sock_fd == sock_fd6);
|
assert(sock_fd == sock_fd4 || sock_fd == sock_fd6);
|
||||||
remote_ip = sck_message.remote_addr.ip.ip_addr;
|
remote_ip = sck_message->remote_addr.ip.ip_addr;
|
||||||
SCK_GetLoopbackIPAddress(remote_ip.family, &loopback_addr);
|
SCK_GetLoopbackIPAddress(remote_ip.family, &loopback_addr);
|
||||||
localhost = UTI_CompareIPs(&remote_ip, &loopback_addr, NULL) == 0;
|
localhost = UTI_CompareIPs(&remote_ip, &loopback_addr, NULL) == 0;
|
||||||
|
|
||||||
if (!localhost && !ADF_IsAllowed(access_auth_table, &remote_ip)) {
|
if (!localhost && !ADF_IsAllowed(access_auth_table, &remote_ip)) {
|
||||||
DEBUG_LOG("Unauthorised host %s",
|
DEBUG_LOG("Unauthorised host %s",
|
||||||
UTI_IPSockAddrToString(&sck_message.remote_addr.ip));
|
UTI_IPSockAddrToString(&sck_message->remote_addr.ip));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1291,7 +1292,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(&rx_message, sck_message.data, read_length);
|
memcpy(&rx_message, sck_message->data, read_length);
|
||||||
|
|
||||||
if (rx_message.pkt_type != PKT_TYPE_CMD_REQUEST ||
|
if (rx_message.pkt_type != PKT_TYPE_CMD_REQUEST ||
|
||||||
rx_message.res1 != 0 ||
|
rx_message.res1 != 0 ||
|
||||||
|
@ -1313,8 +1314,8 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
rx_command = ntohs(rx_message.command);
|
rx_command = ntohs(rx_message.command);
|
||||||
|
|
||||||
memset(&tx_message, 0, sizeof (tx_message));
|
memset(&tx_message, 0, sizeof (tx_message));
|
||||||
sck_message.data = &tx_message;
|
sck_message->data = &tx_message;
|
||||||
sck_message.length = 0;
|
sck_message->length = 0;
|
||||||
|
|
||||||
tx_message.version = PROTO_VERSION_NUMBER;
|
tx_message.version = PROTO_VERSION_NUMBER;
|
||||||
tx_message.pkt_type = PKT_TYPE_CMD_REPLY;
|
tx_message.pkt_type = PKT_TYPE_CMD_REPLY;
|
||||||
|
@ -1329,7 +1330,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
|
|
||||||
if (rx_message.version >= PROTO_VERSION_MISMATCH_COMPAT_SERVER) {
|
if (rx_message.version >= PROTO_VERSION_MISMATCH_COMPAT_SERVER) {
|
||||||
tx_message.status = htons(STT_BADPKTVERSION);
|
tx_message.status = htons(STT_BADPKTVERSION);
|
||||||
transmit_reply(sock_fd, &sck_message);
|
transmit_reply(sock_fd, sck_message);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
@ -1339,7 +1340,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
DEBUG_LOG("Command packet has invalid command %d", rx_command);
|
DEBUG_LOG("Command packet has invalid command %d", rx_command);
|
||||||
|
|
||||||
tx_message.status = htons(STT_INVALID);
|
tx_message.status = htons(STT_INVALID);
|
||||||
transmit_reply(sock_fd, &sck_message);
|
transmit_reply(sock_fd, sck_message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1348,7 +1349,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
expected_length);
|
expected_length);
|
||||||
|
|
||||||
tx_message.status = htons(STT_BADPKTLENGTH);
|
tx_message.status = htons(STT_BADPKTLENGTH);
|
||||||
transmit_reply(sock_fd, &sck_message);
|
transmit_reply(sock_fd, sck_message);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1629,7 +1630,7 @@ read_from_cmd_socket(int sock_fd, int event, void *anything)
|
||||||
static int do_it=1;
|
static int do_it=1;
|
||||||
|
|
||||||
if (do_it) {
|
if (do_it) {
|
||||||
transmit_reply(sock_fd, &sck_message);
|
transmit_reply(sock_fd, sck_message);
|
||||||
}
|
}
|
||||||
|
|
||||||
#if 0
|
#if 0
|
||||||
|
|
6
ntp_io.c
6
ntp_io.c
|
@ -407,7 +407,7 @@ read_from_socket(int sock_fd, int event, void *anything)
|
||||||
/* This should only be called when there is something
|
/* This should only be called when there is something
|
||||||
to read, otherwise it may block */
|
to read, otherwise it may block */
|
||||||
|
|
||||||
SCK_Message messages[SCK_MAX_RECV_MESSAGES];
|
SCK_Message *messages;
|
||||||
int i, received, flags = 0;
|
int i, received, flags = 0;
|
||||||
|
|
||||||
#ifdef HAVE_LINUX_TIMESTAMPING
|
#ifdef HAVE_LINUX_TIMESTAMPING
|
||||||
|
@ -423,8 +423,8 @@ read_from_socket(int sock_fd, int event, void *anything)
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
received = SCK_ReceiveMessages(sock_fd, messages, SCK_MAX_RECV_MESSAGES, flags);
|
messages = SCK_ReceiveMessages(sock_fd, flags, &received);
|
||||||
if (received <= 0)
|
if (!messages)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
for (i = 0; i < received; i++)
|
for (i = 0; i < received; i++)
|
||||||
|
|
|
@ -139,28 +139,29 @@ handle_client(int sock_fd, IPSockAddr *addr)
|
||||||
static void
|
static void
|
||||||
handle_helper_request(int fd, int event, void *arg)
|
handle_helper_request(int fd, int event, void *arg)
|
||||||
{
|
{
|
||||||
SCK_Message message;
|
SCK_Message *message;
|
||||||
HelperRequest *req;
|
HelperRequest *req;
|
||||||
IPSockAddr client_addr;
|
IPSockAddr client_addr;
|
||||||
int sock_fd;
|
int sock_fd;
|
||||||
|
|
||||||
if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR))
|
message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR);
|
||||||
|
if (!message)
|
||||||
return;
|
return;
|
||||||
|
|
||||||
sock_fd = message.descriptor;
|
sock_fd = message->descriptor;
|
||||||
if (sock_fd < 0) {
|
if (sock_fd < 0) {
|
||||||
/* Message with no descriptor is a shutdown command */
|
/* Message with no descriptor is a shutdown command */
|
||||||
SCH_QuitProgram();
|
SCH_QuitProgram();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.length != sizeof (HelperRequest)) {
|
if (message->length != sizeof (HelperRequest)) {
|
||||||
DEBUG_LOG("Unexpected message length");
|
DEBUG_LOG("Unexpected message length");
|
||||||
SCK_CloseSocket(sock_fd);
|
SCK_CloseSocket(sock_fd);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
req = message.data;
|
req = message->data;
|
||||||
|
|
||||||
/* Extract the server key and client address from the request */
|
/* Extract the server key and client address from the request */
|
||||||
server_keys[current_server_key].id = ntohl(req->key_id);
|
server_keys[current_server_key].id = ntohl(req->key_id);
|
||||||
|
|
14
privops.c
14
privops.c
|
@ -171,22 +171,22 @@ send_response(int fd, const PrvResponse *res)
|
||||||
static int
|
static int
|
||||||
receive_from_daemon(int fd, PrvRequest *req)
|
receive_from_daemon(int fd, PrvRequest *req)
|
||||||
{
|
{
|
||||||
SCK_Message message;
|
SCK_Message *message;
|
||||||
|
|
||||||
if (!SCK_ReceiveMessage(fd, &message, SCK_FLAG_MSG_DESCRIPTOR) ||
|
message = SCK_ReceiveMessage(fd, SCK_FLAG_MSG_DESCRIPTOR);
|
||||||
message.length != sizeof (*req))
|
if (!message || message->length != sizeof (*req))
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
memcpy(req, message.data, sizeof (*req));
|
memcpy(req, message->data, sizeof (*req));
|
||||||
|
|
||||||
if (req->op == OP_BINDSOCKET) {
|
if (req->op == OP_BINDSOCKET) {
|
||||||
req->data.bind_socket.sock = message.descriptor;
|
req->data.bind_socket.sock = message->descriptor;
|
||||||
|
|
||||||
/* return error if valid descriptor not found */
|
/* return error if valid descriptor not found */
|
||||||
if (req->data.bind_socket.sock < 0)
|
if (req->data.bind_socket.sock < 0)
|
||||||
return 0;
|
return 0;
|
||||||
} else if (message.descriptor >= 0) {
|
} else if (message->descriptor >= 0) {
|
||||||
SCK_CloseSocket(message.descriptor);
|
SCK_CloseSocket(message->descriptor);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
41
socket.c
41
socket.c
|
@ -68,7 +68,7 @@ struct Message {
|
||||||
};
|
};
|
||||||
|
|
||||||
#ifdef HAVE_RECVMMSG
|
#ifdef HAVE_RECVMMSG
|
||||||
#define MAX_RECV_MESSAGES SCK_MAX_RECV_MESSAGES
|
#define MAX_RECV_MESSAGES 4
|
||||||
#define MessageHeader mmsghdr
|
#define MessageHeader mmsghdr
|
||||||
#else
|
#else
|
||||||
/* Compatible with mmsghdr */
|
/* Compatible with mmsghdr */
|
||||||
|
@ -85,9 +85,10 @@ static int initialised;
|
||||||
/* Flags supported by socket() */
|
/* Flags supported by socket() */
|
||||||
static int supported_socket_flags;
|
static int supported_socket_flags;
|
||||||
|
|
||||||
/* Arrays of Message and MessageHeader */
|
/* Arrays of Message, MessageHeader, and SCK_Message */
|
||||||
static ARR_Instance recv_messages;
|
static ARR_Instance recv_messages;
|
||||||
static ARR_Instance recv_headers;
|
static ARR_Instance recv_headers;
|
||||||
|
static ARR_Instance recv_sck_messages;
|
||||||
|
|
||||||
static unsigned int received_messages;
|
static unsigned int received_messages;
|
||||||
|
|
||||||
|
@ -867,22 +868,27 @@ process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int fla
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
|
||||||
static int
|
static SCK_Message *
|
||||||
receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags)
|
receive_messages(int sock_fd, int flags, int max_messages, int *num_messages)
|
||||||
{
|
{
|
||||||
struct MessageHeader *hdr;
|
struct MessageHeader *hdr;
|
||||||
|
SCK_Message *messages;
|
||||||
unsigned int i, n;
|
unsigned int i, n;
|
||||||
int ret, recv_flags = 0;
|
int ret, recv_flags = 0;
|
||||||
|
|
||||||
assert(initialised);
|
assert(initialised);
|
||||||
|
|
||||||
|
*num_messages = 0;
|
||||||
|
|
||||||
if (max_messages < 1)
|
if (max_messages < 1)
|
||||||
return 0;
|
return NULL;
|
||||||
|
|
||||||
/* Prepare used buffers for new messages */
|
/* Prepare used buffers for new messages */
|
||||||
prepare_buffers(received_messages);
|
prepare_buffers(received_messages);
|
||||||
received_messages = 0;
|
received_messages = 0;
|
||||||
|
|
||||||
|
messages = ARR_GetElements(recv_sck_messages);
|
||||||
|
|
||||||
hdr = ARR_GetElements(recv_headers);
|
hdr = ARR_GetElements(recv_headers);
|
||||||
n = ARR_GetSize(recv_headers);
|
n = ARR_GetSize(recv_headers);
|
||||||
n = MIN(n, max_messages);
|
n = MIN(n, max_messages);
|
||||||
|
@ -903,7 +909,7 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags
|
||||||
|
|
||||||
if (ret < 0) {
|
if (ret < 0) {
|
||||||
handle_recv_error(sock_fd, flags);
|
handle_recv_error(sock_fd, flags);
|
||||||
return 0;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
received_messages = n;
|
received_messages = n;
|
||||||
|
@ -911,13 +917,15 @@ receive_messages(int sock_fd, SCK_Message *messages, int max_messages, int flags
|
||||||
for (i = 0; i < n; i++) {
|
for (i = 0; i < n; i++) {
|
||||||
hdr = ARR_GetElement(recv_headers, i);
|
hdr = ARR_GetElement(recv_headers, i);
|
||||||
if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, flags, &messages[i]))
|
if (!process_header(&hdr->msg_hdr, hdr->msg_len, sock_fd, flags, &messages[i]))
|
||||||
return 0;
|
return NULL;
|
||||||
|
|
||||||
log_message(sock_fd, 1, &messages[i],
|
log_message(sock_fd, 1, &messages[i],
|
||||||
flags & SCK_FLAG_MSG_ERRQUEUE ? "Received error" : "Received", NULL);
|
flags & SCK_FLAG_MSG_ERRQUEUE ? "Received error" : "Received", NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
return n;
|
*num_messages = n;
|
||||||
|
|
||||||
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
@ -1092,6 +1100,8 @@ SCK_Initialise(void)
|
||||||
ARR_SetSize(recv_messages, MAX_RECV_MESSAGES);
|
ARR_SetSize(recv_messages, MAX_RECV_MESSAGES);
|
||||||
recv_headers = ARR_CreateInstance(sizeof (struct MessageHeader));
|
recv_headers = ARR_CreateInstance(sizeof (struct MessageHeader));
|
||||||
ARR_SetSize(recv_headers, MAX_RECV_MESSAGES);
|
ARR_SetSize(recv_headers, MAX_RECV_MESSAGES);
|
||||||
|
recv_sck_messages = ARR_CreateInstance(sizeof (SCK_Message));
|
||||||
|
ARR_SetSize(recv_sck_messages, MAX_RECV_MESSAGES);
|
||||||
|
|
||||||
received_messages = MAX_RECV_MESSAGES;
|
received_messages = MAX_RECV_MESSAGES;
|
||||||
|
|
||||||
|
@ -1115,6 +1125,7 @@ SCK_Initialise(void)
|
||||||
void
|
void
|
||||||
SCK_Finalise(void)
|
SCK_Finalise(void)
|
||||||
{
|
{
|
||||||
|
ARR_DestroyInstance(recv_sck_messages);
|
||||||
ARR_DestroyInstance(recv_headers);
|
ARR_DestroyInstance(recv_headers);
|
||||||
ARR_DestroyInstance(recv_messages);
|
ARR_DestroyInstance(recv_messages);
|
||||||
|
|
||||||
|
@ -1381,18 +1392,20 @@ SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags)
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
|
||||||
int
|
SCK_Message *
|
||||||
SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags)
|
SCK_ReceiveMessage(int sock_fd, int flags)
|
||||||
{
|
{
|
||||||
return SCK_ReceiveMessages(sock_fd, message, 1, flags);
|
int num_messages;
|
||||||
|
|
||||||
|
return receive_messages(sock_fd, flags, 1, &num_messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
|
||||||
int
|
SCK_Message *
|
||||||
SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages, int flags)
|
SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages)
|
||||||
{
|
{
|
||||||
return receive_messages(sock_fd, messages, max_messages, flags);
|
return receive_messages(sock_fd, flags, MAX_RECV_MESSAGES, num_messages);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
|
14
socket.h
14
socket.h
|
@ -41,9 +41,6 @@
|
||||||
#define SCK_FLAG_MSG_ERRQUEUE 1
|
#define SCK_FLAG_MSG_ERRQUEUE 1
|
||||||
#define SCK_FLAG_MSG_DESCRIPTOR 2
|
#define SCK_FLAG_MSG_DESCRIPTOR 2
|
||||||
|
|
||||||
/* Maximum number of received messages */
|
|
||||||
#define SCK_MAX_RECV_MESSAGES 4
|
|
||||||
|
|
||||||
typedef enum {
|
typedef enum {
|
||||||
SCK_ADDR_UNSPEC = 0,
|
SCK_ADDR_UNSPEC = 0,
|
||||||
SCK_ADDR_IP,
|
SCK_ADDR_IP,
|
||||||
|
@ -119,12 +116,11 @@ extern int SCK_ShutdownConnection(int sock_fd);
|
||||||
extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags);
|
extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags);
|
||||||
extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags);
|
extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags);
|
||||||
|
|
||||||
/* Receive a single message or multiple messages. The functions return the
|
/* Receive a single message or multiple messages. The functions return
|
||||||
number of received messages, or 0 on error. The returned data point to
|
a pointer to static buffers, or NULL on error. The buffers are valid until
|
||||||
static buffers, which are valid until another call of these functions. */
|
another call of the functions and can be reused for sending messages. */
|
||||||
extern int SCK_ReceiveMessage(int sock_fd, SCK_Message *message, int flags);
|
extern SCK_Message *SCK_ReceiveMessage(int sock_fd, int flags);
|
||||||
extern int SCK_ReceiveMessages(int sock_fd, SCK_Message *messages, int max_messages,
|
extern SCK_Message *SCK_ReceiveMessages(int sock_fd, int flags, int *num_messages);
|
||||||
int flags);
|
|
||||||
|
|
||||||
/* Initialise a new message (e.g. before sending) */
|
/* Initialise a new message (e.g. before sending) */
|
||||||
extern void SCK_InitMessage(SCK_Message *message, SCK_AddressType addr_type);
|
extern void SCK_InitMessage(SCK_Message *message, SCK_AddressType addr_type);
|
||||||
|
|
Loading…
Reference in a new issue