nts: improve session code

Add more comments and assertions, replace getsockopt() call with
SCK_GetIntOption(), replace strncmp() with memcmp(), move a return
statement for clarity, and remove an unused field from the instance
record.
This commit is contained in:
Miroslav Lichvar 2020-07-07 12:34:29 +02:00
parent b0f5024d56
commit a3436c26f0

View file

@ -81,7 +81,6 @@ struct NKSN_Instance_Record {
struct Message message; struct Message message;
int new_message; int new_message;
int ended_message;
}; };
/* ================================================== */ /* ================================================== */
@ -110,6 +109,8 @@ add_record(struct Message *message, int critical, int type, const void *body, in
{ {
struct RecordHeader header; struct RecordHeader header;
assert(message->length <= sizeof (message->data));
if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff || if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff ||
message->length + sizeof (header) + body_length > sizeof (message->data)) message->length + sizeof (header) + body_length > sizeof (message->data))
return 0; return 0;
@ -301,31 +302,14 @@ session_timeout(void *arg)
/* ================================================== */ /* ================================================== */
static int
get_socket_error(int sock_fd)
{
int optval;
socklen_t optlen = sizeof (optval);
if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
DEBUG_LOG("getsockopt() failed : %s", strerror(errno));
return EINVAL;
}
return optval;
}
/* ================================================== */
static int static int
check_alpn(NKSN_Instance inst) check_alpn(NKSN_Instance inst)
{ {
gnutls_datum_t alpn; gnutls_datum_t alpn;
int r;
r = gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn); if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 ||
if (r < 0 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
strncmp((const char *)alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1)) memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0)
return 0; return 0;
return 1; return 1;
@ -375,9 +359,11 @@ handle_event(NKSN_Instance inst, int event)
if (event != SCH_FILE_OUTPUT) if (event != SCH_FILE_OUTPUT)
return 0; return 0;
r = get_socket_error(inst->sock_fd); /* Get the socket error */
if (!SCK_GetIntOption(inst->sock_fd, SOL_SOCKET, SO_ERROR, &r))
r = EINVAL;
if (r) { if (r != 0) {
LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r)); LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r));
stop_session(inst); stop_session(inst);
return 0; return 0;
@ -446,6 +432,7 @@ handle_event(NKSN_Instance inst, int event)
case KE_SEND: case KE_SEND:
assert(inst->new_message && message->complete); assert(inst->new_message && message->complete);
assert(message->length <= sizeof (message->data) && message->length > message->sent);
r = gnutls_record_send(inst->tls_session, &message->data[message->sent], r = gnutls_record_send(inst->tls_session, &message->data[message->sent],
message->length - message->sent); message->length - message->sent);
@ -513,7 +500,9 @@ handle_event(NKSN_Instance inst, int event)
/* Server will send a response to the client */ /* Server will send a response to the client */
change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN); change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN);
break;
/* Return success to process the received message */
return 1;
case KE_SHUTDOWN: case KE_SHUTDOWN:
r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR); r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR);
@ -539,9 +528,8 @@ handle_event(NKSN_Instance inst, int event)
default: default:
assert(0); assert(0);
return 0;
} }
return 1;
} }
/* ================================================== */ /* ================================================== */
@ -554,6 +542,9 @@ read_write_socket(int fd, int event, void *arg)
if (!handle_event(inst, event)) if (!handle_event(inst, event))
return; return;
/* A valid message was received. Call the handler to process the message,
and prepare a response if it is a server. */
reset_message_parsing(&inst->message); reset_message_parsing(&inst->message);
if (!(inst->handler)(inst->handler_arg)) { if (!(inst->handler)(inst->handler_arg)) {
@ -602,13 +593,15 @@ init_gnutls(void)
if (r < 0) if (r < 0)
LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r)); LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r));
/* NTS specification requires TLS1.3 or later */ /* Prepare a priority cache for server and client NTS-KE sessions
(the NTS specification requires TLS1.3 or later) */
r = gnutls_priority_init2(&priority_cache, r = gnutls_priority_init2(&priority_cache,
"-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2", "-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2",
NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND); NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND);
if (r < 0) if (r < 0)
LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r)); LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r));
/* Use our clock instead of the system clock in certificate verification */
gnutls_global_set_time_function(get_time); gnutls_global_set_time_function(get_time);
gnutls_initialised = 1; gnutls_initialised = 1;
@ -704,7 +697,7 @@ NKSN_CreateInstance(int server_mode, const char *server_name,
inst->server_name = server_name ? Strdup(server_name) : NULL; inst->server_name = server_name ? Strdup(server_name) : NULL;
inst->handler = handler; inst->handler = handler;
inst->handler_arg = handler_arg; inst->handler_arg = handler_arg;
/* Replace NULL arg with the session itself */ /* Replace a NULL argument with the session itself */
if (!inst->handler_arg) if (!inst->handler_arg)
inst->handler_arg = inst; inst->handler_arg = inst;
@ -751,7 +744,6 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label,
reset_message(&inst->message); reset_message(&inst->message);
inst->new_message = 0; inst->new_message = 0;
inst->ended_message = 0;
change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT); change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT);
@ -785,6 +777,7 @@ NKSN_EndMessage(NKSN_Instance inst)
{ {
assert(!inst->message.complete); assert(!inst->message.complete);
/* Terminate the message */
if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0)) if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0))
return 0; return 0;
@ -806,6 +799,7 @@ NKSN_GetRecord(NKSN_Instance inst, int *critical, int *type, int *body_length,
if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length)) if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length))
return 0; return 0;
/* Hide the end-of-message record */
if (type2 == NKE_RECORD_END_OF_MESSAGE) if (type2 == NKE_RECORD_END_OF_MESSAGE)
return 0; return 0;