nts: improve session code

Add more comments and assertions, replace getsockopt() call with
SCK_GetIntOption(), replace strncmp() with memcmp(), move a return
statement for clarity, and remove an unused field from the instance
record.
This commit is contained in:
Miroslav Lichvar 2020-07-07 12:34:29 +02:00
parent b0f5024d56
commit a3436c26f0

View file

@ -81,7 +81,6 @@ struct NKSN_Instance_Record {
struct Message message;
int new_message;
int ended_message;
};
/* ================================================== */
@ -110,6 +109,8 @@ add_record(struct Message *message, int critical, int type, const void *body, in
{
struct RecordHeader header;
assert(message->length <= sizeof (message->data));
if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff ||
message->length + sizeof (header) + body_length > sizeof (message->data))
return 0;
@ -301,31 +302,14 @@ session_timeout(void *arg)
/* ================================================== */
static int
get_socket_error(int sock_fd)
{
int optval;
socklen_t optlen = sizeof (optval);
if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
DEBUG_LOG("getsockopt() failed : %s", strerror(errno));
return EINVAL;
}
return optval;
}
/* ================================================== */
static int
check_alpn(NKSN_Instance inst)
{
gnutls_datum_t alpn;
int r;
r = gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn);
if (r < 0 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
strncmp((const char *)alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1))
if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 ||
alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0)
return 0;
return 1;
@ -375,9 +359,11 @@ handle_event(NKSN_Instance inst, int event)
if (event != SCH_FILE_OUTPUT)
return 0;
r = get_socket_error(inst->sock_fd);
/* Get the socket error */
if (!SCK_GetIntOption(inst->sock_fd, SOL_SOCKET, SO_ERROR, &r))
r = EINVAL;
if (r) {
if (r != 0) {
LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r));
stop_session(inst);
return 0;
@ -446,6 +432,7 @@ handle_event(NKSN_Instance inst, int event)
case KE_SEND:
assert(inst->new_message && message->complete);
assert(message->length <= sizeof (message->data) && message->length > message->sent);
r = gnutls_record_send(inst->tls_session, &message->data[message->sent],
message->length - message->sent);
@ -513,7 +500,9 @@ handle_event(NKSN_Instance inst, int event)
/* Server will send a response to the client */
change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN);
break;
/* Return success to process the received message */
return 1;
case KE_SHUTDOWN:
r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR);
@ -539,9 +528,8 @@ handle_event(NKSN_Instance inst, int event)
default:
assert(0);
return 0;
}
return 1;
}
/* ================================================== */
@ -554,6 +542,9 @@ read_write_socket(int fd, int event, void *arg)
if (!handle_event(inst, event))
return;
/* A valid message was received. Call the handler to process the message,
and prepare a response if it is a server. */
reset_message_parsing(&inst->message);
if (!(inst->handler)(inst->handler_arg)) {
@ -602,13 +593,15 @@ init_gnutls(void)
if (r < 0)
LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r));
/* NTS specification requires TLS1.3 or later */
/* Prepare a priority cache for server and client NTS-KE sessions
(the NTS specification requires TLS1.3 or later) */
r = gnutls_priority_init2(&priority_cache,
"-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2",
NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND);
if (r < 0)
LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r));
/* Use our clock instead of the system clock in certificate verification */
gnutls_global_set_time_function(get_time);
gnutls_initialised = 1;
@ -704,7 +697,7 @@ NKSN_CreateInstance(int server_mode, const char *server_name,
inst->server_name = server_name ? Strdup(server_name) : NULL;
inst->handler = handler;
inst->handler_arg = handler_arg;
/* Replace NULL arg with the session itself */
/* Replace a NULL argument with the session itself */
if (!inst->handler_arg)
inst->handler_arg = inst;
@ -751,7 +744,6 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label,
reset_message(&inst->message);
inst->new_message = 0;
inst->ended_message = 0;
change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT);
@ -785,6 +777,7 @@ NKSN_EndMessage(NKSN_Instance inst)
{
assert(!inst->message.complete);
/* Terminate the message */
if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0))
return 0;
@ -806,6 +799,7 @@ NKSN_GetRecord(NKSN_Instance inst, int *critical, int *type, int *body_length,
if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length))
return 0;
/* Hide the end-of-message record */
if (type2 == NKE_RECORD_END_OF_MESSAGE)
return 0;