From a3436c26f0e8e2c679bc045e35d8e3905de39456 Mon Sep 17 00:00:00 2001 From: Miroslav Lichvar Date: Tue, 7 Jul 2020 12:34:29 +0200 Subject: [PATCH] nts: improve session code Add more comments and assertions, replace getsockopt() call with SCK_GetIntOption(), replace strncmp() with memcmp(), move a return statement for clarity, and remove an unused field from the instance record. --- nts_ke_session.c | 52 +++++++++++++++++++++--------------------------- 1 file changed, 23 insertions(+), 29 deletions(-) diff --git a/nts_ke_session.c b/nts_ke_session.c index 05ca99f..a686db2 100644 --- a/nts_ke_session.c +++ b/nts_ke_session.c @@ -81,7 +81,6 @@ struct NKSN_Instance_Record { struct Message message; int new_message; - int ended_message; }; /* ================================================== */ @@ -110,6 +109,8 @@ add_record(struct Message *message, int critical, int type, const void *body, in { struct RecordHeader header; + assert(message->length <= sizeof (message->data)); + if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff || message->length + sizeof (header) + body_length > sizeof (message->data)) return 0; @@ -301,31 +302,14 @@ session_timeout(void *arg) /* ================================================== */ -static int -get_socket_error(int sock_fd) -{ - int optval; - socklen_t optlen = sizeof (optval); - - if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) { - DEBUG_LOG("getsockopt() failed : %s", strerror(errno)); - return EINVAL; - } - - return optval; -} - -/* ================================================== */ - static int check_alpn(NKSN_Instance inst) { gnutls_datum_t alpn; - int r; - r = gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn); - if (r < 0 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 || - strncmp((const char *)alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1)) + if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 || + alpn.size != sizeof (NKE_ALPN_NAME) - 1 || + memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0) return 0; return 1; @@ -375,9 +359,11 @@ handle_event(NKSN_Instance inst, int event) if (event != SCH_FILE_OUTPUT) return 0; - r = get_socket_error(inst->sock_fd); + /* Get the socket error */ + if (!SCK_GetIntOption(inst->sock_fd, SOL_SOCKET, SO_ERROR, &r)) + r = EINVAL; - if (r) { + if (r != 0) { LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r)); stop_session(inst); return 0; @@ -446,6 +432,7 @@ handle_event(NKSN_Instance inst, int event) case KE_SEND: assert(inst->new_message && message->complete); + assert(message->length <= sizeof (message->data) && message->length > message->sent); r = gnutls_record_send(inst->tls_session, &message->data[message->sent], message->length - message->sent); @@ -513,7 +500,9 @@ handle_event(NKSN_Instance inst, int event) /* Server will send a response to the client */ change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN); - break; + + /* Return success to process the received message */ + return 1; case KE_SHUTDOWN: r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR); @@ -539,9 +528,8 @@ handle_event(NKSN_Instance inst, int event) default: assert(0); + return 0; } - - return 1; } /* ================================================== */ @@ -554,6 +542,9 @@ read_write_socket(int fd, int event, void *arg) if (!handle_event(inst, event)) return; + /* A valid message was received. Call the handler to process the message, + and prepare a response if it is a server. */ + reset_message_parsing(&inst->message); if (!(inst->handler)(inst->handler_arg)) { @@ -602,13 +593,15 @@ init_gnutls(void) if (r < 0) LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r)); - /* NTS specification requires TLS1.3 or later */ + /* Prepare a priority cache for server and client NTS-KE sessions + (the NTS specification requires TLS1.3 or later) */ r = gnutls_priority_init2(&priority_cache, "-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2", NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND); if (r < 0) LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r)); + /* Use our clock instead of the system clock in certificate verification */ gnutls_global_set_time_function(get_time); gnutls_initialised = 1; @@ -704,7 +697,7 @@ NKSN_CreateInstance(int server_mode, const char *server_name, inst->server_name = server_name ? Strdup(server_name) : NULL; inst->handler = handler; inst->handler_arg = handler_arg; - /* Replace NULL arg with the session itself */ + /* Replace a NULL argument with the session itself */ if (!inst->handler_arg) inst->handler_arg = inst; @@ -751,7 +744,6 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label, reset_message(&inst->message); inst->new_message = 0; - inst->ended_message = 0; change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT); @@ -785,6 +777,7 @@ NKSN_EndMessage(NKSN_Instance inst) { assert(!inst->message.complete); + /* Terminate the message */ if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0)) return 0; @@ -806,6 +799,7 @@ NKSN_GetRecord(NKSN_Instance inst, int *critical, int *type, int *body_length, if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length)) return 0; + /* Hide the end-of-message record */ if (type2 == NKE_RECORD_END_OF_MESSAGE) return 0;