socket: improve code

Add more assertions and other checks, and improve coding style a bit.
This commit is contained in:
Miroslav Lichvar 2020-08-11 17:07:14 +02:00
parent 8c75f44603
commit 9ac582fa35
3 changed files with 31 additions and 14 deletions

View file

@ -757,7 +757,7 @@ NIO_Linux_ProcessMessage(SCK_Message *message, NTP_Local_Address *local_addr,
l2_length = message->length; l2_length = message->length;
message->length = extract_udp_data(message->data, &message->remote_addr.ip, message->length); message->length = extract_udp_data(message->data, &message->remote_addr.ip, message->length);
DEBUG_LOG("Extracted message for %s fd=%d len=%u", DEBUG_LOG("Extracted message for %s fd=%d len=%d",
UTI_IPSockAddrToString(&message->remote_addr.ip), UTI_IPSockAddrToString(&message->remote_addr.ip),
local_addr->sock_fd, message->length); local_addr->sock_fd, message->length);

View file

@ -118,7 +118,7 @@ prepare_buffers(unsigned int n)
hdr->msg_hdr.msg_namelen = sizeof (msg->name); hdr->msg_hdr.msg_namelen = sizeof (msg->name);
hdr->msg_hdr.msg_iov = &msg->iov; hdr->msg_hdr.msg_iov = &msg->iov;
hdr->msg_hdr.msg_iovlen = 1; hdr->msg_hdr.msg_iovlen = 1;
hdr->msg_hdr.msg_control = &msg->cmsg_buf; hdr->msg_hdr.msg_control = msg->cmsg_buf;
hdr->msg_hdr.msg_controllen = sizeof (msg->cmsg_buf); hdr->msg_hdr.msg_controllen = sizeof (msg->cmsg_buf);
hdr->msg_hdr.msg_flags = 0; hdr->msg_hdr.msg_flags = 0;
hdr->msg_len = 0; hdr->msg_len = 0;
@ -176,7 +176,7 @@ check_socket_flag(int sock_flag, int fd_flag, int fs_flag)
static int static int
set_socket_nonblock(int sock_fd) set_socket_nonblock(int sock_fd)
{ {
if (fcntl(sock_fd, F_SETFL, O_NONBLOCK)) { if (fcntl(sock_fd, F_SETFL, O_NONBLOCK) < 0) {
DEBUG_LOG("Could not set O_NONBLOCK : %s", strerror(errno)); DEBUG_LOG("Could not set O_NONBLOCK : %s", strerror(errno));
return 0; return 0;
} }
@ -656,9 +656,8 @@ log_message(int sock_fd, int direction, SCK_Message *message, const char *prefix
case SCK_ADDR_IP: case SCK_ADDR_IP:
if (message->remote_addr.ip.ip_addr.family != IPADDR_UNSPEC) if (message->remote_addr.ip.ip_addr.family != IPADDR_UNSPEC)
remote_addr = UTI_IPSockAddrToString(&message->remote_addr.ip); remote_addr = UTI_IPSockAddrToString(&message->remote_addr.ip);
if (message->local_addr.ip.family != IPADDR_UNSPEC) { if (message->local_addr.ip.family != IPADDR_UNSPEC)
local_addr = UTI_IPToString(&message->local_addr.ip); local_addr = UTI_IPToString(&message->local_addr.ip);
}
break; break;
case SCK_ADDR_UNIX: case SCK_ADDR_UNIX:
remote_addr = message->remote_addr.path; remote_addr = message->remote_addr.path;
@ -684,7 +683,7 @@ log_message(int sock_fd, int direction, SCK_Message *message, const char *prefix
snprintf(tslen, sizeof (tslen), " tslen=%d", message->timestamp.l2_length); snprintf(tslen, sizeof (tslen), " tslen=%d", message->timestamp.l2_length);
} }
DEBUG_LOG("%s message%s%s%s%s fd=%d len=%u%s%s%s%s%s%s", DEBUG_LOG("%s message%s%s%s%s fd=%d len=%d%s%s%s%s%s%s",
prefix, prefix,
remote_addr ? (direction > 0 ? " from " : " to ") : "", remote_addr ? (direction > 0 ? " from " : " to ") : "",
remote_addr ? remote_addr : "", remote_addr ? remote_addr : "",
@ -739,7 +738,7 @@ init_message_nonaddress(SCK_Message *message)
/* ================================================== */ /* ================================================== */
static int static int
process_header(struct msghdr *msg, unsigned int msg_length, int sock_fd, int flags, process_header(struct msghdr *msg, int msg_length, int sock_fd, int flags,
SCK_Message *message) SCK_Message *message)
{ {
struct cmsghdr *cmsg; struct cmsghdr *cmsg;
@ -921,7 +920,10 @@ receive_messages(int sock_fd, int flags, int max_messages, int *num_messages)
hdr = ARR_GetElements(recv_headers); hdr = ARR_GetElements(recv_headers);
n = ARR_GetSize(recv_headers); n = ARR_GetSize(recv_headers);
n = MIN(n, max_messages); n = MIN(n, max_messages);
assert(n >= 1);
if (n < 1 || n > MAX_RECV_MESSAGES ||
n > ARR_GetSize(recv_messages) || n > ARR_GetSize(recv_sck_messages))
assert(0);
recv_flags = get_recv_flags(flags); recv_flags = get_recv_flags(flags);
@ -1030,6 +1032,11 @@ send_message(int sock_fd, SCK_Message *message, int flags)
msg.msg_namelen = 0; msg.msg_namelen = 0;
} }
if (message->length < 0) {
DEBUG_LOG("Invalid length %d", message->length);
return 0;
}
iov.iov_base = message->data; iov.iov_base = message->data;
iov.iov_len = message->length; iov.iov_len = message->length;
msg.msg_iov = &iov; msg.msg_iov = &iov;
@ -1404,10 +1411,15 @@ SCK_ShutdownConnection(int sock_fd)
/* ================================================== */ /* ================================================== */
int int
SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags) SCK_Receive(int sock_fd, void *buffer, int length, int flags)
{ {
int r; int r;
if (length < 0) {
DEBUG_LOG("Invalid length %d", length);
return -1;
}
r = recv(sock_fd, buffer, length, get_recv_flags(flags)); r = recv(sock_fd, buffer, length, get_recv_flags(flags));
if (r < 0) { if (r < 0) {
@ -1423,16 +1435,21 @@ SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags)
/* ================================================== */ /* ================================================== */
int int
SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags) SCK_Send(int sock_fd, const void *buffer, int length, int flags)
{ {
int r; int r;
assert(flags == 0); assert(flags == 0);
if (length < 0) {
DEBUG_LOG("Invalid length %d", length);
return -1;
}
r = send(sock_fd, buffer, length, 0); r = send(sock_fd, buffer, length, 0);
if (r < 0) { if (r < 0) {
DEBUG_LOG("Could not send data fd=%d len=%u : %s", sock_fd, length, strerror(errno)); DEBUG_LOG("Could not send data fd=%d len=%d : %s", sock_fd, length, strerror(errno));
return r; return r;
} }

View file

@ -49,7 +49,7 @@ typedef enum {
typedef struct { typedef struct {
void *data; void *data;
unsigned int length; int length;
SCK_AddressType addr_type; SCK_AddressType addr_type;
int if_index; int if_index;
@ -119,8 +119,8 @@ extern int SCK_AcceptConnection(int sock_fd, IPSockAddr *remote_addr);
extern int SCK_ShutdownConnection(int sock_fd); extern int SCK_ShutdownConnection(int sock_fd);
/* Receive and send data on connected sockets - recv()/send() wrappers */ /* Receive and send data on connected sockets - recv()/send() wrappers */
extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags); extern int SCK_Receive(int sock_fd, void *buffer, int length, int flags);
extern int SCK_Send(int sock_fd, const void *buffer, unsigned int length, int flags); extern int SCK_Send(int sock_fd, const void *buffer, int length, int flags);
/* Receive a single message or multiple messages. The functions return /* Receive a single message or multiple messages. The functions return
a pointer to static buffers, or NULL on error. The buffers are valid until a pointer to static buffers, or NULL on error. The buffers are valid until