diff --git a/nts_ke_server.c b/nts_ke_server.c index aadb15f..2fd576d 100644 --- a/nts_ke_server.c +++ b/nts_ke_server.c @@ -47,8 +47,7 @@ #define SERVER_TIMEOUT 2.0 -#define SERVER_COOKIE_SIV AEAD_AES_SIV_CMAC_256 -#define SERVER_COOKIE_NONCE_LENGTH 16 +#define MAX_COOKIE_NONCE_LENGTH 16 #define KEY_ID_INDEX_BITS 2 #define MAX_SERVER_KEYS (1U << KEY_ID_INDEX_BITS) @@ -61,17 +60,19 @@ typedef struct { uint32_t key_id; - unsigned char nonce[SERVER_COOKIE_NONCE_LENGTH]; } ServerCookieHeader; typedef struct { uint32_t id; unsigned char key[SIV_MAX_KEY_LENGTH]; + SIV_Algorithm siv_algorithm; SIV_Instance siv; + int nonce_length; } ServerKey; typedef struct { uint32_t key_id; + uint32_t siv_algorithm; unsigned char key[SIV_MAX_KEY_LENGTH]; IPAddr client_addr; uint16_t client_port; @@ -148,6 +149,23 @@ handle_client(int sock_fd, IPSockAddr *addr) /* ================================================== */ +static void +update_key_siv(ServerKey *key, SIV_Algorithm algorithm) +{ + if (!key->siv || key->siv_algorithm != algorithm) { + if (key->siv) + SIV_DestroyInstance(key->siv); + key->siv_algorithm = algorithm; + key->siv = SIV_CreateInstance(algorithm); + key->nonce_length = MIN(SIV_GetMaxNonceLength(key->siv), MAX_COOKIE_NONCE_LENGTH); + } + + if (!key->siv || !SIV_SetKey(key->siv, key->key, SIV_GetKeyLength(key->siv_algorithm))) + LOG_FATAL("Could not set SIV key"); +} + +/* ================================================== */ + static void handle_helper_request(int fd, int event, void *arg) { @@ -189,8 +207,7 @@ handle_helper_request(int fd, int event, void *arg) UTI_IPNetworkToHost(&req->client_addr, &client_addr.ip_addr); client_addr.port = ntohs(req->client_port); - if (!SIV_SetKey(key->siv, key->key, SIV_GetKeyLength(SERVER_COOKIE_SIV))) - LOG_FATAL("Could not set SIV key"); + update_key_siv(key, ntohl(req->siv_algorithm)); if (!handle_client(sock_fd, &client_addr)) { SCK_CloseSocket(sock_fd); @@ -240,6 +257,7 @@ accept_connection(int listening_fd, int event, void *arg) /* Include the current server key and client address in the request */ req.key_id = htonl(server_keys[current_server_key].id); + req.siv_algorithm = htonl(server_keys[current_server_key].siv_algorithm); assert(sizeof (req.key) == sizeof (server_keys[current_server_key].key)); memcpy(req.key, server_keys[current_server_key].key, sizeof (req.key)); UTI_IPHostToNetwork(&addr.ip_addr, &req.client_addr); @@ -471,29 +489,37 @@ handle_message(void *arg) static void generate_key(int index) { + SIV_Algorithm algorithm; ServerKey *key; int key_length; if (index < 0 || index >= MAX_SERVER_KEYS) assert(0); + /* Prefer AES-128-GCM-SIV if available. Note that if older keys loaded + from ntsdumpdir use a different algorithm, responding to NTP requests + with cookies encrypted with those keys will not work if the new algorithm + produces longer cookies (i.e. response would be longer than request). + Switching from AES-SIV-CMAC-256 to AES-128-GCM-SIV is ok. */ + algorithm = SIV_GetKeyLength(AEAD_AES_128_GCM_SIV) > 0 ? + AEAD_AES_128_GCM_SIV : AEAD_AES_SIV_CMAC_256; + key = &server_keys[index]; - key_length = SIV_GetKeyLength(SERVER_COOKIE_SIV); + + key_length = SIV_GetKeyLength(algorithm); if (key_length > sizeof (key->key)) assert(0); UTI_GetRandomBytesUrandom(key->key, key_length); - - if (!key->siv || !SIV_SetKey(key->siv, key->key, key_length)) - LOG_FATAL("Could not set SIV key"); - UTI_GetRandomBytes(&key->id, sizeof (key->id)); /* Encode the index in the lowest bits of the ID */ key->id &= -1U << KEY_ID_INDEX_BITS; key->id |= index; - DEBUG_LOG("Generated server key %"PRIX32, key->id); + update_key_siv(key, algorithm); + + DEBUG_LOG("Generated key %"PRIX32" (%d)", key->id, (int)key->siv_algorithm); last_server_key_ts = SCH_GetLastEventMonoTime(); } @@ -521,16 +547,18 @@ save_keys(void) if (!f) return; - key_length = SIV_GetKeyLength(SERVER_COOKIE_SIV); + key_length = SIV_GetKeyLength(server_keys[0].siv_algorithm); last_key_age = SCH_GetLastEventMonoTime() - last_server_key_ts; - if (fprintf(f, "%s%d %.1f\n", DUMP_IDENTIFIER, SERVER_COOKIE_SIV, last_key_age) < 0) + if (fprintf(f, "%s%d %.1f\n", DUMP_IDENTIFIER, (int)server_keys[0].siv_algorithm, + last_key_age) < 0) goto error; for (i = 0; i < MAX_SERVER_KEYS; i++) { index = (current_server_key + i + 1 + FUTURE_KEYS) % MAX_SERVER_KEYS; if (key_length > sizeof (server_keys[index].key) || + server_keys[index].siv_algorithm != server_keys[0].siv_algorithm || !UTI_BytesToHex(server_keys[index].key, key_length, buf, sizeof (buf)) || fprintf(f, "%08"PRIX32" %s\n", server_keys[index].id, buf) < 0) goto error; @@ -578,11 +606,11 @@ load_keys(void) if (!fgets(line, sizeof (line), f) || strcmp(line, DUMP_IDENTIFIER) != 0 || !fgets(line, sizeof (line), f) || UTI_SplitString(line, words, MAX_WORDS) != 2 || - sscanf(words[0], "%d", &algorithm) != 1 || algorithm != SERVER_COOKIE_SIV || + sscanf(words[0], "%d", &algorithm) != 1 || SIV_GetKeyLength(algorithm) <= 0 || sscanf(words[1], "%lf", &key_age) != 1) goto error; - key_length = SIV_GetKeyLength(SERVER_COOKIE_SIV); + key_length = SIV_GetKeyLength(algorithm); last_server_key_ts = SCH_GetLastEventMonoTime() - MAX(key_age, 0.0); for (i = 0; i < MAX_SERVER_KEYS && fgets(line, sizeof (line), f); i++) { @@ -599,10 +627,9 @@ load_keys(void) assert(sizeof (server_keys[index].key) == sizeof (key)); memcpy(server_keys[index].key, key, key_length); - if (!SIV_SetKey(server_keys[index].siv, server_keys[index].key, key_length)) - LOG_FATAL("Could not set SIV key"); + update_key_siv(&server_keys[index], algorithm); - DEBUG_LOG("Loaded key %"PRIX32, id); + DEBUG_LOG("Loaded key %"PRIX32" (%d)", id, (int)algorithm); current_server_key = (index + MAX_SERVER_KEYS - FUTURE_KEYS) % MAX_SERVER_KEYS; } @@ -761,7 +788,7 @@ NKS_Initialise(void) /* Generate random keys, even if they will be replaced by reloaded keys, or unused (in the helper) */ for (i = 0; i < MAX_SERVER_KEYS; i++) { - server_keys[i].siv = SIV_CreateInstance(SERVER_COOKIE_SIV); + server_keys[i].siv = NULL; generate_key(i); } @@ -854,7 +881,7 @@ NKS_ReloadKeys(void) int NKS_GenerateCookie(NKE_Context *context, NKE_Cookie *cookie) { - unsigned char plaintext[2 * NKE_MAX_KEY_LENGTH], *ciphertext; + unsigned char *nonce, plaintext[2 * NKE_MAX_KEY_LENGTH], *ciphertext; int plaintext_length, tag_length; ServerCookieHeader *header; ServerKey *key; @@ -879,7 +906,11 @@ NKS_GenerateCookie(NKE_Context *context, NKE_Cookie *cookie) header = (ServerCookieHeader *)cookie->cookie; header->key_id = htonl(key->id); - UTI_GetRandomBytes(header->nonce, sizeof (header->nonce)); + + nonce = cookie->cookie + sizeof (*header); + if (key->nonce_length > sizeof (cookie->cookie) - sizeof (*header)) + assert(0); + UTI_GetRandomBytes(nonce, key->nonce_length); plaintext_length = context->c2s.length + context->s2c.length; assert(plaintext_length <= sizeof (plaintext)); @@ -887,11 +918,11 @@ NKS_GenerateCookie(NKE_Context *context, NKE_Cookie *cookie) memcpy(plaintext + context->c2s.length, context->s2c.key, context->s2c.length); tag_length = SIV_GetTagLength(key->siv); - cookie->length = sizeof (*header) + plaintext_length + tag_length; + cookie->length = sizeof (*header) + key->nonce_length + plaintext_length + tag_length; assert(cookie->length <= sizeof (cookie->cookie)); - ciphertext = cookie->cookie + sizeof (*header); + ciphertext = cookie->cookie + sizeof (*header) + key->nonce_length; - if (!SIV_Encrypt(key->siv, header->nonce, sizeof (header->nonce), + if (!SIV_Encrypt(key->siv, nonce, key->nonce_length, "", 0, plaintext, plaintext_length, ciphertext, plaintext_length + tag_length)) { @@ -907,7 +938,7 @@ NKS_GenerateCookie(NKE_Context *context, NKE_Cookie *cookie) int NKS_DecodeCookie(NKE_Cookie *cookie, NKE_Context *context) { - unsigned char plaintext[2 * NKE_MAX_KEY_LENGTH], *ciphertext; + unsigned char *nonce, plaintext[2 * NKE_MAX_KEY_LENGTH], *ciphertext; int ciphertext_length, plaintext_length, tag_length; ServerCookieHeader *header; ServerKey *key; @@ -924,8 +955,6 @@ NKS_DecodeCookie(NKE_Cookie *cookie, NKE_Context *context) } header = (ServerCookieHeader *)cookie->cookie; - ciphertext = cookie->cookie + sizeof (*header); - ciphertext_length = cookie->length - sizeof (*header); key_id = ntohl(header->key_id); key = &server_keys[key_id % MAX_SERVER_KEYS]; @@ -935,18 +964,23 @@ NKS_DecodeCookie(NKE_Cookie *cookie, NKE_Context *context) } tag_length = SIV_GetTagLength(key->siv); - if (tag_length >= ciphertext_length) { + + if (cookie->length <= (int)sizeof (*header) + key->nonce_length + tag_length) { DEBUG_LOG("Invalid cookie length"); return 0; } + nonce = cookie->cookie + sizeof (*header); + ciphertext = cookie->cookie + sizeof (*header) + key->nonce_length; + ciphertext_length = cookie->length - sizeof (*header) - key->nonce_length; plaintext_length = ciphertext_length - tag_length; + if (plaintext_length > sizeof (plaintext) || plaintext_length % 2 != 0) { DEBUG_LOG("Invalid cookie length"); return 0; } - if (!SIV_Decrypt(key->siv, header->nonce, sizeof (header->nonce), + if (!SIV_Decrypt(key->siv, nonce, key->nonce_length, "", 0, ciphertext, ciphertext_length, plaintext, plaintext_length)) {