diff --git a/socket.c b/socket.c index d9687d0..b43c9b0 100644 --- a/socket.c +++ b/socket.c @@ -168,6 +168,19 @@ check_socket_flag(int sock_flag, int fd_flag, int fs_flag) /* ================================================== */ +static int +set_socket_nonblock(int sock_fd) +{ + if (fcntl(sock_fd, F_SETFL, O_NONBLOCK)) { + DEBUG_LOG("Could not set O_NONBLOCK : %s", strerror(errno)); + return 0; + } + + return 1; +} + +/* ================================================== */ + static int open_socket(int domain, int type, int flags) { @@ -196,8 +209,7 @@ open_socket(int domain, int type, int flags) #ifdef SOCK_NONBLOCK (socket_flags & SOCK_NONBLOCK) == 0 && #endif - fcntl(sock_fd, F_SETFL, O_NONBLOCK)) { - DEBUG_LOG("Could not set O_NONBLOCK : %s", strerror(errno)); + !set_socket_nonblock(sock_fd)) { close(sock_fd); return INVALID_SOCK_FD; } @@ -318,7 +330,7 @@ connect_ip_address(int sock_fd, IPSockAddr *addr) if (saddr_len == 0) return 0; - if (connect(sock_fd, &saddr.sa, saddr_len) < 0) { + if (connect(sock_fd, &saddr.sa, saddr_len) < 0 && errno != EINPROGRESS) { DEBUG_LOG("Could not connect socket to %s : %s", UTI_IPSockAddrToString(addr), strerror(errno)); return 0; @@ -377,8 +389,9 @@ open_ip_socket(IPSockAddr *remote_addr, IPSockAddr *local_addr, int type, int fl goto error; if (remote_addr || local_addr) - DEBUG_LOG("Opened %s socket fd=%d%s%s%s%s", - family == IPADDR_INET4 ? "IPv4" : "IPv6", + DEBUG_LOG("Opened %s%s socket fd=%d%s%s%s%s", + type == SOCK_DGRAM ? "UDP" : type == SOCK_STREAM ? "TCP" : "?", + family == IPADDR_INET4 ? "v4" : "v6", sock_fd, remote_addr ? " remote=" : "", remote_addr ? UTI_IPSockAddrToString(remote_addr) : "", @@ -508,11 +521,8 @@ handle_recv_error(int sock_fd, int flags) be for a socket error. Clear the error to avoid a busy loop. */ if (flags & SCK_FLAG_MSG_ERRQUEUE) { int error = 0; - socklen_t len = sizeof (error); - if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &error, &len)) - DEBUG_LOG("Could not get SO_ERROR"); - if (error) + if (SCK_GetIntOption(sock_fd, SOL_SOCKET, SO_ERROR, &error)) errno = error; } #endif @@ -1077,6 +1087,14 @@ SCK_OpenUdpSocket(IPSockAddr *remote_addr, IPSockAddr *local_addr, int flags) /* ================================================== */ +int +SCK_OpenTcpSocket(IPSockAddr *remote_addr, IPSockAddr *local_addr, int flags) +{ + return open_ip_socket(remote_addr, local_addr, SOCK_STREAM, flags); +} + +/* ================================================== */ + int SCK_OpenUnixDatagramSocket(const char *remote_addr, const char *local_addr, int flags) { @@ -1107,6 +1125,22 @@ SCK_SetIntOption(int sock_fd, int level, int name, int value) /* ================================================== */ +int +SCK_GetIntOption(int sock_fd, int level, int name, int *value) +{ + socklen_t len = sizeof (*value); + + if (getsockopt(sock_fd, level, name, value, &len) < 0) { + DEBUG_LOG("getsockopt() failed fd=%d level=%d name=%d : %s", + sock_fd, level, name, strerror(errno)); + return 0; + } + + return 1; +} + +/* ================================================== */ + int SCK_EnableKernelRxTimestamping(int sock_fd) { @@ -1124,6 +1158,57 @@ SCK_EnableKernelRxTimestamping(int sock_fd) /* ================================================== */ +int +SCK_ListenOnSocket(int sock_fd, int backlog) +{ + if (listen(sock_fd, backlog) < 0) { + DEBUG_LOG("listen() failed : %s", strerror(errno)); + return 0; + } + + return 1; +} + +/* ================================================== */ + +int +SCK_AcceptConnection(int sock_fd, IPSockAddr *remote_addr) +{ + union sockaddr_all saddr; + socklen_t saddr_len = sizeof (saddr); + int conn_fd; + + conn_fd = accept(sock_fd, &saddr.sa, &saddr_len); + if (conn_fd < 0) { + DEBUG_LOG("accept() failed : %s", strerror(errno)); + return INVALID_SOCK_FD; + } + + if (!UTI_FdSetCloexec(conn_fd) || !set_socket_nonblock(conn_fd)) { + close(conn_fd); + return INVALID_SOCK_FD; + } + + SCK_SockaddrToIPSockAddr(&saddr.sa, saddr_len, remote_addr); + + return conn_fd; +} + +/* ================================================== */ + +int +SCK_ShutdownConnection(int sock_fd) +{ + if (shutdown(sock_fd, SHUT_RDWR) < 0) { + DEBUG_LOG("shutdown() failed : %s", strerror(errno)); + return 0; + } + + return 1; +} + +/* ================================================== */ + int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags) { diff --git a/socket.h b/socket.h index b394918..40a8fcb 100644 --- a/socket.h +++ b/socket.h @@ -95,11 +95,18 @@ extern int SCK_OpenUnixDatagramSocket(const char *remote_addr, const char *local extern int SCK_OpenUnixStreamSocket(const char *remote_addr, const char *local_addr, int flags); -/* Set a socket option */ -extern int SCK_SetIntOption(int sock_fd, int level, int option, int value); +/* Set and get a socket option of int size */ +extern int SCK_SetIntOption(int sock_fd, int level, int name, int value); +extern int SCK_GetIntOption(int sock_fd, int level, int name, int *value); + /* Enable RX timestamping socket option */ extern int SCK_EnableKernelRxTimestamping(int sock_fd); +/* Operate on a stream socket - listen()/accept()/shutdown() wrappers */ +extern int SCK_ListenOnSocket(int sock_fd, int backlog); +extern int SCK_AcceptConnection(int sock_fd, IPSockAddr *remote_addr); +extern int SCK_ShutdownConnection(int sock_fd); + /* Receive and send data on connected sockets - recv()/send() wrappers */ extern int SCK_Receive(int sock_fd, void *buffer, unsigned int length, int flags); extern int SCK_Send(int sock_fd, void *buffer, unsigned int length, int flags); diff --git a/sys_linux.c b/sys_linux.c index 898dc7a..1f36696 100644 --- a/sys_linux.c +++ b/sys_linux.c @@ -496,9 +496,10 @@ SYS_Linux_EnableSystemCallFilter(int level) SCMP_SYS(stat), SCMP_SYS(stat64), SCMP_SYS(statfs), SCMP_SYS(statfs64), SCMP_SYS(unlink), SCMP_SYS(unlinkat), /* Socket */ - SCMP_SYS(bind), SCMP_SYS(connect), SCMP_SYS(getsockname), SCMP_SYS(getsockopt), - SCMP_SYS(recv), SCMP_SYS(recvfrom), SCMP_SYS(recvmmsg), SCMP_SYS(recvmsg), - SCMP_SYS(send), SCMP_SYS(sendmmsg), SCMP_SYS(sendmsg), SCMP_SYS(sendto), + SCMP_SYS(accept), SCMP_SYS(bind), SCMP_SYS(connect), SCMP_SYS(getsockname), + SCMP_SYS(getsockopt), SCMP_SYS(recv), SCMP_SYS(recvfrom), + SCMP_SYS(recvmmsg), SCMP_SYS(recvmsg), SCMP_SYS(send), SCMP_SYS(sendmmsg), + SCMP_SYS(sendmsg), SCMP_SYS(sendto), SCMP_SYS(shutdown), /* TODO: check socketcall arguments */ SCMP_SYS(socketcall), /* General I/O */