diff --git a/CMakeLists.txt b/CMakeLists.txt index 16d777a..915326d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -54,7 +54,6 @@ add_executable(M2-PT-DRP source/packets/base/PacketContent.cpp source/packets/base/PacketContent.hpp source/packets/base/SecurityMode.hpp - source/packets/base/PacketData.hpp source/packets/info/InfoPacketData.hpp source/utils/time/Chrony.cpp source/utils/time/Chrony.hpp @@ -64,8 +63,21 @@ add_executable(M2-PT-DRP source/Peer.cpp source/RemotePeer.cpp source/Context.cpp - source/utils/crypto.cpp - source/utils/crypto.hpp + source/test.cpp + source/utils/crypto/aes/AesKey.cpp + source/utils/crypto/aes/AesKey.hpp + source/utils/crypto/rsa/RsaPublicKey.cpp + source/utils/crypto/rsa/RsaPublicKey.hpp + source/utils/crypto/rsa/RsaPrivateKey.cpp + source/utils/crypto/rsa/RsaPrivateKey.hpp + source/utils/crypto/rsa/RsaPrivateKey.hpp + source/utils/crypto/rsa/RsaKeyPair.cpp + source/utils/crypto/rsa/RsaKeyPair.hpp + source/utils/serialize/basics.cpp + source/utils/serialize/basics.hpp + source/packets/audio/AudioPacketData.cpp + source/packets/info/InfoPacketData.cpp + source/packets/search/SearchPacketData.cpp ) target_include_directories(M2-PT-DRP PRIVATE source diff --git a/source/Context.cpp b/source/Context.cpp index 4f33a23..d365f1f 100644 --- a/source/Context.cpp +++ b/source/Context.cpp @@ -1,8 +1,13 @@ #include "Context.hpp" +#include "utils/crypto/rsa/RsaKeyPair.hpp" -Context::Context(const std::array &privateKey, const std::array &publicKey) : me(publicKey) { - this->cryptoRsaPrivateKey = privateKey; + +Context::Context() { + const auto keyPair = drp::util::crypto::RsaKeyPair(2048); + + this->me = Peer(keyPair.getPublicKey()); + this->cryptoRsaPrivateKey = keyPair.getPrivateKey(); this->socket = -1; this->broadcastAddressInfo = nullptr; diff --git a/source/Context.hpp b/source/Context.hpp index 5c55adb..f900649 100644 --- a/source/Context.hpp +++ b/source/Context.hpp @@ -5,6 +5,8 @@ #include #include "RemotePeer.hpp" +#include "utils/crypto/aes/AesKey.hpp" +#include "utils/crypto/rsa/RsaPrivateKey.hpp" /** @@ -14,7 +16,7 @@ */ class Context { public: - explicit Context(const std::array& privateKey, const std::array& publicKey); + explicit Context(); int socket; /// current socket file descriptor, used to communicate addrinfo* broadcastAddressInfo; /// address used to broadcast messages @@ -24,5 +26,6 @@ public: std::list> remotePeers {}; /// information about other machines std::chrono::high_resolution_clock::time_point latestPeerDiscovery; /// time of the latest discovered machine - std::array cryptoRsaPrivateKey {}; /// the RSA private key + drp::util::crypto::RsaPrivateKey cryptoRsaPrivateKey {}; /// the RSA private key + drp::util::crypto::AesKey256 cryptoAesKey = {}; /// the AES secret key }; diff --git a/source/Manager.cpp b/source/Manager.cpp index b112d18..1a69589 100644 --- a/source/Manager.cpp +++ b/source/Manager.cpp @@ -20,14 +20,14 @@ #include "behavior/tasks/client/ClientTask.hpp" #include "behavior/tasks/server/ServerTask.hpp" #include "behavior/tasks/undefined/UndefinedTask.hpp" -#include "utils/crypto.hpp" +#include "utils/crypto/aes/AesKey.hpp" +#include "utils/crypto/rsa/RsaKeyPair.hpp" Manager::Manager(const std::string& address, const std::string& port, const bool useIpv6) { std::cout << "Broadcast address: " << address << ":" << port << " (" << (useIpv6 ? "IPv6" : "IPv4") << ")" << std::endl; - auto [privateKey, publicKey] = newRsaKeys<2048>(); - this->context = std::make_shared(privateKey, publicKey); + this->context = std::make_shared(); // register the different events type this->eventRegistry = { @@ -61,7 +61,7 @@ Manager::Manager(const std::string& address, const std::string& port, const bool throw std::runtime_error("Could not create the socket: " + std::string(strerror(errno))); // allow IPv6 multicast loopback so that we can receive our own messages. - const int socketLoopback = 1; + constexpr int socketLoopback = 1; if (setsockopt( context->socket, IPPROTO_IPV6, @@ -123,7 +123,7 @@ void Manager::loop() { } -void Manager::loopSender() { +void Manager::loopSender() const { while (true) { std::cout << "[Sender] Handling status: " + std::to_string(static_cast(this->context->me.status)) << std::endl; @@ -142,20 +142,20 @@ void Manager::loopSender() { } -void Manager::loopReceiver() { +void Manager::loopReceiver() const { // prepare space for the sender address sockaddr_storage fromAddress {}; socklen_t fromAddressLength = sizeof(fromAddress); - drp::packet::base::Packet packet {}; - drp::packet::base::PacketContent packetContent {}; + + std::array buffer {}; // client loop while (true) { // receive new data const ssize_t size = recvfrom( this->context->socket, - &packet, - sizeof(packet), + buffer.data(), + buffer.size(), 0, reinterpret_cast(&fromAddress), &fromAddressLength @@ -163,13 +163,17 @@ void Manager::loopReceiver() { if (size == -1) throw std::runtime_error("[Receiver] Could not receive the packet: " + std::string(strerror(errno))); + // deserialize the packet + std::vector data(buffer.begin(), buffer.end()); + const auto packet = drp::packet::base::Packet::deserialize(data); + // if the packet channel is neither 0 (all) nor the current one, ignore it if (packet.channel != 0 && packet.channel != this->context->me.channel) continue; // decrypt the packet // TODO(Faraphel): handle exception ? - packetContent = packet.getContent(); + drp::packet::base::PacketContent packetContent = packet.getContent(*this->context); // look for a saved peer with the same address auto remotePeer = std::ranges::find_if( @@ -187,18 +191,18 @@ void Manager::loopReceiver() { // get the corresponding event class std::shared_ptr event; try { - event = this->eventRegistry.at(static_cast(packetContent.eventType)); + event = this->eventRegistry.at(packetContent.eventType); } catch (const std::out_of_range& exception) { std::cerr << "[Receiver] Unsupported event type." << std::endl; continue; } - std::cout << "[Receiver] handling event: " << std::to_string(packetContent.eventType) << std::endl; + std::cout << "[Receiver] handling event: " << static_cast(packetContent.eventType) << std::endl; // ask the event class to handle the event event->handle( *this->context, - packetContent, + packetContent.data, fromAddress, fromAddressLength ); diff --git a/source/Manager.hpp b/source/Manager.hpp index cc1fb1f..4a6edac 100644 --- a/source/Manager.hpp +++ b/source/Manager.hpp @@ -21,8 +21,8 @@ public: Manager(const std::string& address, const std::string& port, bool useIpv6 = false); void loop(); - void loopSender(); - void loopReceiver(); + void loopSender() const; + void loopReceiver() const; private: std::thread senderThread; /// the thread sending communication diff --git a/source/Peer.cpp b/source/Peer.cpp index b1f088f..47c2b2b 100644 --- a/source/Peer.cpp +++ b/source/Peer.cpp @@ -1,10 +1,12 @@ #include "Peer.hpp" - -Peer::Peer() : Peer(std::array()) {} +#include "utils/serialize/basics.hpp" -Peer::Peer(const std::array& cryptoRsaPublicKey) { +Peer::Peer() = default; + + +Peer::Peer(const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey) { this->id = randomDistribution(randomGenerator); this->channel = 0; this->serverEnabled = false; @@ -15,12 +17,61 @@ Peer::Peer(const std::array& cryptoRsaPublicKey) { this->cryptoRsaPublicKey = cryptoRsaPublicKey; } +Peer::Peer( + std::uint32_t id, + bool serverEnabled, + drp::task::TaskType status, + std::uint8_t channel, + const std::chrono::high_resolution_clock::duration& latencyAverage, + const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey +) { + this->id = id; + this->serverEnabled = serverEnabled; + this->status = status; + this->channel = channel; + this->latencyAverage = latencyAverage; + this->cryptoRsaPublicKey = cryptoRsaPublicKey; +} -std::random_device Peer::randomDevice = std::random_device(); -std::mt19937 Peer::randomGenerator = std::mt19937(randomDevice()); +std::vector Peer::serialize() const { + std::vector data; -std::uniform_int_distribution Peer::randomDistribution = std::uniform_int_distribution( - 1, + // serialized the members + const auto serializedId = drp::util::serialize::serializeObject(this->id); + const auto serializedServerEnabled = drp::util::serialize::serializeObject(this->serverEnabled); + const auto serializedStatus = drp::util::serialize::serializeObject(static_cast(this->status)); + const auto serializedChannel = drp::util::serialize::serializeObject(this->channel); + const auto serializedLatencyAverage = drp::util::serialize::serializeObject(this->latencyAverage); + const auto serializedPublicKey = this->cryptoRsaPublicKey.serialize(); + + // store them in the data + data.insert(data.end(), serializedId.begin(), serializedId.end()); + data.insert(data.end(), serializedServerEnabled.begin(), serializedServerEnabled.end()); + data.insert(data.end(), serializedStatus.begin(), serializedStatus.end()); + data.insert(data.end(), serializedChannel.begin(), serializedChannel.end()); + data.insert(data.end(), serializedLatencyAverage.begin(), serializedLatencyAverage.end()); + data.insert(data.end(), serializedPublicKey.begin(), serializedPublicKey.end()); + + return data; +} + +Peer Peer::deserialize(std::vector& data) { + // deserialize the members + const auto id = drp::util::serialize::deserializeObject(data); + const auto serverEnabled = drp::util::serialize::deserializeObject(data); + const auto status = static_cast(drp::util::serialize::deserializeObject(data)); + const auto channel = drp::util::serialize::deserializeObject(data); + const auto latencyAverage = drp::util::serialize::deserializeObject(data); + const auto publicKey = drp::util::crypto::RsaPublicKey::deserialize(data); + + return Peer(id, serverEnabled, status, channel, latencyAverage, publicKey); +} + + +std::mt19937 Peer::randomGenerator = std::mt19937(std::random_device{}()); + +std::uniform_int_distribution Peer::randomDistribution = std::uniform_int_distribution( + std::numeric_limits::min(), std::numeric_limits::max() ); \ No newline at end of file diff --git a/source/Peer.hpp b/source/Peer.hpp index a9ca63b..ddee339 100644 --- a/source/Peer.hpp +++ b/source/Peer.hpp @@ -3,9 +3,9 @@ #include #include #include -#include #include "behavior/tasks/types.hpp" +#include "utils/crypto/rsa/RsaPublicKey.hpp" /** @@ -15,21 +15,33 @@ class Peer { public: Peer(); - explicit Peer(const std::array& cryptoRsaPublicKey); + explicit Peer(const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey); + explicit Peer( + std::uint32_t id, + bool serverEnabled, + drp::task::TaskType status, + std::uint8_t channel, + const std::chrono::high_resolution_clock::duration& latencyAverage, + const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey + ); - std::uint32_t id; + [[nodiscard]] std::vector serialize() const; + static Peer deserialize(std::vector &data); - bool serverEnabled; - drp::task::TaskType status; - std::uint8_t channel; + // identification + std::uint32_t id {}; // TODO(Faraphel): shall be removed in the future + // network + bool serverEnabled {}; + drp::task::TaskType status {}; + std::uint8_t channel {}; std::chrono::high_resolution_clock::duration latencyAverage {}; - std::array cryptoRsaPublicKey {}; + // cryptography + drp::util::crypto::RsaPublicKey cryptoRsaPublicKey {}; private: // random - static std::random_device randomDevice; static std::mt19937 randomGenerator; static std::uniform_int_distribution randomDistribution; }; diff --git a/source/behavior/events/audio/AudioEvent.cpp b/source/behavior/events/audio/AudioEvent.cpp index 6a187be..da08089 100644 --- a/source/behavior/events/audio/AudioEvent.cpp +++ b/source/behavior/events/audio/AudioEvent.cpp @@ -3,8 +3,6 @@ #include #include -#include "packets/audio/AudioPacketData.hpp" -#include "packets/base/PacketData.hpp" #include "utils/audio.hpp" @@ -34,12 +32,12 @@ AudioEvent::~AudioEvent() { void AudioEvent::handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, const socklen_t fromAddressLength ) { // get the audio data in the content - const auto audioData = packet::AudioPacketData::fromGeneric(content); + const auto audioData = packet::audio::AudioPacketData::deserialize(data); // save it in the audio queue this->audioQueue.push(audioData); @@ -123,7 +121,7 @@ void AudioEvent::loopPlay() { const int error = Pa_WriteStream( this->stream, audioPacket.content.data(), - audioPacket.contentSize / Pa_GetSampleSize(this->streamSampleFormat) / this->streamChannels + audioPacket.content.size() / Pa_GetSampleSize(this->streamSampleFormat) / this->streamChannels ); switch (error) { // success diff --git a/source/behavior/events/audio/AudioEvent.hpp b/source/behavior/events/audio/AudioEvent.hpp index 6839ac0..4bc855f 100644 --- a/source/behavior/events/audio/AudioEvent.hpp +++ b/source/behavior/events/audio/AudioEvent.hpp @@ -1,9 +1,8 @@ #pragma once + #include #include #include -#include -#include #include "AudioPacketsComparator.hpp" #include "../base/BaseEvent.hpp" @@ -22,7 +21,7 @@ public: void handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, socklen_t fromAddressLength ) override; @@ -34,7 +33,7 @@ private: int streamChannels; std::uint32_t streamSampleFormat; double streamRate; - std::priority_queue, AudioPacketsComparator> audioQueue; + std::priority_queue, AudioPacketsComparator> audioQueue; std::mutex audioMutex; std::unique_lock audioLock; diff --git a/source/behavior/events/audio/AudioPacketsComparator.cpp b/source/behavior/events/audio/AudioPacketsComparator.cpp index c3fbe6c..2d738a3 100644 --- a/source/behavior/events/audio/AudioPacketsComparator.cpp +++ b/source/behavior/events/audio/AudioPacketsComparator.cpp @@ -4,7 +4,7 @@ namespace drp::event { -bool AudioPacketsComparator::operator()(const packet::AudioPacketData& a, const packet::AudioPacketData& b) const { +bool AudioPacketsComparator::operator()(const packet::audio::AudioPacketData& a, const packet::audio::AudioPacketData& b) const { return a.timePlay > b.timePlay; } diff --git a/source/behavior/events/audio/AudioPacketsComparator.hpp b/source/behavior/events/audio/AudioPacketsComparator.hpp index 3af9b70..cee4346 100644 --- a/source/behavior/events/audio/AudioPacketsComparator.hpp +++ b/source/behavior/events/audio/AudioPacketsComparator.hpp @@ -7,7 +7,7 @@ namespace drp::event { struct AudioPacketsComparator { - bool operator() (const packet::AudioPacketData& a, const packet::AudioPacketData& b) const; + bool operator() (const packet::audio::AudioPacketData& a, const packet::audio::AudioPacketData& b) const; }; diff --git a/source/behavior/events/base/BaseEvent.hpp b/source/behavior/events/base/BaseEvent.hpp index 4a9df86..42a7422 100644 --- a/source/behavior/events/base/BaseEvent.hpp +++ b/source/behavior/events/base/BaseEvent.hpp @@ -12,7 +12,7 @@ public: virtual ~BaseEvent() = default; virtual void handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, socklen_t fromAddressLength ) = 0; diff --git a/source/behavior/events/info/InfoEvent.cpp b/source/behavior/events/info/InfoEvent.cpp index ca3382e..3b02e48 100644 --- a/source/behavior/events/info/InfoEvent.cpp +++ b/source/behavior/events/info/InfoEvent.cpp @@ -11,14 +11,14 @@ namespace drp::event { void InfoEvent::handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, const socklen_t fromAddressLength ) { std::cout << "[Event - Info] Received peer information." << std::endl; // get the peer information - const auto packetData = packet::info::InfoPacketData::fromGeneric(content); + const auto packetData = packet::info::InfoPacketData::deserialize(data); const Peer packetPeer = packetData.peer; // check if the peer address is already in the map diff --git a/source/behavior/events/info/InfoEvent.hpp b/source/behavior/events/info/InfoEvent.hpp index f3720da..a98eb29 100644 --- a/source/behavior/events/info/InfoEvent.hpp +++ b/source/behavior/events/info/InfoEvent.hpp @@ -10,7 +10,7 @@ class InfoEvent : public BaseEvent { public: void handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, socklen_t fromAddressLength ) override; diff --git a/source/behavior/events/pong/PongEvent.cpp b/source/behavior/events/pong/PongEvent.cpp index 10e0399..0d6efa2 100644 --- a/source/behavior/events/pong/PongEvent.cpp +++ b/source/behavior/events/pong/PongEvent.cpp @@ -8,7 +8,7 @@ namespace drp::event { void PongEvent::handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, const socklen_t fromAddressLength ) { diff --git a/source/behavior/events/pong/PongEvent.hpp b/source/behavior/events/pong/PongEvent.hpp index 505cab9..b496fbb 100644 --- a/source/behavior/events/pong/PongEvent.hpp +++ b/source/behavior/events/pong/PongEvent.hpp @@ -10,7 +10,7 @@ class PongEvent : public BaseEvent { public: void handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, socklen_t fromAddressLength ) override; diff --git a/source/behavior/events/search/SearchEvent.cpp b/source/behavior/events/search/SearchEvent.cpp index 739feed..be99fef 100644 --- a/source/behavior/events/search/SearchEvent.cpp +++ b/source/behavior/events/search/SearchEvent.cpp @@ -15,29 +15,34 @@ namespace drp { void event::SearchEvent::handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, const socklen_t fromAddressLength ) { packet::base::Packet packet {}; + packet::base::PacketContent packetContent {}; // create the packet header (available to read for everyone) packet.channel = 0; - packet.securityMode = static_cast(packet::base::SecurityMode::PLAIN); // create the packet data containing our information packet::info::InfoPacketData packetData {}; packetData.peer = context.me; - packet.setContent(packetData.toGeneric()); + packetContent.eventType = EventType::INFO; + packetContent.data = packetData.serialize(); + + packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent); // TODO(Faraphel): send back the timestamp too + const auto serializedPacket = packet.serialize(); + // send back our information if (sendto( context.socket, - &packet, - sizeof(packet), + serializedPacket.data(), + serializedPacket.size(), 0, reinterpret_cast(&fromAddress), fromAddressLength diff --git a/source/behavior/events/search/SearchEvent.hpp b/source/behavior/events/search/SearchEvent.hpp index 78264f1..37b9934 100644 --- a/source/behavior/events/search/SearchEvent.hpp +++ b/source/behavior/events/search/SearchEvent.hpp @@ -1,4 +1,5 @@ #pragma once + #include "../base/BaseEvent.hpp" @@ -9,7 +10,7 @@ class SearchEvent : public BaseEvent { public: void handle( Context& context, - const packet::base::PacketContent& content, + std::vector& data, const sockaddr_storage& fromAddress, socklen_t fromAddressLength ) override; diff --git a/source/behavior/tasks/base/BaseTask.hpp b/source/behavior/tasks/base/BaseTask.hpp index 5ae38c1..3f74da7 100644 --- a/source/behavior/tasks/base/BaseTask.hpp +++ b/source/behavior/tasks/base/BaseTask.hpp @@ -6,10 +6,19 @@ namespace drp::task { +/** + * The base to define a task. + * A task is a state for the machine, defining how it shall behave. + */ class BaseTask { public: virtual ~BaseTask() = default; + /** + * The handle of the task. + * Contain the behavior of that specific task. + * @param context the context to use. + */ virtual void handle(Context& context) = 0; }; diff --git a/source/behavior/tasks/server/ServerTask.cpp b/source/behavior/tasks/server/ServerTask.cpp index 77f2142..e3e7fc4 100644 --- a/source/behavior/tasks/server/ServerTask.cpp +++ b/source/behavior/tasks/server/ServerTask.cpp @@ -19,6 +19,15 @@ namespace drp::task { +/* +ServerTask::use(Context& context) { + context.server = serverCandidate; + context.me.status = TaskType::SERVER; + return; +} +*/ + + ServerTask::ServerTask() { this->channels = 0; this->encoding = 0; @@ -62,19 +71,25 @@ void ServerTask::handle(Context& context) { // prepare the packet structure packet::base::Packet packet {}; - packet::AudioPacketData audioPacket; + packet::base::PacketContent packetContent {}; std::size_t done; // create a packet - // TODO(Faraphel): should not be broadcast plaintext + packet::audio::AudioPacketData audioPacket; packet.channel = 0; - packet.securityMode = static_cast(packet::base::SecurityMode::PLAIN); + + // set the audio settings + audioPacket.channels = this->channels; + audioPacket.sampleFormat = util::encoding_mpg123_to_PulseAudio(this->encoding); + audioPacket.sampleRate = this->sampleRate; + + std::vector content(64992); // read the file if (mpg123_read( this->mpgHandle, - &audioPacket.content, - std::size(audioPacket.content), + content.data(), + content.size(), &done ) != MPG123_OK) { std::cerr << "[Task - Server] Could not read audio data from file." << std::endl; @@ -82,27 +97,28 @@ void ServerTask::handle(Context& context) { return; } + // resize the content to fit + content.resize(done); + audioPacket.content = content; + // set the target time // TODO(Faraphel): dynamically change this delay to be the lowest possible audioPacket.timePlay = std::chrono::high_resolution_clock::now() + std::chrono::milliseconds(5000); - // set the size of the content - audioPacket.contentSize = done; + packetContent.eventType = event::EventType::AUDIO; + packetContent.data = audioPacket.serialize(); - // set the audio settings - audioPacket.channels = this->channels; - audioPacket.sampleFormat = util::encoding_mpg123_to_PulseAudio(this->encoding); - audioPacket.sampleRate = this->sampleRate; + packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent); - packet.setContent(audioPacket.toGeneric()); + const auto serializedPacket = packet.serialize(); // broadcast the audio data if (sendto( context.socket, - &packet, - sizeof(packet), + serializedPacket.data(), + serializedPacket.size(), 0, context.broadcastAddressInfo->ai_addr, context.broadcastAddressInfo->ai_addrlen diff --git a/source/behavior/tasks/server/ServerTask.hpp b/source/behavior/tasks/server/ServerTask.hpp index a9c9251..f86e17d 100644 --- a/source/behavior/tasks/server/ServerTask.hpp +++ b/source/behavior/tasks/server/ServerTask.hpp @@ -15,7 +15,13 @@ namespace drp::task { class ServerTask : public BaseTask { public: explicit ServerTask(); - ~ServerTask(); + ~ServerTask() override; + + /** + * Set this task as the current one. + * @param context the context to apply the state on. + */ + // void use(Context &context); void handle(Context& context) override; diff --git a/source/behavior/tasks/undefined/UndefinedTask.cpp b/source/behavior/tasks/undefined/UndefinedTask.cpp index 8f6c613..b0fa504 100644 --- a/source/behavior/tasks/undefined/UndefinedTask.cpp +++ b/source/behavior/tasks/undefined/UndefinedTask.cpp @@ -76,24 +76,29 @@ void UndefinedTask::handle(Context& context) { // prepare a search message packet::base::Packet packet {}; + packet::base::PacketContent packetContent {}; packet::search::SearchPacketData packetData {}; // broadcast message packet.channel = 0; - packet.securityMode = static_cast(packet::base::SecurityMode::PLAIN); // search message with the time of the message being sent packetData.timestamp = std::chrono::high_resolution_clock::now(); - packet.setContent(packetData.toGeneric()); + packetContent.eventType = event::EventType::SEARCH; + packetContent.data = packetData.serialize(); + + packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent); std::cout << "[Task - Undefined] Looking for new peers." << std::endl; + const auto serializedPacket = packet.serialize(); + // send the search message if (sendto( context.socket, - &packet, - sizeof(packet), + serializedPacket.data(), + serializedPacket.size(), 0, context.broadcastAddressInfo->ai_addr, context.broadcastAddressInfo->ai_addrlen diff --git a/source/packets/audio/AudioPacketData.cpp b/source/packets/audio/AudioPacketData.cpp new file mode 100644 index 0000000..8552c2e --- /dev/null +++ b/source/packets/audio/AudioPacketData.cpp @@ -0,0 +1,43 @@ +#include "AudioPacketData.hpp" + +#include "utils/serialize/basics.hpp" + + +namespace drp::packet::audio { + + +std::vector AudioPacketData::serialize() const { + // serialize the members + const auto serializedTimePlay = util::serialize::serializeObject(this->timePlay); + const auto serializedChannels = util::serialize::serializeObject(this->channels); + const auto serializedSampleFormat = util::serialize::serializeObject(this->sampleFormat); + const auto serializedSampleRate = util::serialize::serializeObject(this->sampleRate); + const auto serializedContent = util::serialize::serializeVector(this->content); + + // create a buffer to store our members + std::vector data; + + // store our members + data.insert(data.end(), serializedTimePlay.begin(), serializedTimePlay.end()); + data.insert(data.end(), serializedChannels.begin(), serializedChannels.end()); + data.insert(data.end(), serializedSampleFormat.begin(), serializedSampleFormat.end()); + data.insert(data.end(), serializedSampleRate.begin(), serializedSampleRate.end()); + data.insert(data.end(), serializedContent.begin(), serializedContent.end()); + + return data; +} + + +AudioPacketData AudioPacketData::deserialize(std::vector& data) { + // deserialize the members + const auto packetTimePlay = util::serialize::deserializeObject>(data); + const auto packetChannels = util::serialize::deserializeObject(data); + const auto packetSampleFormat = util::serialize::deserializeObject(data); + const auto packetSampleRate = util::serialize::deserializeObject(data); + const auto packetContent = util::serialize::deserializeVector(data); + + return {packetTimePlay, packetChannels, packetSampleFormat, packetSampleRate, packetContent}; +} + + +} diff --git a/source/packets/audio/AudioPacketData.hpp b/source/packets/audio/AudioPacketData.hpp index 7a4480d..65f0830 100644 --- a/source/packets/audio/AudioPacketData.hpp +++ b/source/packets/audio/AudioPacketData.hpp @@ -3,19 +3,19 @@ #include #include -#include "behavior/events/types.hpp" -#include "../base/PacketData.hpp" - -namespace drp::packet { +namespace drp::packet::audio { /** * Represent the content of an audio packet. * Contains a chunk of audio and its metadata to play it. */ -class AudioPacketData : public base::PacketData { +class AudioPacketData { public: + [[nodiscard]] std::vector serialize() const; + static AudioPacketData deserialize(std::vector& data); + // scheduling // TODO(Faraphel): use a more "fixed" size format ? std::chrono::time_point timePlay; @@ -24,8 +24,7 @@ public: std::uint32_t sampleFormat {}; std::uint32_t sampleRate {}; // content - std::uint16_t contentSize {}; - std::array content {}; + std::vector content {}; }; diff --git a/source/packets/base/Packet.cpp b/source/packets/base/Packet.cpp index c80a2ac..fc883af 100644 --- a/source/packets/base/Packet.cpp +++ b/source/packets/base/Packet.cpp @@ -3,64 +3,63 @@ #include #include "SecurityMode.hpp" +#include "utils/serialize/basics.hpp" namespace drp::packet::base { -/* -GenericPacketContent decryptPacketContentAes(const GenericPacket& packet) { - GenericPacketContent decryptedPacketContent {}; +Packet::Packet() = default; - const auto& [key, iv] = keysAes[serverAddress]; - - EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new(); - if (EVP_DecryptInit_ex( - ctx, - EVP_aes_256_cbc(), - nullptr, - key, - iv - ) != 1) - throw std::runtime_error("[Client] Could not initialize the EVP_CIPHER_CTX."); - - int packetContentLength; - - if (EVP_DecryptUpdate( - ctx, - reinterpret_cast(&decryptedPacketContent), - &packetContentLength, - reinterpret_cast(&packet.encryptedContent), - sizeof(packet) - ) != 1) - throw std::runtime_error("[Client] Could not encrypt the plaintext."); - - if (EVP_DecryptFinal_ex( - ctx, - reinterpret_cast(&decryptedPacketContent + packetContentLength), - &packetContentLength - ) != 1) - throw std::runtime_error("[Client] Could not decrypt the final plaintext."); - - EVP_CIPHER_CTX_free(ctx); - - return decryptedPacketContent; +Packet::Packet(const std::uint8_t channel, const SecurityMode securityMode, const std::vector& content) { + this->channel = channel; + this->securityMode = securityMode; + this->content = content; } -*/ -PacketContent Packet::getContent() const { - // TODO(Faraphel): implement RSA and AES - // additional "context" argument to hold cryptographic keys ? +PacketContent Packet::getContent(const Context& context) const { + std::vector content; - switch (static_cast(this->securityMode)) { + switch (this->securityMode) { case SecurityMode::PLAIN: - return this->_content; + // copy the content + content = this->content; + break; case SecurityMode::AES: - // return decryptPacketContentAes(packet); + // decrypt the content + content = context.cryptoAesKey.decrypt(this->content); + break; + + case SecurityMode::RSA: throw std::runtime_error("Not implemented."); + default: + throw std::runtime_error("Unsupported security mode."); + } + + // deserialize the content + return PacketContent::deserialize(content); +} + + +void Packet::setContent(const Context& context, const SecurityMode securityMode, const PacketContent& packetContent) { + this->securityMode = securityMode; + + const std::vector content = packetContent.serialize(); + + switch (this->securityMode) { + case SecurityMode::PLAIN: + // directly save the serialized content + this->content = content; + break; + + case SecurityMode::AES: + // encrypt it with the defined AES key. + this->content = context.cryptoAesKey.encrypt(content); + break; + case SecurityMode::RSA: throw std::runtime_error("Not implemented."); @@ -70,4 +69,31 @@ PacketContent Packet::getContent() const { } +std::vector Packet::serialize() const { + // serialize the members + const auto serializedChannel = util::serialize::serializeObject(this->channel); + const auto serializedSecurityMode = util::serialize::serializeObject(static_cast(this->securityMode)); + const auto serializedContent = util::serialize::serializeVector(this->content); + + // create a buffer to store our members + std::vector data; + + // store our members + data.insert(data.end(), serializedChannel.begin(), serializedChannel.end()); + data.insert(data.end(), serializedSecurityMode.begin(), serializedSecurityMode.end()); + data.insert(data.end(), serializedContent.begin(), serializedContent.end()); + + return data; +} + +Packet Packet::deserialize(std::vector& data) { + // deserialize the members + const auto packetChannel = util::serialize::deserializeObject(data); + const auto packetSecurityMode = static_cast(util::serialize::deserializeObject(data)); + const auto packetContent = util::serialize::deserializeVector(data); + + return Packet(packetChannel, packetSecurityMode, packetContent); +} + + } diff --git a/source/packets/base/Packet.hpp b/source/packets/base/Packet.hpp index fa6352c..63132e2 100644 --- a/source/packets/base/Packet.hpp +++ b/source/packets/base/Packet.hpp @@ -1,9 +1,10 @@ #pragma once -#include #include +#include "Context.hpp" #include "PacketContent.hpp" +#include "SecurityMode.hpp" namespace drp::packet::base { @@ -16,13 +17,22 @@ namespace drp::packet::base { * @param securityMode the type of security used in the packet. * @param _content the content of the packet. It is encrypted accordingly to the securityMode. */ -struct Packet { - std::uint8_t channel; - std::uint8_t securityMode; - PacketContent _content; +class Packet { +public: + Packet(); + explicit Packet(std::uint8_t channel, SecurityMode securityMode, const std::vector& content); - [[nodiscard]] PacketContent getContent() const; - void setContent(const PacketContent& content); + [[nodiscard]] PacketContent getContent(const Context& context) const; + void setContent(const Context& context, SecurityMode securityMode, const PacketContent& packetContent); + + [[nodiscard]] std::vector serialize() const; + static Packet deserialize(std::vector& data); + + std::uint8_t channel {}; + +private: + SecurityMode securityMode {}; + std::vector content; }; diff --git a/source/packets/base/PacketContent.cpp b/source/packets/base/PacketContent.cpp index cfa3cdd..b424d16 100644 --- a/source/packets/base/PacketContent.cpp +++ b/source/packets/base/PacketContent.cpp @@ -1,30 +1,46 @@ #include "PacketContent.hpp" +#include #include #include "Packet.hpp" #include "SecurityMode.hpp" +#include "behavior/events/types.hpp" +#include "utils/serialize/basics.hpp" namespace drp::packet::base { -void Packet::setContent(const PacketContent &content) { - // TODO(Faraphel): implement RSA and AES - switch (static_cast(this->securityMode)) { - case SecurityMode::PLAIN: - this->_content = content; - return; +PacketContent::PacketContent() = default; - case SecurityMode::AES: - throw std::runtime_error("Not implemented."); +PacketContent::PacketContent(const event::EventType eventType, const std::vector& data) { + this->eventType = eventType; + this->data = data; +} - case SecurityMode::RSA: - throw std::runtime_error("Not implemented."); - default: - throw std::runtime_error("Unsupported security mode."); - } +std::vector PacketContent::serialize() const { + // serialize the members + const auto serializedEventType = util::serialize::serializeObject(static_cast(this->eventType)); + const auto serializedData = util::serialize::serializeVector(this->data); + + // create a buffer to store our members + std::vector data; + + // store our members + data.insert(data.end(), serializedEventType.begin(), serializedEventType.end()); + data.insert(data.end(), serializedData.begin(), serializedData.end()); + + return data; +} + +PacketContent PacketContent::deserialize(std::vector& data) { + // deserialize the members + const auto contentEventType = static_cast(util::serialize::deserializeObject(data)); + const auto contentData = util::serialize::deserializeVector(data); + + return {contentEventType, contentData}; } diff --git a/source/packets/base/PacketContent.hpp b/source/packets/base/PacketContent.hpp index cf47ed1..ac15b75 100644 --- a/source/packets/base/PacketContent.hpp +++ b/source/packets/base/PacketContent.hpp @@ -1,16 +1,17 @@ #pragma once -#include #include +#include +#include + +#include "behavior/events/types.hpp" namespace drp::packet::base { -// the maximum data length -// a packet can't be larger than 65565 (uint16 max) -// reserve some space for metadata and settings -constexpr std::uint16_t dataLength = 65504; +// The maximum length of a packet. Cannot be larger than 65565 (uint16 max). +constexpr std::uint16_t maxPacketLength = std::numeric_limits::max(); /** @@ -20,8 +21,14 @@ constexpr std::uint16_t dataLength = 65504; */ class PacketContent { public: - std::uint8_t eventType; - std::array data; + PacketContent(); + PacketContent(event::EventType eventType, const std::vector& data); + + [[nodiscard]] std::vector serialize() const; + static PacketContent deserialize(std::vector& data); + + event::EventType eventType {}; + std::vector data; }; diff --git a/source/packets/base/PacketData.hpp b/source/packets/base/PacketData.hpp deleted file mode 100644 index 6b7f934..0000000 --- a/source/packets/base/PacketData.hpp +++ /dev/null @@ -1,48 +0,0 @@ -#pragma once - -#include - -#include "PacketContent.hpp" -#include "behavior/events/types.hpp" - - -namespace drp::packet::base { - - -/** - * Represent the actual data contained inside a packet, with no header. - * Can be used to implement and communicate anything. - * @tparam eventType the event type associated with this type of packet data. - * Allow the receiver to redirect this packet to the correct handler. - */ -template -class PacketData { -public: - /** - * Convert this packet data to a generic packet content. - * @return a generic packet content. - */ - [[nodiscard]] PacketContent toGeneric() const { - // create an empty generic packet content - PacketContent content {}; - // set its content - content.eventType = static_cast(eventType); - std::memcpy(content.data.data(), this, content.data.size()); - - return content; - } - - /** - * Get the data from a generic packet data. - * @param content a generic packet content. - * @return the actual packet data. - */ - static packetClass fromGeneric(const PacketContent& content) { - packetClass data; - std::memcpy(&data, content.data.data(), sizeof(packetClass)); - return data; - } -}; - - -} diff --git a/source/packets/info/InfoPacketData.cpp b/source/packets/info/InfoPacketData.cpp new file mode 100644 index 0000000..6514363 --- /dev/null +++ b/source/packets/info/InfoPacketData.cpp @@ -0,0 +1,23 @@ +#include "InfoPacketData.hpp" + + +namespace drp::packet::info { + + +std::vector InfoPacketData::serialize() const { + // serialize the members + const auto serializedPeer = this->peer.serialize(); + + return serializedPeer; +} + + +InfoPacketData InfoPacketData::deserialize(std::vector& data) { + // deserialize the members + const auto packetPeer = Peer::deserialize(data); + + return {packetPeer}; +} + + +} \ No newline at end of file diff --git a/source/packets/info/InfoPacketData.hpp b/source/packets/info/InfoPacketData.hpp index 7a137fd..72c534d 100644 --- a/source/packets/info/InfoPacketData.hpp +++ b/source/packets/info/InfoPacketData.hpp @@ -1,8 +1,6 @@ #pragma once #include "RemotePeer.hpp" -#include "behavior/events/types.hpp" -#include "../base/PacketData.hpp" namespace drp::packet::info { @@ -12,8 +10,11 @@ namespace drp::packet::info { * Represent the content of an info packet. * Contains information about the peer sending it. */ -class InfoPacketData : public base::PacketData { +class InfoPacketData { public: + [[nodiscard]] std::vector serialize() const; + static InfoPacketData deserialize(std::vector& data); + Peer peer; }; diff --git a/source/packets/search/SearchPacketData.cpp b/source/packets/search/SearchPacketData.cpp new file mode 100644 index 0000000..9865bdf --- /dev/null +++ b/source/packets/search/SearchPacketData.cpp @@ -0,0 +1,25 @@ +#include "SearchPacketData.hpp" + +#include "utils/serialize/basics.hpp" + + +namespace drp::packet::search { + + +std::vector SearchPacketData::serialize() const { + // serialize the members + const auto serializedTimestamp = util::serialize::serializeObject(this->timestamp); + + return serializedTimestamp; +} + + +SearchPacketData SearchPacketData::deserialize(std::vector& data) { + // deserialize the members + const auto packetTimestamp = util::serialize::deserializeObject>(data); + + return SearchPacketData(packetTimestamp); +} + + +} diff --git a/source/packets/search/SearchPacketData.hpp b/source/packets/search/SearchPacketData.hpp index dcd3297..216fecb 100644 --- a/source/packets/search/SearchPacketData.hpp +++ b/source/packets/search/SearchPacketData.hpp @@ -2,8 +2,6 @@ #include -#include "../base/PacketData.hpp" - namespace drp::packet::search { @@ -12,8 +10,11 @@ namespace drp::packet::search { * Represent a discovery request. * Sent by someone to get information about other available machine in the network. */ -class SearchPacketData : public base::PacketData { +class SearchPacketData { public: + [[nodiscard]] std::vector serialize() const; + static SearchPacketData deserialize(std::vector& data); + std::chrono::time_point timestamp; /// timestamp when the search request was sent }; diff --git a/source/test.cpp b/source/test.cpp new file mode 100644 index 0000000..212fcf0 --- /dev/null +++ b/source/test.cpp @@ -0,0 +1,119 @@ +#include +#include + +#include "Peer.hpp" +#include "utils/crypto/aes/AesKey.hpp" +#include "utils/crypto/rsa/RsaKeyPair.hpp" +#include "utils/serialize/basics.hpp" + + +int mainAes() { + const auto aesKey = drp::util::crypto::AesKey256(); + + // plain + std::string text = "hello world!"; + const std::vector plainData(text.begin(), text.end()); + std::cout << "plain: "; + for (const auto& byte : plainData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + // encrypted + const auto encryptedData = aesKey.encrypt(plainData); + std::cout << "encrypted: "; + for (const auto& byte : encryptedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + // decrypted + const auto decryptedData = aesKey.decrypt(encryptedData); + std::cout << "decrypted: "; + for (const auto& byte : decryptedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + return 0; +} + + +int mainRsa() { + const auto rsaKey = drp::util::crypto::RsaKeyPair(2048); + const auto rsaPrivateKey = rsaKey.getPrivateKey(); + const auto rsaPublicKey = rsaKey.getPublicKey(); + + // plain + std::string text = "hello world!"; + const std::vector plainData(text.begin(), text.end()); + std::cout << "plain: "; + for (const auto& byte : plainData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + // encrypted + const auto encryptedData = rsaPublicKey.encrypt(plainData); + std::cout << "encrypted: "; + for (const auto& byte : encryptedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + // decrypted + const auto decryptedData = rsaPrivateKey.decrypt(encryptedData); + std::cout << "decrypted: "; + for (const auto& byte : decryptedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + return 0; +} + + +int mainSerialize() { + std::string text = "hello world!"; + const std::vector plainData(text.begin(), text.end()); + + std::cout << "plain: "; + for (const auto& byte : plainData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + std::vector serializedData = drp::util::serialize::serializeVector(plainData); + + std::cout << "serialized: "; + for (const auto& byte : serializedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + std::vector deserializedData = drp::util::serialize::deserializeVector(serializedData); + + std::cout << "deserialized: "; + for (const auto& byte : deserializedData) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + return 0; +} + + +int mainRsaSerialize() { + Peer peer; + auto serializedPeer = peer.serialize(); + + std::cout << "serialized: "; + for (const auto& byte : serializedPeer) + std::cout << std::to_string(byte) << "-"; + std::cout << std::endl; + + const auto deserializedPeer = Peer::deserialize(serializedPeer); + + return 0; +} + + +int main_test() { + // mainAes(); + // mainRsa(); + // mainSerialize(); + // mainRsaSerialize(); + + return 0; +} \ No newline at end of file diff --git a/source/utils/crypto.cpp b/source/utils/crypto.cpp deleted file mode 100644 index 05f8386..0000000 --- a/source/utils/crypto.cpp +++ /dev/null @@ -1 +0,0 @@ -#include "crypto.hpp" diff --git a/source/utils/crypto.hpp b/source/utils/crypto.hpp deleted file mode 100644 index 3df5127..0000000 --- a/source/utils/crypto.hpp +++ /dev/null @@ -1,70 +0,0 @@ -#pragma once - -#include -#include -#include -#include -#include -#include - - -template -std::pair, std::array> newRsaKeys() { - // create the context - const auto context = std::unique_ptr( - EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr), - EVP_PKEY_CTX_free - ); - - if (context == nullptr) - throw std::runtime_error("Could not create an EVP context."); - - if (EVP_PKEY_keygen_init(context.get()) <= 0) - throw std::runtime_error("Could not initialize the EVP context."); - - // configure the context - if (EVP_PKEY_CTX_set_rsa_keygen_bits(context.get(), size) <= 0) - throw std::runtime_error("Error setting RSA key size."); - - // create the private key - EVP_PKEY* rawKeyPair = nullptr; - if (EVP_PKEY_keygen(context.get(), &rawKeyPair) <= 0) - throw std::runtime_error("Could not generate RSA private key."); - - const std::unique_ptr keyPair( - rawKeyPair, - EVP_PKEY_free - ); - - // extract the private and public key - const std::shared_ptr privateBio( - BIO_new(BIO_s_mem()), - BIO_free - ); - - if (!PEM_write_bio_PrivateKey( - privateBio.get(), - keyPair.get(), - nullptr, - nullptr, - 0, - nullptr, - nullptr - )) - throw std::runtime_error("Could not generate RSA private key."); - - std::array privateKey; - BIO_read(privateBio.get(), privateKey.data(), BIO_pending(privateBio.get())); - - const std::shared_ptr publicBio( - BIO_new(BIO_s_mem()), - BIO_free - ); - if (!PEM_write_bio_PUBKEY(publicBio.get(), keyPair.get())) - throw std::runtime_error("Could not generate RSA public key."); - - std::array publicKey; - BIO_read(publicBio.get(), publicKey.data(), BIO_pending(publicBio.get())); - - return {privateKey, publicKey}; -} diff --git a/source/utils/crypto/aes/AesKey.cpp b/source/utils/crypto/aes/AesKey.cpp new file mode 100644 index 0000000..d330336 --- /dev/null +++ b/source/utils/crypto/aes/AesKey.cpp @@ -0,0 +1 @@ +#include "AesKey.hpp" diff --git a/source/utils/crypto/aes/AesKey.hpp b/source/utils/crypto/aes/AesKey.hpp new file mode 100644 index 0000000..d4e53cc --- /dev/null +++ b/source/utils/crypto/aes/AesKey.hpp @@ -0,0 +1,166 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + + +namespace drp::util::crypto { + + +/** + * Represent an AES key. + * Allow for encrypting and decrypting data. + * @tparam keySize the size of the key (in bytes) + * @tparam ivSize the size of the initialisation vector (in bytes) + */ +template +class AesKey { +public: + /** + * Create a random AES key. + */ + AesKey() { + // generate a random key + for (auto& byte : this->_data) + byte = randomDistribution(randomGenerator); + } + + explicit AesKey(const std::array& data) { + this->_data = data; + } + + /** + * Encrypt data with this key. + * @param plainData the data to encrypt + * @return the encrypted data. It will always be longer than the original data. + */ + [[nodiscard]] std::vector encrypt(const std::vector& plainData) const { + // create an initialization vector + std::array iv {}; + for (auto& byte : iv) + byte = randomDistribution(randomGenerator); + + // create the cipher context + const auto context = std::unique_ptr( + EVP_CIPHER_CTX_new(), + EVP_CIPHER_CTX_free + ); + + if (context == nullptr) + throw std::runtime_error("Error creating EVP_CIPHER_CTX"); + + // initialize the encryptor + if (EVP_EncryptInit_ex( + context.get(), + cipherFunction(), + nullptr, + this->_data.data(), + iv.data() + ) != 1) + throw std::runtime_error("Error initializing encryption"); + + std::vector encryptedData(ivSize + plainData.size() + EVP_CIPHER_block_size(EVP_aes_256_cbc())); + std::copy(iv.begin(), iv.end(), encryptedData.begin()); + int length; + + // encrypt the data + if (EVP_EncryptUpdate( + context.get(), + encryptedData.data() + ivSize, + &length, + plainData.data(), + static_cast(plainData.size()) + ) != 1) + throw std::runtime_error("Error encrypting data"); + int encryptedDataLength = length; + + // finalize the encryption + if (EVP_EncryptFinal_ex( + context.get(), + encryptedData.data() + ivSize + encryptedDataLength, + &length + ) != 1) + throw std::runtime_error("Error finalizing encryption"); + encryptedDataLength += length; + encryptedData.resize(ivSize + encryptedDataLength); + + return encryptedData; + } + + /** + * Decrypt data with this key. + * @param rawEncryptedData the encrypted data to decrypt. + * @return the decrypted data. + */ + [[nodiscard]] std::vector decrypt(const std::vector& rawEncryptedData) const { + // create a cipher context + const auto context = std::unique_ptr( + EVP_CIPHER_CTX_new(), + EVP_CIPHER_CTX_free + ); + + std::array iv; + std::copy(rawEncryptedData.begin(), rawEncryptedData.begin() + ivSize, iv.data()); + const std::vector encryptedData(rawEncryptedData.begin() + ivSize, rawEncryptedData.end()); + + // initialize the decryptor + if (EVP_DecryptInit_ex( + context.get(), + cipherFunction(), + nullptr, + this->_data.data(), + iv.data() + ) != 1) + throw std::runtime_error("Error initializing decryptor"); + + std::vector plainData(encryptedData.size()); + int length; + + // decrypt the data + if (EVP_DecryptUpdate( + context.get(), + plainData.data(), + &length, + encryptedData.data(), + static_cast(encryptedData.size()) + ) != 1) + throw std::runtime_error("Error decrypting data"); + + // finalize the decryption + if (EVP_DecryptFinal_ex( + context.get(), + plainData.data(), + &length + ) != 1) + throw std::runtime_error("Error finalizing decryptor"); + plainData.resize(length); + + return plainData; + } + +private: + static std::mt19937 randomGenerator; + static std::uniform_int_distribution randomDistribution; + + std::array _data; +}; + + +template +std::mt19937 AesKey::randomGenerator = std::mt19937(std::random_device{}()); + +template +std::uniform_int_distribution AesKey::randomDistribution = std::uniform_int_distribution( + std::numeric_limits::min(), + std::numeric_limits::max() +); + + +using AesKey256 = AesKey<256/8, 16, EVP_aes_256_cbc>; + + +} \ No newline at end of file diff --git a/source/utils/crypto/rsa/RsaKeyPair.cpp b/source/utils/crypto/rsa/RsaKeyPair.cpp new file mode 100644 index 0000000..7694996 --- /dev/null +++ b/source/utils/crypto/rsa/RsaKeyPair.cpp @@ -0,0 +1,98 @@ +#include "RsaKeyPair.hpp" + +#include + + +namespace drp::util::crypto { + + +RsaKeyPair::RsaKeyPair(const std::size_t size, const int padMode) { + // create the context + const std::unique_ptr context( + EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr), + EVP_PKEY_CTX_free + ); + + if (context == nullptr) + throw std::runtime_error("Could not create an EVP context."); + + if (EVP_PKEY_keygen_init(context.get()) <= 0) + throw std::runtime_error("Could not initialize the EVP context."); + + // configure the context + if (EVP_PKEY_CTX_set_rsa_keygen_bits(context.get(), static_cast(size)) <= 0) + throw std::runtime_error("Error setting RSA key size."); + + // create the key pair + EVP_PKEY* rawKeyPair = nullptr; + if (EVP_PKEY_keygen(context.get(), &rawKeyPair) <= 0) + throw std::runtime_error("Could not generate RSA key pair."); + if (rawKeyPair == nullptr) + throw std::runtime_error("Could not generate RSA key pair."); + + std::unique_ptr keyPair( + rawKeyPair, + EVP_PKEY_free + ); + + // extract the private key + const std::unique_ptr privateBio( + BIO_new(BIO_s_mem()), + BIO_free + ); + if (privateBio == nullptr) + throw std::runtime_error("Could not create RSA private key."); + + if (!i2d_PrivateKey_bio( + privateBio.get(), + keyPair.get() + )) + throw std::runtime_error("Could not generate RSA private key."); + + std::vector privateKeyData(BIO_pending(privateBio.get())); + if (BIO_read( + privateBio.get(), + privateKeyData.data(), + static_cast(privateKeyData.size()) + ) <= 0) + throw std::runtime_error("Could not read RSA private key."); + + this->privateKey = RsaPrivateKey(privateKeyData, padMode); + + // extract the public key + const std::unique_ptr publicBio( + BIO_new(BIO_s_mem()), + BIO_free + ); + if (publicBio == nullptr) + throw std::runtime_error("Could not create RSA public key."); + + if (!i2d_PUBKEY_bio( + publicBio.get(), + keyPair.get() + )) + throw std::runtime_error("Could not generate RSA public key."); + + std::vector publicKeyData(BIO_pending(publicBio.get())); + if (BIO_read( + publicBio.get(), + publicKeyData.data(), + static_cast(publicKeyData.size()) + ) <= 0) + throw std::runtime_error("Could not read RSA public key."); + + this->publicKey = RsaPublicKey(publicKeyData, padMode); +} + + +RsaPublicKey RsaKeyPair::getPublicKey() const { + return this->publicKey; +} + + +RsaPrivateKey RsaKeyPair::getPrivateKey() const { + return this->privateKey; +} + + +} diff --git a/source/utils/crypto/rsa/RsaKeyPair.hpp b/source/utils/crypto/rsa/RsaKeyPair.hpp new file mode 100644 index 0000000..36dbbb5 --- /dev/null +++ b/source/utils/crypto/rsa/RsaKeyPair.hpp @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include "RsaPrivateKey.hpp" +#include "RsaPublicKey.hpp" + + +namespace drp::util::crypto { + + +/** + * Represent a pair of RSA key. + */ +class RsaKeyPair { +public: + /** + * Generate a pair of public and private RSA keys. + */ + explicit RsaKeyPair(std::size_t size, int padMode = RSA_PKCS1_OAEP_PADDING); + + /** + * Get the public key. + * @return the public key. + */ + [[nodiscard]] RsaPublicKey getPublicKey() const; + + /** + * Get the private key. + * @return the private key. + */ + [[nodiscard]] RsaPrivateKey getPrivateKey() const; + +private: + RsaPrivateKey privateKey; + RsaPublicKey publicKey; +}; + + +} \ No newline at end of file diff --git a/source/utils/crypto/rsa/RsaPrivateKey.cpp b/source/utils/crypto/rsa/RsaPrivateKey.cpp new file mode 100644 index 0000000..fb71230 --- /dev/null +++ b/source/utils/crypto/rsa/RsaPrivateKey.cpp @@ -0,0 +1,121 @@ +#include "RsaPrivateKey.hpp" + +#include +#include +#include + +#include "utils/serialize/basics.hpp" + + +namespace drp::util::crypto { + + +RsaPrivateKey::RsaPrivateKey() = default; + +RsaPrivateKey::RsaPrivateKey(const std::vector& data, const int padMode) { + this->_data = data; + this->padMode = padMode; +} + + +std::vector RsaPrivateKey::decrypt(const std::vector& encryptedData) const { + const auto key = this->getOpenSslKey(); + + // initialize the encryption context + const std::unique_ptr context( + EVP_PKEY_CTX_new(key.get(), nullptr), + EVP_PKEY_CTX_free + ); + + if (context == nullptr) + throw std::runtime_error("Could not create EVP_PKEY_CTX."); + + // initialize the encryption operation + if (EVP_PKEY_decrypt_init(context.get()) <= 0) + throw std::runtime_error("Could not initialize encryption."); + + // set the padding + if (EVP_PKEY_CTX_set_rsa_padding(context.get(), RSA_PKCS1_OAEP_PADDING) <= 0) + throw std::runtime_error("Could not set RSA padding."); + + // get the size of the output buffer + std::size_t decryptedDataLength; + if (EVP_PKEY_decrypt( + context.get(), + nullptr, + &decryptedDataLength, + encryptedData.data(), + encryptedData.size() + ) <= 0) + throw std::runtime_error("Could not determine output length."); + + std::vector decryptedData(decryptedDataLength); + + // encrypt the data + if (EVP_PKEY_decrypt( + context.get(), + decryptedData.data(), + &decryptedDataLength, + encryptedData.data(), + encryptedData.size() + ) <= 0) + throw std::runtime_error("Could not decrypt data."); + + decryptedData.resize(decryptedDataLength); + + return decryptedData; +} + + +std::vector RsaPrivateKey::serialize() const { + // serialize the members + const auto serializedData = serialize::serializeVector(this->_data); + const auto serializedPadMode = serialize::serializeObject(static_cast(this->padMode)); + + // create a buffer to store our members + std::vector data; + + // store our members + data.insert(data.end(), serializedData.begin(), serializedData.end()); + data.insert(data.end(), serializedPadMode.begin(), serializedPadMode.end()); + + return data; +} + +RsaPrivateKey RsaPrivateKey::deserialize(std::vector& data) { + // deserialize the members + const auto keyData = serialize::deserializeVector(data); + const auto keyPadding = static_cast(serialize::deserializeObject(data)); + + return RsaPrivateKey(keyData, keyPadding); +} + + +[[nodiscard]] std::unique_ptr RsaPrivateKey::getOpenSslKey() const { + // get the bio from the private key data + const std::unique_ptr bio( + BIO_new_mem_buf( + this->_data.data(), + static_cast(this->_data.size()) + ), + BIO_free + ); + if (bio == nullptr) + throw std::runtime_error("Could not create BIO for private key."); + + // get the key from the bio + EVP_PKEY* rawKey = nullptr; + if (!d2i_PrivateKey_bio( + bio.get(), + &rawKey + )) + throw std::runtime_error("Could not deserialize RSA private key."); + + return { + rawKey, + EVP_PKEY_free + }; +} + + +} diff --git a/source/utils/crypto/rsa/RsaPrivateKey.hpp b/source/utils/crypto/rsa/RsaPrivateKey.hpp new file mode 100644 index 0000000..6476fde --- /dev/null +++ b/source/utils/crypto/rsa/RsaPrivateKey.hpp @@ -0,0 +1,32 @@ +#pragma once + +#include +#include +#include + + +namespace drp::util::crypto { + + +/** + * Represent an RSA private key. + */ +class RsaPrivateKey { +public: + RsaPrivateKey(); + explicit RsaPrivateKey(const std::vector& data, int padMode); + + [[nodiscard]] std::vector decrypt(const std::vector& encryptedData) const; + + [[nodiscard]] std::vector serialize() const; + static RsaPrivateKey deserialize(std::vector& data); + +private: + [[nodiscard]] std::unique_ptr getOpenSslKey() const; + + int padMode {}; + std::vector _data; +}; + + +} diff --git a/source/utils/crypto/rsa/RsaPublicKey.cpp b/source/utils/crypto/rsa/RsaPublicKey.cpp new file mode 100644 index 0000000..ef36603 --- /dev/null +++ b/source/utils/crypto/rsa/RsaPublicKey.cpp @@ -0,0 +1,121 @@ +#include "RsaPublicKey.hpp" + +#include +#include +#include + +#include "utils/serialize/basics.hpp" + + +namespace drp::util::crypto { + + +RsaPublicKey::RsaPublicKey() = default; + +RsaPublicKey::RsaPublicKey(const std::vector& data, const int padMode) { + this->_data = data; + this->padMode = padMode; +} + + +[[nodiscard]] std::vector RsaPublicKey::encrypt(const std::vector& plainData) const { + const auto key = this->getOpenSslKey(); + + // initialize the encryption context + const std::unique_ptr context( + EVP_PKEY_CTX_new(key.get(), nullptr), + EVP_PKEY_CTX_free + ); + + if (context == nullptr) + throw std::runtime_error("Could not create EVP_PKEY_CTX."); + + // initialize the encryption operation + if (EVP_PKEY_encrypt_init(context.get()) <= 0) + throw std::runtime_error("Could not initialize encryption."); + + // set the padding + if (EVP_PKEY_CTX_set_rsa_padding(context.get(), this->padMode) <= 0) + throw std::runtime_error("Could not set RSA padding."); + + // get the size of the output buffer + std::size_t encryptedDataLength; + if (EVP_PKEY_encrypt( + context.get(), + nullptr, + &encryptedDataLength, + plainData.data(), + plainData.size() + ) <= 0) + throw std::runtime_error("Could not determine output length."); + + std::vector encryptedData(encryptedDataLength); + + // encrypt the data + if (EVP_PKEY_encrypt( + context.get(), + encryptedData.data(), + &encryptedDataLength, + plainData.data(), + plainData.size() + ) <= 0) + throw std::runtime_error("Could not encrypt data."); + + encryptedData.resize(encryptedDataLength); + + return encryptedData; +} + + +std::vector RsaPublicKey::serialize() const { + // serialize the members + const auto serializedData = serialize::serializeVector(this->_data); + const auto serializedPadMode = serialize::serializeObject(static_cast(this->padMode)); + + // create a buffer to store our members + std::vector data; + + // store our members + data.insert(data.end(), serializedData.begin(), serializedData.end()); + data.insert(data.end(), serializedPadMode.begin(), serializedPadMode.end()); + + return data; +} + +RsaPublicKey RsaPublicKey::deserialize(std::vector& data) { + // deserialize the members + const auto keyData = serialize::deserializeVector(data); + const auto keyPadding = static_cast(serialize::deserializeObject(data)); + + return RsaPublicKey(keyData, keyPadding); +} + + +[[nodiscard]] std::unique_ptr RsaPublicKey::getOpenSslKey() const { + // get the bio from the public key data + const std::unique_ptr bio( + BIO_new_mem_buf( + this->_data.data(), + static_cast(this->_data.size()) + ), + BIO_free + ); + if (bio == nullptr) + throw std::runtime_error("Could not create BIO for public key."); + + // get the key from the bio + EVP_PKEY* rawKey = nullptr; + if (!d2i_PUBKEY_bio( + bio.get(), + &rawKey + )) + throw std::runtime_error("Could not deserialize RSA public key."); + + return { + rawKey, + EVP_PKEY_free + }; +} + + +} diff --git a/source/utils/crypto/rsa/RsaPublicKey.hpp b/source/utils/crypto/rsa/RsaPublicKey.hpp new file mode 100644 index 0000000..b97eba3 --- /dev/null +++ b/source/utils/crypto/rsa/RsaPublicKey.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include + +#include "RsaPrivateKey.hpp" + + +namespace drp::util::crypto { + + +/** + * Represent an RSA public key. + */ +class RsaPublicKey { +public: + RsaPublicKey(); + explicit RsaPublicKey(const std::vector& data, int padMode); + + /** + * Encrypt data with the public key. Can only be decrypted with the corresponding private key. + * @param plainData the plain data. + * @return the encrypted data. + */ + [[nodiscard]] std::vector encrypt(const std::vector& plainData) const; + + [[nodiscard]] std::vector serialize() const; + static RsaPublicKey deserialize(std::vector& data); + +private: + [[nodiscard]] std::unique_ptr getOpenSslKey() const; + + int padMode {}; + std::vector _data; +}; + + +} diff --git a/source/utils/serialize/basics.cpp b/source/utils/serialize/basics.cpp new file mode 100644 index 0000000..e69de29 diff --git a/source/utils/serialize/basics.hpp b/source/utils/serialize/basics.hpp new file mode 100644 index 0000000..e38ca73 --- /dev/null +++ b/source/utils/serialize/basics.hpp @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + + +namespace drp::util::serialize { + + +/** + * Serialize a basic object to a vector of bytes + * @tparam Type the type of the object. + * @param object the object to serialize. + * @return the object as a vector of bytes. + */ +template +std::vector serializeObject(const Type& object) { + // create a vector with enough space for the object + std::vector buffer(sizeof(Type)); + // copy the object data to the buffer + std::memcpy(buffer.data(), &object, sizeof(Type)); + + return buffer; +} + + +/** + * Deserialize a vector of bytes into an object. + * @warning the data used in parameter will be stripped of the data used to deserialize the object. + * @tparam Type the type of the object + * @param data the data of the object + * @return the object + */ +template +Type deserializeObject(std::vector& data) { + // create an object based on the data + Type object; + std::memcpy(&object, data.data(), sizeof(Type)); + + // remove the space used by the object in the data + data.erase(data.begin(), data.begin() + sizeof(Type)); + + return object; +} + + +/** + * Serialize a vector of anything to a vector of bytes. + * @tparam Type the type of the data contained in the vector. + * @tparam SizeType the range of the size of the vector. + * @param object the vector to serialize. + * @return the vector data as a vector of bytes. + */ +template +std::vector serializeVector(const std::vector& object) { + // create a vector with enough size for the size and the data of the vector + std::vector buffer(sizeof(SizeType) + object.size() * sizeof(Type)); + + // save the size of the vector + auto size = static_cast(object.size()); + std::memcpy(buffer.data(), &size, sizeof(SizeType)); + // save the content of the vector + std::memcpy(buffer.data() + sizeof(SizeType), object.data(), object.size() * sizeof(Type)); + + return buffer; +} + + +/** + * Deserialize a vector of bytes into a vector of object. + * @warning the data used in parameter will be stripped of the data used to deserialize the object. + * @tparam Type the type of the object + * @tparam SizeType the type of the size of the vector + * @param data the data in the vector + * @return the vector of object + */ +template +std::vector deserializeVector(std::vector& data) { + // get the size of the vector + SizeType size; + std::memcpy(&size, data.data(), sizeof(SizeType)); + // create a vector with enough size of the data + std::vector object(size); + + // restore the data into the object + std::memcpy(object.data(), data.data() + sizeof(SizeType), size * sizeof(Type)); + + // remove the data used for the deserialization from the source + data.erase(data.begin(), data.begin() + sizeof(SizeType) + size * sizeof(Type)); + + return object; +} + + +}