diff --git a/client.c b/client.c index f418c73..5cbd402 100644 --- a/client.c +++ b/client.c @@ -2985,7 +2985,7 @@ process_cmd_keygen(char *line) ; #ifdef HAVE_CMAC - cmac_length = CMC_GetKeyLength(type); + cmac_length = CMC_GetKeyLength(UTI_CmacNameToAlgorithm(type)); #else cmac_length = 0; #endif diff --git a/cmac.h b/cmac.h index c621d20..206edc1 100644 --- a/cmac.h +++ b/cmac.h @@ -28,10 +28,17 @@ #ifndef GOT_CMAC_H #define GOT_CMAC_H +/* Avoid overlapping with the hash enumeration */ +typedef enum { + CMC_INVALID = 0, + CMC_AES128 = 13, + CMC_AES256 = 14, +} CMC_Algorithm; + typedef struct CMC_Instance_Record *CMC_Instance; -extern unsigned int CMC_GetKeyLength(const char *cipher); -extern CMC_Instance CMC_CreateInstance(const char *cipher, const unsigned char *key, +extern unsigned int CMC_GetKeyLength(CMC_Algorithm algorithm); +extern CMC_Instance CMC_CreateInstance(CMC_Algorithm algorithm, const unsigned char *key, unsigned int length); extern unsigned int CMC_Hash(CMC_Instance inst, const unsigned char *in, unsigned int in_len, unsigned char *out, unsigned int out_len); diff --git a/cmac_nettle.c b/cmac_nettle.c index 46aadce..239d006 100644 --- a/cmac_nettle.c +++ b/cmac_nettle.c @@ -45,11 +45,11 @@ struct CMC_Instance_Record { /* ================================================== */ unsigned int -CMC_GetKeyLength(const char *cipher) +CMC_GetKeyLength(CMC_Algorithm algorithm) { - if (strcmp(cipher, "AES128") == 0) + if (algorithm == CMC_AES128) return AES128_KEY_SIZE; - else if (strcmp(cipher, "AES256") == 0) + else if (algorithm == CMC_AES256) return AES256_KEY_SIZE; return 0; } @@ -57,11 +57,11 @@ CMC_GetKeyLength(const char *cipher) /* ================================================== */ CMC_Instance -CMC_CreateInstance(const char *cipher, const unsigned char *key, unsigned int length) +CMC_CreateInstance(CMC_Algorithm algorithm, const unsigned char *key, unsigned int length) { CMC_Instance inst; - if (length == 0 || length != CMC_GetKeyLength(cipher)) + if (length == 0 || length != CMC_GetKeyLength(algorithm)) return NULL; inst = MallocNew(struct CMC_Instance_Record); diff --git a/keys.c b/keys.c index 749ca38..f74626e 100644 --- a/keys.c +++ b/keys.c @@ -202,6 +202,7 @@ KEY_Reload(void) char line[2048], *key_file, *key_value; const char *key_type; HSH_Algorithm hash_algorithm; + CMC_Algorithm cmac_algorithm; int hash_id; Key key; @@ -239,8 +240,8 @@ KEY_Reload(void) continue; } - cmac_key_length = CMC_GetKeyLength(key_type); hash_algorithm = UTI_HashNameToAlgorithm(key_type); + cmac_algorithm = UTI_CmacNameToAlgorithm(key_type); if (hash_algorithm != 0) { hash_id = HSH_GetHashId(hash_algorithm); @@ -253,18 +254,23 @@ KEY_Reload(void) memcpy(key.data.ntp_mac.value, key_value, key_length); key.data.ntp_mac.length = key_length; key.data.ntp_mac.hash_id = hash_id; - } else if (cmac_key_length > 0) { - if (cmac_key_length != key_length) { + } else if (cmac_algorithm != 0) { + cmac_key_length = CMC_GetKeyLength(cmac_algorithm); + if (cmac_key_length == 0) { + LOG(LOGS_WARN, "Unsupported %s in key %"PRIu32, "cipher", key.id); + continue; + } else if (cmac_key_length != key_length) { LOG(LOGS_WARN, "Invalid length of %s key %"PRIu32" (expected %u bits)", key_type, key.id, 8 * cmac_key_length); continue; } key.class = CMAC; - key.data.cmac = CMC_CreateInstance(key_type, (unsigned char *)key_value, key_length); + key.data.cmac = CMC_CreateInstance(cmac_algorithm, (unsigned char *)key_value, + key_length); assert(key.data.cmac); } else { - LOG(LOGS_WARN, "Unknown hash function or cipher in key %"PRIu32, key.id); + LOG(LOGS_WARN, "Invalid type in key %"PRIu32, key.id); continue; } diff --git a/stubs.c b/stubs.c index f773114..35c612f 100644 --- a/stubs.c +++ b/stubs.c @@ -432,13 +432,13 @@ NSD_SignAndSendPacket(uint32_t key_id, NTP_Packet *packet, NTP_PacketInfo *info, #ifndef HAVE_CMAC unsigned int -CMC_GetKeyLength(const char *cipher) +CMC_GetKeyLength(CMC_Algorithm algorithm) { return 0; } CMC_Instance -CMC_CreateInstance(const char *cipher, const unsigned char *key, unsigned int length) +CMC_CreateInstance(CMC_Algorithm algorithm, const unsigned char *key, unsigned int length) { return NULL; } diff --git a/test/unit/cmac.c b/test/unit/cmac.c index d639220..07fc3d5 100644 --- a/test/unit/cmac.c +++ b/test/unit/cmac.c @@ -22,6 +22,7 @@ #include #include #include +#include #include "test.h" #define MAX_KEY_LENGTH 64 @@ -49,6 +50,7 @@ test_unit(void) { "", "", 0, "", 0 } }; + CMC_Algorithm algorithm; CMC_Instance inst; unsigned int length; int i, j; @@ -57,21 +59,25 @@ test_unit(void) TEST_REQUIRE(0); #endif + TEST_CHECK(CMC_INVALID == 0); + for (i = 0; tests[i].name[0] != '\0'; i++) { - TEST_CHECK(CMC_GetKeyLength(tests[i].name) == tests[i].key_length); + algorithm = UTI_CmacNameToAlgorithm(tests[i].name); + TEST_CHECK(algorithm != 0); + TEST_CHECK(CMC_GetKeyLength(algorithm) == tests[i].key_length); DEBUG_LOG("testing %s", tests[i].name); for (j = 0; j <= 128; j++) { if (j == tests[i].key_length) continue; - TEST_CHECK(!CMC_CreateInstance(tests[i].name, tests[i].key, j)); + TEST_CHECK(!CMC_CreateInstance(algorithm, tests[i].key, j)); } - inst = CMC_CreateInstance(tests[i].name, tests[i].key, tests[i].key_length); + inst = CMC_CreateInstance(algorithm, tests[i].key, tests[i].key_length); TEST_CHECK(inst); - TEST_CHECK(!CMC_CreateInstance("nosuchcipher", tests[i].key, tests[i].key_length)); + TEST_CHECK(!CMC_CreateInstance(0, tests[i].key, tests[i].key_length)); for (j = 0; j <= sizeof (hash); j++) { memset(hash, 0, sizeof (hash)); diff --git a/util.c b/util.c index a71b8fc..9743b99 100644 --- a/util.c +++ b/util.c @@ -927,6 +927,18 @@ UTI_FloatHostToNetwork(double x) /* ================================================== */ +CMC_Algorithm +UTI_CmacNameToAlgorithm(const char *name) +{ + if (strcmp(name, "AES128") == 0) + return CMC_AES128; + else if (strcmp(name, "AES256") == 0) + return CMC_AES256; + return CMC_INVALID; +} + +/* ================================================== */ + HSH_Algorithm UTI_HashNameToAlgorithm(const char *name) { diff --git a/util.h b/util.h index f3cb99f..d52a453 100644 --- a/util.h +++ b/util.h @@ -32,6 +32,7 @@ #include "addressing.h" #include "ntp.h" #include "candm.h" +#include "cmac.h" #include "hash.h" /* Zero a timespec */ @@ -165,6 +166,7 @@ extern void UTI_TimespecHostToNetwork(const struct timespec *src, Timespec *dest extern double UTI_FloatNetworkToHost(Float x); extern Float UTI_FloatHostToNetwork(double x); +extern CMC_Algorithm UTI_CmacNameToAlgorithm(const char *name); extern HSH_Algorithm UTI_HashNameToAlgorithm(const char *name); /* Set FD_CLOEXEC on descriptor */