nts: improve session code
Add more comments and assertions, replace getsockopt() call with SCK_GetIntOption(), replace strncmp() with memcmp(), move a return statement for clarity, and remove an unused field from the instance record.
This commit is contained in:
parent
b0f5024d56
commit
a3436c26f0
1 changed files with 23 additions and 29 deletions
|
@ -81,7 +81,6 @@ struct NKSN_Instance_Record {
|
||||||
|
|
||||||
struct Message message;
|
struct Message message;
|
||||||
int new_message;
|
int new_message;
|
||||||
int ended_message;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
@ -110,6 +109,8 @@ add_record(struct Message *message, int critical, int type, const void *body, in
|
||||||
{
|
{
|
||||||
struct RecordHeader header;
|
struct RecordHeader header;
|
||||||
|
|
||||||
|
assert(message->length <= sizeof (message->data));
|
||||||
|
|
||||||
if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff ||
|
if (body_length < 0 || body_length > 0xffff || type < 0 || type > 0x7fff ||
|
||||||
message->length + sizeof (header) + body_length > sizeof (message->data))
|
message->length + sizeof (header) + body_length > sizeof (message->data))
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -301,31 +302,14 @@ session_timeout(void *arg)
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
|
||||||
static int
|
|
||||||
get_socket_error(int sock_fd)
|
|
||||||
{
|
|
||||||
int optval;
|
|
||||||
socklen_t optlen = sizeof (optval);
|
|
||||||
|
|
||||||
if (getsockopt(sock_fd, SOL_SOCKET, SO_ERROR, &optval, &optlen) < 0) {
|
|
||||||
DEBUG_LOG("getsockopt() failed : %s", strerror(errno));
|
|
||||||
return EINVAL;
|
|
||||||
}
|
|
||||||
|
|
||||||
return optval;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ================================================== */
|
|
||||||
|
|
||||||
static int
|
static int
|
||||||
check_alpn(NKSN_Instance inst)
|
check_alpn(NKSN_Instance inst)
|
||||||
{
|
{
|
||||||
gnutls_datum_t alpn;
|
gnutls_datum_t alpn;
|
||||||
int r;
|
|
||||||
|
|
||||||
r = gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn);
|
if (gnutls_alpn_get_selected_protocol(inst->tls_session, &alpn) < 0 ||
|
||||||
if (r < 0 || alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
|
alpn.size != sizeof (NKE_ALPN_NAME) - 1 ||
|
||||||
strncmp((const char *)alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1))
|
memcmp(alpn.data, NKE_ALPN_NAME, sizeof (NKE_ALPN_NAME) - 1) != 0)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -375,9 +359,11 @@ handle_event(NKSN_Instance inst, int event)
|
||||||
if (event != SCH_FILE_OUTPUT)
|
if (event != SCH_FILE_OUTPUT)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
r = get_socket_error(inst->sock_fd);
|
/* Get the socket error */
|
||||||
|
if (!SCK_GetIntOption(inst->sock_fd, SOL_SOCKET, SO_ERROR, &r))
|
||||||
|
r = EINVAL;
|
||||||
|
|
||||||
if (r) {
|
if (r != 0) {
|
||||||
LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r));
|
LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r));
|
||||||
stop_session(inst);
|
stop_session(inst);
|
||||||
return 0;
|
return 0;
|
||||||
|
@ -446,6 +432,7 @@ handle_event(NKSN_Instance inst, int event)
|
||||||
|
|
||||||
case KE_SEND:
|
case KE_SEND:
|
||||||
assert(inst->new_message && message->complete);
|
assert(inst->new_message && message->complete);
|
||||||
|
assert(message->length <= sizeof (message->data) && message->length > message->sent);
|
||||||
|
|
||||||
r = gnutls_record_send(inst->tls_session, &message->data[message->sent],
|
r = gnutls_record_send(inst->tls_session, &message->data[message->sent],
|
||||||
message->length - message->sent);
|
message->length - message->sent);
|
||||||
|
@ -513,7 +500,9 @@ handle_event(NKSN_Instance inst, int event)
|
||||||
|
|
||||||
/* Server will send a response to the client */
|
/* Server will send a response to the client */
|
||||||
change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN);
|
change_state(inst, inst->server ? KE_SEND : KE_SHUTDOWN);
|
||||||
break;
|
|
||||||
|
/* Return success to process the received message */
|
||||||
|
return 1;
|
||||||
|
|
||||||
case KE_SHUTDOWN:
|
case KE_SHUTDOWN:
|
||||||
r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR);
|
r = gnutls_bye(inst->tls_session, GNUTLS_SHUT_RDWR);
|
||||||
|
@ -539,9 +528,8 @@ handle_event(NKSN_Instance inst, int event)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
assert(0);
|
assert(0);
|
||||||
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
return 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ================================================== */
|
/* ================================================== */
|
||||||
|
@ -554,6 +542,9 @@ read_write_socket(int fd, int event, void *arg)
|
||||||
if (!handle_event(inst, event))
|
if (!handle_event(inst, event))
|
||||||
return;
|
return;
|
||||||
|
|
||||||
|
/* A valid message was received. Call the handler to process the message,
|
||||||
|
and prepare a response if it is a server. */
|
||||||
|
|
||||||
reset_message_parsing(&inst->message);
|
reset_message_parsing(&inst->message);
|
||||||
|
|
||||||
if (!(inst->handler)(inst->handler_arg)) {
|
if (!(inst->handler)(inst->handler_arg)) {
|
||||||
|
@ -602,13 +593,15 @@ init_gnutls(void)
|
||||||
if (r < 0)
|
if (r < 0)
|
||||||
LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r));
|
LOG_FATAL("Could not initialise %s : %s", "gnutls", gnutls_strerror(r));
|
||||||
|
|
||||||
/* NTS specification requires TLS1.3 or later */
|
/* Prepare a priority cache for server and client NTS-KE sessions
|
||||||
|
(the NTS specification requires TLS1.3 or later) */
|
||||||
r = gnutls_priority_init2(&priority_cache,
|
r = gnutls_priority_init2(&priority_cache,
|
||||||
"-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2",
|
"-VERS-SSL3.0:-VERS-TLS1.0:-VERS-TLS1.1:-VERS-TLS1.2",
|
||||||
NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND);
|
NULL, GNUTLS_PRIORITY_INIT_DEF_APPEND);
|
||||||
if (r < 0)
|
if (r < 0)
|
||||||
LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r));
|
LOG_FATAL("Could not initialise %s : %s", "priority cache", gnutls_strerror(r));
|
||||||
|
|
||||||
|
/* Use our clock instead of the system clock in certificate verification */
|
||||||
gnutls_global_set_time_function(get_time);
|
gnutls_global_set_time_function(get_time);
|
||||||
|
|
||||||
gnutls_initialised = 1;
|
gnutls_initialised = 1;
|
||||||
|
@ -704,7 +697,7 @@ NKSN_CreateInstance(int server_mode, const char *server_name,
|
||||||
inst->server_name = server_name ? Strdup(server_name) : NULL;
|
inst->server_name = server_name ? Strdup(server_name) : NULL;
|
||||||
inst->handler = handler;
|
inst->handler = handler;
|
||||||
inst->handler_arg = handler_arg;
|
inst->handler_arg = handler_arg;
|
||||||
/* Replace NULL arg with the session itself */
|
/* Replace a NULL argument with the session itself */
|
||||||
if (!inst->handler_arg)
|
if (!inst->handler_arg)
|
||||||
inst->handler_arg = inst;
|
inst->handler_arg = inst;
|
||||||
|
|
||||||
|
@ -751,7 +744,6 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label,
|
||||||
|
|
||||||
reset_message(&inst->message);
|
reset_message(&inst->message);
|
||||||
inst->new_message = 0;
|
inst->new_message = 0;
|
||||||
inst->ended_message = 0;
|
|
||||||
|
|
||||||
change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT);
|
change_state(inst, inst->server ? KE_HANDSHAKE : KE_WAIT_CONNECT);
|
||||||
|
|
||||||
|
@ -785,6 +777,7 @@ NKSN_EndMessage(NKSN_Instance inst)
|
||||||
{
|
{
|
||||||
assert(!inst->message.complete);
|
assert(!inst->message.complete);
|
||||||
|
|
||||||
|
/* Terminate the message */
|
||||||
if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0))
|
if (!add_record(&inst->message, 1, NKE_RECORD_END_OF_MESSAGE, NULL, 0))
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
|
@ -806,6 +799,7 @@ NKSN_GetRecord(NKSN_Instance inst, int *critical, int *type, int *body_length,
|
||||||
if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length))
|
if (!get_record(&inst->message, critical, &type2, body_length, body, buffer_length))
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
|
/* Hide the end-of-message record */
|
||||||
if (type2 == NKE_RECORD_END_OF_MESSAGE)
|
if (type2 == NKE_RECORD_END_OF_MESSAGE)
|
||||||
return 0;
|
return 0;
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue