diff --git a/nts_ke_client.c b/nts_ke_client.c index 090d713..c9592f0 100644 --- a/nts_ke_client.c +++ b/nts_ke_client.c @@ -339,7 +339,8 @@ NKC_Start(NKC_Instance inst) return 0; /* Start a NTS-KE session */ - if (!NKSN_StartSession(inst->session, sock_fd, client_credentials, CLIENT_TIMEOUT)) { + if (!NKSN_StartSession(inst->session, sock_fd, inst->name, + client_credentials, CLIENT_TIMEOUT)) { SCK_CloseSocket(sock_fd); return 0; } diff --git a/nts_ke_server.c b/nts_ke_server.c index a3fa9f9..3093ee0 100644 --- a/nts_ke_server.c +++ b/nts_ke_server.c @@ -113,7 +113,7 @@ handle_client(int sock_fd, IPSockAddr *addr) instp = ARR_GetElement(sessions, i); if (!*instp) { /* NULL handler arg will be replaced with the session instance */ - inst = NKSN_CreateInstance(1, UTI_IPSockAddrToString(addr), handle_message, NULL); + inst = NKSN_CreateInstance(1, NULL, handle_message, NULL); *instp = inst; break; } else if (NKSN_IsStopped(*instp)) { @@ -128,7 +128,8 @@ handle_client(int sock_fd, IPSockAddr *addr) return 0; } - if (!NKSN_StartSession(inst, sock_fd, server_credentials, SERVER_TIMEOUT)) + if (!NKSN_StartSession(inst, sock_fd, UTI_IPSockAddrToString(addr), + server_credentials, SERVER_TIMEOUT)) return 0; return 1; diff --git a/nts_ke_session.c b/nts_ke_session.c index d9b0d4b..224355f 100644 --- a/nts_ke_session.c +++ b/nts_ke_session.c @@ -66,12 +66,13 @@ typedef enum { struct NKSN_Instance_Record { int server; - char *name; + char *server_name; NKSN_MessageHandler handler; void *handler_arg; KeState state; int sock_fd; + char *label; gnutls_session_t tls_session; SCH_TimeoutID timeout_id; @@ -261,6 +262,9 @@ stop_session(NKSN_Instance inst) SCK_CloseSocket(inst->sock_fd); inst->sock_fd = INVALID_SOCK_FD; + Free(inst->label); + inst->label = NULL; + gnutls_deinit(inst->tls_session); inst->tls_session = NULL; @@ -275,7 +279,7 @@ session_timeout(void *arg) { NKSN_Instance inst = arg; - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE session with %s timed out", inst->name); + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE session with %s timed out", inst->label); inst->timeout_id = 0; stop_session(inst); @@ -360,12 +364,12 @@ handle_event(NKSN_Instance inst, int event) r = get_socket_error(inst->sock_fd); if (r) { - LOG(LOGS_ERR, "Could not connect to %s : %s", inst->name, strerror(r)); + LOG(LOGS_ERR, "Could not connect to %s : %s", inst->label, strerror(r)); stop_session(inst); return 0; } - DEBUG_LOG("Connected to %s", inst->name); + DEBUG_LOG("Connected to %s", inst->label); change_state(inst, KE_HANDSHAKE); return 0; @@ -376,7 +380,7 @@ handle_event(NKSN_Instance inst, int event) if (r < 0) { if (gnutls_error_is_fatal(r)) { LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "TLS handshake with %s failed : %s", inst->name, gnutls_strerror(r)); + "TLS handshake with %s failed : %s", inst->label, gnutls_strerror(r)); stop_session(inst); return 0; } @@ -390,12 +394,12 @@ handle_event(NKSN_Instance inst, int event) if (DEBUG) { char *description = gnutls_session_get_desc(inst->tls_session); DEBUG_LOG("Handshake with %s completed %s", - inst->name, description ? description : ""); + inst->label, description ? description : ""); gnutls_free(description); } if (!check_alpn(inst)) { - LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE not supported by %s", inst->name); + LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "NTS-KE not supported by %s", inst->label); stop_session(inst); return 0; } @@ -413,13 +417,13 @@ handle_event(NKSN_Instance inst, int event) if (r < 0) { if (gnutls_error_is_fatal(r)) { LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "Could not send NTS-KE message to %s : %s", inst->name, gnutls_strerror(r)); + "Could not send NTS-KE message to %s : %s", inst->label, gnutls_strerror(r)); stop_session(inst); } return 0; } - DEBUG_LOG("Sent %d bytes to %s", r, inst->name); + DEBUG_LOG("Sent %d bytes to %s", r, inst->label); message->sent += r; if (message->sent < message->length) @@ -448,13 +452,13 @@ handle_event(NKSN_Instance inst, int event) if (gnutls_error_is_fatal(r) || r == GNUTLS_E_REHANDSHAKE) { LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, "Could not receive NTS-KE message from %s : %s", - inst->name, gnutls_strerror(r)); + inst->label, gnutls_strerror(r)); stop_session(inst); } return 0; } - DEBUG_LOG("Received %d bytes from %s", r, inst->name); + DEBUG_LOG("Received %d bytes from %s", r, inst->label); message->length += r; @@ -462,7 +466,7 @@ handle_event(NKSN_Instance inst, int event) if (!check_message_format(message, r == 0)) { LOG(inst->server ? LOGS_DEBUG : LOGS_ERR, - "Received invalid NTS-KE message from %s", inst->name); + "Received invalid NTS-KE message from %s", inst->label); stop_session(inst); return 0; } @@ -480,7 +484,7 @@ handle_event(NKSN_Instance inst, int event) if (r < 0) { if (gnutls_error_is_fatal(r)) { - DEBUG_LOG("Shutdown with %s failed : %s", inst->name, gnutls_strerror(r)); + DEBUG_LOG("Shutdown with %s failed : %s", inst->label, gnutls_strerror(r)); stop_session(inst); return 0; } @@ -620,7 +624,7 @@ NKSN_DestroyCertCredentials(void *credentials) /* ================================================== */ NKSN_Instance -NKSN_CreateInstance(int server_mode, const char *name, +NKSN_CreateInstance(int server_mode, const char *server_name, NKSN_MessageHandler handler, void *handler_arg) { NKSN_Instance inst; @@ -628,7 +632,7 @@ NKSN_CreateInstance(int server_mode, const char *name, inst = MallocNew(struct NKSN_Instance_Record); inst->server = server_mode; - inst->name = Strdup(name); + inst->server_name = server_name ? Strdup(server_name) : NULL; inst->handler = handler; inst->handler_arg = handler_arg; /* Replace NULL arg with the session itself */ @@ -637,6 +641,7 @@ NKSN_CreateInstance(int server_mode, const char *name, inst->state = KE_STOPPED; inst->sock_fd = INVALID_SOCK_FD; + inst->label = NULL; inst->tls_session = NULL; inst->timeout_id = 0; @@ -650,19 +655,19 @@ NKSN_DestroyInstance(NKSN_Instance inst) { stop_session(inst); - Free(inst->name); + Free(inst->server_name); Free(inst); } /* ================================================== */ int -NKSN_StartSession(NKSN_Instance inst, int sock_fd, void *credentials, double timeout) +NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label, + void *credentials, double timeout) { assert(inst->state == KE_STOPPED); - inst->tls_session = create_tls_session(inst->server, sock_fd, - inst->server ? NULL : inst->name, + inst->tls_session = create_tls_session(inst->server, sock_fd, inst->server_name, credentials, priority_cache); if (!inst->tls_session) return 0; @@ -670,6 +675,7 @@ NKSN_StartSession(NKSN_Instance inst, int sock_fd, void *credentials, double tim inst->sock_fd = sock_fd; SCH_AddFileHandler(sock_fd, SCH_FILE_INPUT, read_write_socket, inst); + inst->label = Strdup(label); inst->timeout_id = SCH_AddTimeoutByDelay(timeout, session_timeout, inst); reset_message(&inst->message); diff --git a/nts_ke_session.h b/nts_ke_session.h index cbfb7f5..bf1b469 100644 --- a/nts_ke_session.h +++ b/nts_ke_session.h @@ -45,15 +45,15 @@ extern void *NKSN_CreateCertCredentials(char *cert, char *key, char *trusted_cer extern void NKSN_DestroyCertCredentials(void *credentials); /* Create an instance */ -extern NKSN_Instance NKSN_CreateInstance(int server_mode, const char *name, +extern NKSN_Instance NKSN_CreateInstance(int server_mode, const char *server_name, NKSN_MessageHandler handler, void *handler_arg); /* Destroy an instance */ extern void NKSN_DestroyInstance(NKSN_Instance inst); /* Start a new NTS-KE session */ -extern int NKSN_StartSession(NKSN_Instance inst, int sock_fd, void *credentials, - double timeout); +extern int NKSN_StartSession(NKSN_Instance inst, int sock_fd, const char *label, + void *credentials, double timeout); /* Begin an NTS-KE message. A request should be made right after starting the session and response should be made in the message handler. */ diff --git a/test/unit/nts_ke_server.c b/test/unit/nts_ke_server.c index b5660e9..d1628d6 100644 --- a/test/unit/nts_ke_server.c +++ b/test/unit/nts_ke_server.c @@ -156,7 +156,7 @@ test_unit(void) unlink("ntskeys"); NKS_Initialise(0); - session = NKSN_CreateInstance(1, "", handle_message, NULL); + session = NKSN_CreateInstance(1, NULL, handle_message, NULL); for (i = 0; i < 10000; i++) { valid = random() % 2; diff --git a/test/unit/nts_ke_session.c b/test/unit/nts_ke_session.c index 74a6607..b907593 100644 --- a/test/unit/nts_ke_session.c +++ b/test/unit/nts_ke_session.c @@ -142,7 +142,7 @@ test_unit(void) for (i = 0; i < 50; i++) { SCH_Initialise(); - server = NKSN_CreateInstance(1, "client", handle_request, NULL); + server = NKSN_CreateInstance(1, NULL, handle_request, NULL); client = NKSN_CreateInstance(0, "test", handle_response, NULL); server_cred = NKSN_CreateCertCredentials("nts_ke.crt", "nts_ke.key", NULL); @@ -152,8 +152,8 @@ test_unit(void) TEST_CHECK(fcntl(sock_fds[0], F_SETFL, O_NONBLOCK) == 0); TEST_CHECK(fcntl(sock_fds[1], F_SETFL, O_NONBLOCK) == 0); - TEST_CHECK(NKSN_StartSession(server, sock_fds[0], server_cred, 4.0)); - TEST_CHECK(NKSN_StartSession(client, sock_fds[1], client_cred, 4.0)); + TEST_CHECK(NKSN_StartSession(server, sock_fds[0], "client", server_cred, 4.0)); + TEST_CHECK(NKSN_StartSession(client, sock_fds[1], "server", client_cred, 4.0)); send_message(client);