added backend for AES / RSA encryption, implemented variable size packets.

This commit is contained in:
study-faraphel 2024-11-26 09:29:56 +01:00
parent 523c86237c
commit 61a57a7529
46 changed files with 1275 additions and 282 deletions

View file

@ -54,7 +54,6 @@ add_executable(M2-PT-DRP
source/packets/base/PacketContent.cpp
source/packets/base/PacketContent.hpp
source/packets/base/SecurityMode.hpp
source/packets/base/PacketData.hpp
source/packets/info/InfoPacketData.hpp
source/utils/time/Chrony.cpp
source/utils/time/Chrony.hpp
@ -64,8 +63,21 @@ add_executable(M2-PT-DRP
source/Peer.cpp
source/RemotePeer.cpp
source/Context.cpp
source/utils/crypto.cpp
source/utils/crypto.hpp
source/test.cpp
source/utils/crypto/aes/AesKey.cpp
source/utils/crypto/aes/AesKey.hpp
source/utils/crypto/rsa/RsaPublicKey.cpp
source/utils/crypto/rsa/RsaPublicKey.hpp
source/utils/crypto/rsa/RsaPrivateKey.cpp
source/utils/crypto/rsa/RsaPrivateKey.hpp
source/utils/crypto/rsa/RsaPrivateKey.hpp
source/utils/crypto/rsa/RsaKeyPair.cpp
source/utils/crypto/rsa/RsaKeyPair.hpp
source/utils/serialize/basics.cpp
source/utils/serialize/basics.hpp
source/packets/audio/AudioPacketData.cpp
source/packets/info/InfoPacketData.cpp
source/packets/search/SearchPacketData.cpp
)
target_include_directories(M2-PT-DRP PRIVATE
source

View file

@ -1,8 +1,13 @@
#include "Context.hpp"
#include "utils/crypto/rsa/RsaKeyPair.hpp"
Context::Context(const std::array<std::uint8_t, 2048> &privateKey, const std::array<std::uint8_t, 2048> &publicKey) : me(publicKey) {
this->cryptoRsaPrivateKey = privateKey;
Context::Context() {
const auto keyPair = drp::util::crypto::RsaKeyPair(2048);
this->me = Peer(keyPair.getPublicKey());
this->cryptoRsaPrivateKey = keyPair.getPrivateKey();
this->socket = -1;
this->broadcastAddressInfo = nullptr;

View file

@ -5,6 +5,8 @@
#include <netdb.h>
#include "RemotePeer.hpp"
#include "utils/crypto/aes/AesKey.hpp"
#include "utils/crypto/rsa/RsaPrivateKey.hpp"
/**
@ -14,7 +16,7 @@
*/
class Context {
public:
explicit Context(const std::array<std::uint8_t, 2048>& privateKey, const std::array<std::uint8_t, 2048>& publicKey);
explicit Context();
int socket; /// current socket file descriptor, used to communicate
addrinfo* broadcastAddressInfo; /// address used to broadcast messages
@ -24,5 +26,6 @@ public:
std::list<std::shared_ptr<RemotePeer>> remotePeers {}; /// information about other machines
std::chrono::high_resolution_clock::time_point latestPeerDiscovery; /// time of the latest discovered machine
std::array<std::uint8_t, 2048> cryptoRsaPrivateKey {}; /// the RSA private key
drp::util::crypto::RsaPrivateKey cryptoRsaPrivateKey {}; /// the RSA private key
drp::util::crypto::AesKey256 cryptoAesKey = {}; /// the AES secret key
};

View file

@ -20,14 +20,14 @@
#include "behavior/tasks/client/ClientTask.hpp"
#include "behavior/tasks/server/ServerTask.hpp"
#include "behavior/tasks/undefined/UndefinedTask.hpp"
#include "utils/crypto.hpp"
#include "utils/crypto/aes/AesKey.hpp"
#include "utils/crypto/rsa/RsaKeyPair.hpp"
Manager::Manager(const std::string& address, const std::string& port, const bool useIpv6) {
std::cout << "Broadcast address: " << address << ":" << port << " (" << (useIpv6 ? "IPv6" : "IPv4") << ")" << std::endl;
auto [privateKey, publicKey] = newRsaKeys<2048>();
this->context = std::make_shared<Context>(privateKey, publicKey);
this->context = std::make_shared<Context>();
// register the different events type
this->eventRegistry = {
@ -61,7 +61,7 @@ Manager::Manager(const std::string& address, const std::string& port, const bool
throw std::runtime_error("Could not create the socket: " + std::string(strerror(errno)));
// allow IPv6 multicast loopback so that we can receive our own messages.
const int socketLoopback = 1;
constexpr int socketLoopback = 1;
if (setsockopt(
context->socket,
IPPROTO_IPV6,
@ -123,7 +123,7 @@ void Manager::loop() {
}
void Manager::loopSender() {
void Manager::loopSender() const {
while (true) {
std::cout << "[Sender] Handling status: " + std::to_string(static_cast<int>(this->context->me.status)) << std::endl;
@ -142,20 +142,20 @@ void Manager::loopSender() {
}
void Manager::loopReceiver() {
void Manager::loopReceiver() const {
// prepare space for the sender address
sockaddr_storage fromAddress {};
socklen_t fromAddressLength = sizeof(fromAddress);
drp::packet::base::Packet packet {};
drp::packet::base::PacketContent packetContent {};
std::array<std::uint8_t, drp::packet::base::maxPacketLength> buffer {};
// client loop
while (true) {
// receive new data
const ssize_t size = recvfrom(
this->context->socket,
&packet,
sizeof(packet),
buffer.data(),
buffer.size(),
0,
reinterpret_cast<sockaddr*>(&fromAddress),
&fromAddressLength
@ -163,13 +163,17 @@ void Manager::loopReceiver() {
if (size == -1)
throw std::runtime_error("[Receiver] Could not receive the packet: " + std::string(strerror(errno)));
// deserialize the packet
std::vector data(buffer.begin(), buffer.end());
const auto packet = drp::packet::base::Packet::deserialize(data);
// if the packet channel is neither 0 (all) nor the current one, ignore it
if (packet.channel != 0 && packet.channel != this->context->me.channel)
continue;
// decrypt the packet
// TODO(Faraphel): handle exception ?
packetContent = packet.getContent();
drp::packet::base::PacketContent packetContent = packet.getContent(*this->context);
// look for a saved peer with the same address
auto remotePeer = std::ranges::find_if(
@ -187,18 +191,18 @@ void Manager::loopReceiver() {
// get the corresponding event class
std::shared_ptr<drp::event::BaseEvent> event;
try {
event = this->eventRegistry.at(static_cast<drp::event::EventType>(packetContent.eventType));
event = this->eventRegistry.at(packetContent.eventType);
} catch (const std::out_of_range& exception) {
std::cerr << "[Receiver] Unsupported event type." << std::endl;
continue;
}
std::cout << "[Receiver] handling event: " << std::to_string(packetContent.eventType) << std::endl;
std::cout << "[Receiver] handling event: " << static_cast<std::uint8_t>(packetContent.eventType) << std::endl;
// ask the event class to handle the event
event->handle(
*this->context,
packetContent,
packetContent.data,
fromAddress,
fromAddressLength
);

View file

@ -21,8 +21,8 @@ public:
Manager(const std::string& address, const std::string& port, bool useIpv6 = false);
void loop();
void loopSender();
void loopReceiver();
void loopSender() const;
void loopReceiver() const;
private:
std::thread senderThread; /// the thread sending communication

View file

@ -1,10 +1,12 @@
#include "Peer.hpp"
Peer::Peer() : Peer(std::array<std::uint8_t, 2048>()) {}
#include "utils/serialize/basics.hpp"
Peer::Peer(const std::array<std::uint8_t, 2048>& cryptoRsaPublicKey) {
Peer::Peer() = default;
Peer::Peer(const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey) {
this->id = randomDistribution(randomGenerator);
this->channel = 0;
this->serverEnabled = false;
@ -15,12 +17,61 @@ Peer::Peer(const std::array<std::uint8_t, 2048>& cryptoRsaPublicKey) {
this->cryptoRsaPublicKey = cryptoRsaPublicKey;
}
Peer::Peer(
std::uint32_t id,
bool serverEnabled,
drp::task::TaskType status,
std::uint8_t channel,
const std::chrono::high_resolution_clock::duration& latencyAverage,
const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey
) {
this->id = id;
this->serverEnabled = serverEnabled;
this->status = status;
this->channel = channel;
this->latencyAverage = latencyAverage;
this->cryptoRsaPublicKey = cryptoRsaPublicKey;
}
std::random_device Peer::randomDevice = std::random_device();
std::mt19937 Peer::randomGenerator = std::mt19937(randomDevice());
std::vector<std::uint8_t> Peer::serialize() const {
std::vector<std::uint8_t> data;
std::uniform_int_distribution<std::uint32_t> Peer::randomDistribution = std::uniform_int_distribution<std::uint32_t>(
1,
// serialized the members
const auto serializedId = drp::util::serialize::serializeObject<std::uint32_t>(this->id);
const auto serializedServerEnabled = drp::util::serialize::serializeObject<std::uint8_t>(this->serverEnabled);
const auto serializedStatus = drp::util::serialize::serializeObject<std::uint8_t>(static_cast<std::uint8_t>(this->status));
const auto serializedChannel = drp::util::serialize::serializeObject<std::uint8_t>(this->channel);
const auto serializedLatencyAverage = drp::util::serialize::serializeObject(this->latencyAverage);
const auto serializedPublicKey = this->cryptoRsaPublicKey.serialize();
// store them in the data
data.insert(data.end(), serializedId.begin(), serializedId.end());
data.insert(data.end(), serializedServerEnabled.begin(), serializedServerEnabled.end());
data.insert(data.end(), serializedStatus.begin(), serializedStatus.end());
data.insert(data.end(), serializedChannel.begin(), serializedChannel.end());
data.insert(data.end(), serializedLatencyAverage.begin(), serializedLatencyAverage.end());
data.insert(data.end(), serializedPublicKey.begin(), serializedPublicKey.end());
return data;
}
Peer Peer::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto id = drp::util::serialize::deserializeObject<std::uint32_t>(data);
const auto serverEnabled = drp::util::serialize::deserializeObject<std::uint8_t>(data);
const auto status = static_cast<drp::task::TaskType>(drp::util::serialize::deserializeObject<std::uint8_t>(data));
const auto channel = drp::util::serialize::deserializeObject<std::uint8_t>(data);
const auto latencyAverage = drp::util::serialize::deserializeObject<std::chrono::high_resolution_clock::duration>(data);
const auto publicKey = drp::util::crypto::RsaPublicKey::deserialize(data);
return Peer(id, serverEnabled, status, channel, latencyAverage, publicKey);
}
std::mt19937 Peer::randomGenerator = std::mt19937(std::random_device{}());
std::uniform_int_distribution<std::uint32_t> Peer::randomDistribution = std::uniform_int_distribution(
std::numeric_limits<std::uint32_t>::min(),
std::numeric_limits<std::uint32_t>::max()
);

View file

@ -3,9 +3,9 @@
#include <chrono>
#include <cstdint>
#include <random>
#include <openssl/types.h>
#include "behavior/tasks/types.hpp"
#include "utils/crypto/rsa/RsaPublicKey.hpp"
/**
@ -15,21 +15,33 @@
class Peer {
public:
Peer();
explicit Peer(const std::array<std::uint8_t, 2048>& cryptoRsaPublicKey);
explicit Peer(const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey);
explicit Peer(
std::uint32_t id,
bool serverEnabled,
drp::task::TaskType status,
std::uint8_t channel,
const std::chrono::high_resolution_clock::duration& latencyAverage,
const drp::util::crypto::RsaPublicKey& cryptoRsaPublicKey
);
std::uint32_t id;
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static Peer deserialize(std::vector<std::uint8_t> &data);
bool serverEnabled;
drp::task::TaskType status;
std::uint8_t channel;
// identification
std::uint32_t id {}; // TODO(Faraphel): shall be removed in the future
// network
bool serverEnabled {};
drp::task::TaskType status {};
std::uint8_t channel {};
std::chrono::high_resolution_clock::duration latencyAverage {};
std::array<std::uint8_t, 2048> cryptoRsaPublicKey {};
// cryptography
drp::util::crypto::RsaPublicKey cryptoRsaPublicKey {};
private:
// random
static std::random_device randomDevice;
static std::mt19937 randomGenerator;
static std::uniform_int_distribution<std::uint32_t> randomDistribution;
};

View file

@ -3,8 +3,6 @@
#include <iostream>
#include <bits/unique_lock.h>
#include "packets/audio/AudioPacketData.hpp"
#include "packets/base/PacketData.hpp"
#include "utils/audio.hpp"
@ -34,12 +32,12 @@ AudioEvent::~AudioEvent() {
void AudioEvent::handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
const socklen_t fromAddressLength
) {
// get the audio data in the content
const auto audioData = packet::AudioPacketData::fromGeneric(content);
const auto audioData = packet::audio::AudioPacketData::deserialize(data);
// save it in the audio queue
this->audioQueue.push(audioData);
@ -123,7 +121,7 @@ void AudioEvent::loopPlay() {
const int error = Pa_WriteStream(
this->stream,
audioPacket.content.data(),
audioPacket.contentSize / Pa_GetSampleSize(this->streamSampleFormat) / this->streamChannels
audioPacket.content.size() / Pa_GetSampleSize(this->streamSampleFormat) / this->streamChannels
);
switch (error) {
// success

View file

@ -1,9 +1,8 @@
#pragma once
#include <condition_variable>
#include <portaudio.h>
#include <queue>
#include <bits/std_mutex.h>
#include <bits/unique_lock.h>
#include "AudioPacketsComparator.hpp"
#include "../base/BaseEvent.hpp"
@ -22,7 +21,7 @@ public:
void handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
socklen_t fromAddressLength
) override;
@ -34,7 +33,7 @@ private:
int streamChannels;
std::uint32_t streamSampleFormat;
double streamRate;
std::priority_queue<packet::AudioPacketData, std::vector<packet::AudioPacketData>, AudioPacketsComparator> audioQueue;
std::priority_queue<packet::audio::AudioPacketData, std::vector<packet::audio::AudioPacketData>, AudioPacketsComparator> audioQueue;
std::mutex audioMutex;
std::unique_lock<std::mutex> audioLock;

View file

@ -4,7 +4,7 @@
namespace drp::event {
bool AudioPacketsComparator::operator()(const packet::AudioPacketData& a, const packet::AudioPacketData& b) const {
bool AudioPacketsComparator::operator()(const packet::audio::AudioPacketData& a, const packet::audio::AudioPacketData& b) const {
return a.timePlay > b.timePlay;
}

View file

@ -7,7 +7,7 @@ namespace drp::event {
struct AudioPacketsComparator {
bool operator() (const packet::AudioPacketData& a, const packet::AudioPacketData& b) const;
bool operator() (const packet::audio::AudioPacketData& a, const packet::audio::AudioPacketData& b) const;
};

View file

@ -12,7 +12,7 @@ public:
virtual ~BaseEvent() = default;
virtual void handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
socklen_t fromAddressLength
) = 0;

View file

@ -11,14 +11,14 @@ namespace drp::event {
void InfoEvent::handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
const socklen_t fromAddressLength
) {
std::cout << "[Event - Info] Received peer information." << std::endl;
// get the peer information
const auto packetData = packet::info::InfoPacketData::fromGeneric(content);
const auto packetData = packet::info::InfoPacketData::deserialize(data);
const Peer packetPeer = packetData.peer;
// check if the peer address is already in the map

View file

@ -10,7 +10,7 @@ class InfoEvent : public BaseEvent {
public:
void handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
socklen_t fromAddressLength
) override;

View file

@ -8,7 +8,7 @@ namespace drp::event {
void PongEvent::handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
const socklen_t fromAddressLength
) {

View file

@ -10,7 +10,7 @@ class PongEvent : public BaseEvent {
public:
void handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
socklen_t fromAddressLength
) override;

View file

@ -15,29 +15,34 @@ namespace drp {
void event::SearchEvent::handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
const socklen_t fromAddressLength
) {
packet::base::Packet packet {};
packet::base::PacketContent packetContent {};
// create the packet header (available to read for everyone)
packet.channel = 0;
packet.securityMode = static_cast<std::uint8_t>(packet::base::SecurityMode::PLAIN);
// create the packet data containing our information
packet::info::InfoPacketData packetData {};
packetData.peer = context.me;
packet.setContent(packetData.toGeneric());
packetContent.eventType = EventType::INFO;
packetContent.data = packetData.serialize();
packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent);
// TODO(Faraphel): send back the timestamp too
const auto serializedPacket = packet.serialize();
// send back our information
if (sendto(
context.socket,
&packet,
sizeof(packet),
serializedPacket.data(),
serializedPacket.size(),
0,
reinterpret_cast<const sockaddr*>(&fromAddress),
fromAddressLength

View file

@ -1,4 +1,5 @@
#pragma once
#include "../base/BaseEvent.hpp"
@ -9,7 +10,7 @@ class SearchEvent : public BaseEvent {
public:
void handle(
Context& context,
const packet::base::PacketContent& content,
std::vector<std::uint8_t>& data,
const sockaddr_storage& fromAddress,
socklen_t fromAddressLength
) override;

View file

@ -6,10 +6,19 @@
namespace drp::task {
/**
* The base to define a task.
* A task is a state for the machine, defining how it shall behave.
*/
class BaseTask {
public:
virtual ~BaseTask() = default;
/**
* The handle of the task.
* Contain the behavior of that specific task.
* @param context the context to use.
*/
virtual void handle(Context& context) = 0;
};

View file

@ -19,6 +19,15 @@
namespace drp::task {
/*
ServerTask::use(Context& context) {
context.server = serverCandidate;
context.me.status = TaskType::SERVER;
return;
}
*/
ServerTask::ServerTask() {
this->channels = 0;
this->encoding = 0;
@ -62,19 +71,25 @@ void ServerTask::handle(Context& context) {
// prepare the packet structure
packet::base::Packet packet {};
packet::AudioPacketData audioPacket;
packet::base::PacketContent packetContent {};
std::size_t done;
// create a packet
// TODO(Faraphel): should not be broadcast plaintext
packet::audio::AudioPacketData audioPacket;
packet.channel = 0;
packet.securityMode = static_cast<std::uint8_t>(packet::base::SecurityMode::PLAIN);
// set the audio settings
audioPacket.channels = this->channels;
audioPacket.sampleFormat = util::encoding_mpg123_to_PulseAudio(this->encoding);
audioPacket.sampleRate = this->sampleRate;
std::vector<std::uint8_t> content(64992);
// read the file
if (mpg123_read(
this->mpgHandle,
&audioPacket.content,
std::size(audioPacket.content),
content.data(),
content.size(),
&done
) != MPG123_OK) {
std::cerr << "[Task - Server] Could not read audio data from file." << std::endl;
@ -82,27 +97,28 @@ void ServerTask::handle(Context& context) {
return;
}
// resize the content to fit
content.resize(done);
audioPacket.content = content;
// set the target time
// TODO(Faraphel): dynamically change this delay to be the lowest possible
audioPacket.timePlay =
std::chrono::high_resolution_clock::now() +
std::chrono::milliseconds(5000);
// set the size of the content
audioPacket.contentSize = done;
packetContent.eventType = event::EventType::AUDIO;
packetContent.data = audioPacket.serialize();
// set the audio settings
audioPacket.channels = this->channels;
audioPacket.sampleFormat = util::encoding_mpg123_to_PulseAudio(this->encoding);
audioPacket.sampleRate = this->sampleRate;
packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent);
packet.setContent(audioPacket.toGeneric());
const auto serializedPacket = packet.serialize();
// broadcast the audio data
if (sendto(
context.socket,
&packet,
sizeof(packet),
serializedPacket.data(),
serializedPacket.size(),
0,
context.broadcastAddressInfo->ai_addr,
context.broadcastAddressInfo->ai_addrlen

View file

@ -15,7 +15,13 @@ namespace drp::task {
class ServerTask : public BaseTask {
public:
explicit ServerTask();
~ServerTask();
~ServerTask() override;
/**
* Set this task as the current one.
* @param context the context to apply the state on.
*/
// void use(Context &context);
void handle(Context& context) override;

View file

@ -76,24 +76,29 @@ void UndefinedTask::handle(Context& context) {
// prepare a search message
packet::base::Packet packet {};
packet::base::PacketContent packetContent {};
packet::search::SearchPacketData packetData {};
// broadcast message
packet.channel = 0;
packet.securityMode = static_cast<std::uint8_t>(packet::base::SecurityMode::PLAIN);
// search message with the time of the message being sent
packetData.timestamp = std::chrono::high_resolution_clock::now();
packet.setContent(packetData.toGeneric());
packetContent.eventType = event::EventType::SEARCH;
packetContent.data = packetData.serialize();
packet.setContent(context, packet::base::SecurityMode::PLAIN, packetContent);
std::cout << "[Task - Undefined] Looking for new peers." << std::endl;
const auto serializedPacket = packet.serialize();
// send the search message
if (sendto(
context.socket,
&packet,
sizeof(packet),
serializedPacket.data(),
serializedPacket.size(),
0,
context.broadcastAddressInfo->ai_addr,
context.broadcastAddressInfo->ai_addrlen

View file

@ -0,0 +1,43 @@
#include "AudioPacketData.hpp"
#include "utils/serialize/basics.hpp"
namespace drp::packet::audio {
std::vector<std::uint8_t> AudioPacketData::serialize() const {
// serialize the members
const auto serializedTimePlay = util::serialize::serializeObject(this->timePlay);
const auto serializedChannels = util::serialize::serializeObject(this->channels);
const auto serializedSampleFormat = util::serialize::serializeObject(this->sampleFormat);
const auto serializedSampleRate = util::serialize::serializeObject(this->sampleRate);
const auto serializedContent = util::serialize::serializeVector(this->content);
// create a buffer to store our members
std::vector<std::uint8_t> data;
// store our members
data.insert(data.end(), serializedTimePlay.begin(), serializedTimePlay.end());
data.insert(data.end(), serializedChannels.begin(), serializedChannels.end());
data.insert(data.end(), serializedSampleFormat.begin(), serializedSampleFormat.end());
data.insert(data.end(), serializedSampleRate.begin(), serializedSampleRate.end());
data.insert(data.end(), serializedContent.begin(), serializedContent.end());
return data;
}
AudioPacketData AudioPacketData::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto packetTimePlay = util::serialize::deserializeObject<std::chrono::time_point<std::chrono::high_resolution_clock>>(data);
const auto packetChannels = util::serialize::deserializeObject<std::uint8_t>(data);
const auto packetSampleFormat = util::serialize::deserializeObject<std::uint32_t>(data);
const auto packetSampleRate = util::serialize::deserializeObject<std::uint32_t>(data);
const auto packetContent = util::serialize::deserializeVector<std::uint8_t>(data);
return {packetTimePlay, packetChannels, packetSampleFormat, packetSampleRate, packetContent};
}
}

View file

@ -3,19 +3,19 @@
#include <chrono>
#include <cstdint>
#include "behavior/events/types.hpp"
#include "../base/PacketData.hpp"
namespace drp::packet {
namespace drp::packet::audio {
/**
* Represent the content of an audio packet.
* Contains a chunk of audio and its metadata to play it.
*/
class AudioPacketData : public base::PacketData<event::EventType::AUDIO, AudioPacketData> {
class AudioPacketData {
public:
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static AudioPacketData deserialize(std::vector<std::uint8_t>& data);
// scheduling
// TODO(Faraphel): use a more "fixed" size format ?
std::chrono::time_point<std::chrono::high_resolution_clock> timePlay;
@ -24,8 +24,7 @@ public:
std::uint32_t sampleFormat {};
std::uint32_t sampleRate {};
// content
std::uint16_t contentSize {};
std::array<std::uint8_t, 65280> content {};
std::vector<std::uint8_t> content {};
};

View file

@ -3,64 +3,63 @@
#include <stdexcept>
#include "SecurityMode.hpp"
#include "utils/serialize/basics.hpp"
namespace drp::packet::base {
/*
GenericPacketContent decryptPacketContentAes(const GenericPacket& packet) {
GenericPacketContent decryptedPacketContent {};
Packet::Packet() = default;
const auto& [key, iv] = keysAes[serverAddress];
EVP_CIPHER_CTX *ctx = EVP_CIPHER_CTX_new();
if (EVP_DecryptInit_ex(
ctx,
EVP_aes_256_cbc(),
nullptr,
key,
iv
) != 1)
throw std::runtime_error("[Client] Could not initialize the EVP_CIPHER_CTX.");
int packetContentLength;
if (EVP_DecryptUpdate(
ctx,
reinterpret_cast<std::uint8_t*>(&decryptedPacketContent),
&packetContentLength,
reinterpret_cast<const std::uint8_t*>(&packet.encryptedContent),
sizeof(packet)
) != 1)
throw std::runtime_error("[Client] Could not encrypt the plaintext.");
if (EVP_DecryptFinal_ex(
ctx,
reinterpret_cast<std::uint8_t*>(&decryptedPacketContent + packetContentLength),
&packetContentLength
) != 1)
throw std::runtime_error("[Client] Could not decrypt the final plaintext.");
EVP_CIPHER_CTX_free(ctx);
return decryptedPacketContent;
Packet::Packet(const std::uint8_t channel, const SecurityMode securityMode, const std::vector<uint8_t>& content) {
this->channel = channel;
this->securityMode = securityMode;
this->content = content;
}
*/
PacketContent Packet::getContent() const {
// TODO(Faraphel): implement RSA and AES
// additional "context" argument to hold cryptographic keys ?
PacketContent Packet::getContent(const Context& context) const {
std::vector<std::uint8_t> content;
switch (static_cast<SecurityMode>(this->securityMode)) {
switch (this->securityMode) {
case SecurityMode::PLAIN:
return this->_content;
// copy the content
content = this->content;
break;
case SecurityMode::AES:
// return decryptPacketContentAes(packet);
// decrypt the content
content = context.cryptoAesKey.decrypt(this->content);
break;
case SecurityMode::RSA:
throw std::runtime_error("Not implemented.");
default:
throw std::runtime_error("Unsupported security mode.");
}
// deserialize the content
return PacketContent::deserialize(content);
}
void Packet::setContent(const Context& context, const SecurityMode securityMode, const PacketContent& packetContent) {
this->securityMode = securityMode;
const std::vector<std::uint8_t> content = packetContent.serialize();
switch (this->securityMode) {
case SecurityMode::PLAIN:
// directly save the serialized content
this->content = content;
break;
case SecurityMode::AES:
// encrypt it with the defined AES key.
this->content = context.cryptoAesKey.encrypt(content);
break;
case SecurityMode::RSA:
throw std::runtime_error("Not implemented.");
@ -70,4 +69,31 @@ PacketContent Packet::getContent() const {
}
std::vector<std::uint8_t> Packet::serialize() const {
// serialize the members
const auto serializedChannel = util::serialize::serializeObject(this->channel);
const auto serializedSecurityMode = util::serialize::serializeObject(static_cast<std::uint8_t>(this->securityMode));
const auto serializedContent = util::serialize::serializeVector(this->content);
// create a buffer to store our members
std::vector<std::uint8_t> data;
// store our members
data.insert(data.end(), serializedChannel.begin(), serializedChannel.end());
data.insert(data.end(), serializedSecurityMode.begin(), serializedSecurityMode.end());
data.insert(data.end(), serializedContent.begin(), serializedContent.end());
return data;
}
Packet Packet::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto packetChannel = util::serialize::deserializeObject<std::uint8_t>(data);
const auto packetSecurityMode = static_cast<SecurityMode>(util::serialize::deserializeObject<std::uint8_t>(data));
const auto packetContent = util::serialize::deserializeVector<std::uint8_t>(data);
return Packet(packetChannel, packetSecurityMode, packetContent);
}
}

View file

@ -1,9 +1,10 @@
#pragma once
#include <array>
#include <cstdint>
#include "Context.hpp"
#include "PacketContent.hpp"
#include "SecurityMode.hpp"
namespace drp::packet::base {
@ -16,13 +17,22 @@ namespace drp::packet::base {
* @param securityMode the type of security used in the packet.
* @param _content the content of the packet. It is encrypted accordingly to the securityMode.
*/
struct Packet {
std::uint8_t channel;
std::uint8_t securityMode;
PacketContent _content;
class Packet {
public:
Packet();
explicit Packet(std::uint8_t channel, SecurityMode securityMode, const std::vector<uint8_t>& content);
[[nodiscard]] PacketContent getContent() const;
void setContent(const PacketContent& content);
[[nodiscard]] PacketContent getContent(const Context& context) const;
void setContent(const Context& context, SecurityMode securityMode, const PacketContent& packetContent);
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static Packet deserialize(std::vector<std::uint8_t>& data);
std::uint8_t channel {};
private:
SecurityMode securityMode {};
std::vector<std::uint8_t> content;
};

View file

@ -1,30 +1,46 @@
#include "PacketContent.hpp"
#include <cstring>
#include <stdexcept>
#include "Packet.hpp"
#include "SecurityMode.hpp"
#include "behavior/events/types.hpp"
#include "utils/serialize/basics.hpp"
namespace drp::packet::base {
void Packet::setContent(const PacketContent &content) {
// TODO(Faraphel): implement RSA and AES
switch (static_cast<SecurityMode>(this->securityMode)) {
case SecurityMode::PLAIN:
this->_content = content;
return;
PacketContent::PacketContent() = default;
case SecurityMode::AES:
throw std::runtime_error("Not implemented.");
PacketContent::PacketContent(const event::EventType eventType, const std::vector<std::uint8_t>& data) {
this->eventType = eventType;
this->data = data;
}
case SecurityMode::RSA:
throw std::runtime_error("Not implemented.");
default:
throw std::runtime_error("Unsupported security mode.");
}
std::vector<std::uint8_t> PacketContent::serialize() const {
// serialize the members
const auto serializedEventType = util::serialize::serializeObject(static_cast<std::uint8_t>(this->eventType));
const auto serializedData = util::serialize::serializeVector(this->data);
// create a buffer to store our members
std::vector<std::uint8_t> data;
// store our members
data.insert(data.end(), serializedEventType.begin(), serializedEventType.end());
data.insert(data.end(), serializedData.begin(), serializedData.end());
return data;
}
PacketContent PacketContent::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto contentEventType = static_cast<event::EventType>(util::serialize::deserializeObject<std::uint8_t>(data));
const auto contentData = util::serialize::deserializeVector<std::uint8_t>(data);
return {contentEventType, contentData};
}

View file

@ -1,16 +1,17 @@
#pragma once
#include <array>
#include <cstdint>
#include <limits>
#include <vector>
#include "behavior/events/types.hpp"
namespace drp::packet::base {
// the maximum data length
// a packet can't be larger than 65565 (uint16 max)
// reserve some space for metadata and settings
constexpr std::uint16_t dataLength = 65504;
// The maximum length of a packet. Cannot be larger than 65565 (uint16 max).
constexpr std::uint16_t maxPacketLength = std::numeric_limits<std::uint16_t>::max();
/**
@ -20,8 +21,14 @@ constexpr std::uint16_t dataLength = 65504;
*/
class PacketContent {
public:
std::uint8_t eventType;
std::array<std::uint8_t, dataLength> data;
PacketContent();
PacketContent(event::EventType eventType, const std::vector<std::uint8_t>& data);
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static PacketContent deserialize(std::vector<std::uint8_t>& data);
event::EventType eventType {};
std::vector<std::uint8_t> data;
};

View file

@ -1,48 +0,0 @@
#pragma once
#include <cstring>
#include "PacketContent.hpp"
#include "behavior/events/types.hpp"
namespace drp::packet::base {
/**
* Represent the actual data contained inside a packet, with no header.
* Can be used to implement and communicate anything.
* @tparam eventType the event type associated with this type of packet data.
* Allow the receiver to redirect this packet to the correct handler.
*/
template<event::EventType eventType, class packetClass>
class PacketData {
public:
/**
* Convert this packet data to a generic packet content.
* @return a generic packet content.
*/
[[nodiscard]] PacketContent toGeneric() const {
// create an empty generic packet content
PacketContent content {};
// set its content
content.eventType = static_cast<std::uint8_t>(eventType);
std::memcpy(content.data.data(), this, content.data.size());
return content;
}
/**
* Get the data from a generic packet data.
* @param content a generic packet content.
* @return the actual packet data.
*/
static packetClass fromGeneric(const PacketContent& content) {
packetClass data;
std::memcpy(&data, content.data.data(), sizeof(packetClass));
return data;
}
};
}

View file

@ -0,0 +1,23 @@
#include "InfoPacketData.hpp"
namespace drp::packet::info {
std::vector<std::uint8_t> InfoPacketData::serialize() const {
// serialize the members
const auto serializedPeer = this->peer.serialize();
return serializedPeer;
}
InfoPacketData InfoPacketData::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto packetPeer = Peer::deserialize(data);
return {packetPeer};
}
}

View file

@ -1,8 +1,6 @@
#pragma once
#include "RemotePeer.hpp"
#include "behavior/events/types.hpp"
#include "../base/PacketData.hpp"
namespace drp::packet::info {
@ -12,8 +10,11 @@ namespace drp::packet::info {
* Represent the content of an info packet.
* Contains information about the peer sending it.
*/
class InfoPacketData : public base::PacketData<event::EventType::INFO, InfoPacketData> {
class InfoPacketData {
public:
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static InfoPacketData deserialize(std::vector<std::uint8_t>& data);
Peer peer;
};

View file

@ -0,0 +1,25 @@
#include "SearchPacketData.hpp"
#include "utils/serialize/basics.hpp"
namespace drp::packet::search {
std::vector<std::uint8_t> SearchPacketData::serialize() const {
// serialize the members
const auto serializedTimestamp = util::serialize::serializeObject(this->timestamp);
return serializedTimestamp;
}
SearchPacketData SearchPacketData::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto packetTimestamp = util::serialize::deserializeObject<std::chrono::time_point<std::chrono::high_resolution_clock>>(data);
return SearchPacketData(packetTimestamp);
}
}

View file

@ -2,8 +2,6 @@
#include <chrono>
#include "../base/PacketData.hpp"
namespace drp::packet::search {
@ -12,8 +10,11 @@ namespace drp::packet::search {
* Represent a discovery request.
* Sent by someone to get information about other available machine in the network.
*/
class SearchPacketData : public base::PacketData<event::EventType::SEARCH, SearchPacketData> {
class SearchPacketData {
public:
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static SearchPacketData deserialize(std::vector<std::uint8_t>& data);
std::chrono::time_point<std::chrono::system_clock> timestamp; /// timestamp when the search request was sent
};

119
source/test.cpp Normal file
View file

@ -0,0 +1,119 @@
#include <iostream>
#include <vector>
#include "Peer.hpp"
#include "utils/crypto/aes/AesKey.hpp"
#include "utils/crypto/rsa/RsaKeyPair.hpp"
#include "utils/serialize/basics.hpp"
int mainAes() {
const auto aesKey = drp::util::crypto::AesKey256();
// plain
std::string text = "hello world!";
const std::vector<std::uint8_t> plainData(text.begin(), text.end());
std::cout << "plain: ";
for (const auto& byte : plainData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
// encrypted
const auto encryptedData = aesKey.encrypt(plainData);
std::cout << "encrypted: ";
for (const auto& byte : encryptedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
// decrypted
const auto decryptedData = aesKey.decrypt(encryptedData);
std::cout << "decrypted: ";
for (const auto& byte : decryptedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
return 0;
}
int mainRsa() {
const auto rsaKey = drp::util::crypto::RsaKeyPair(2048);
const auto rsaPrivateKey = rsaKey.getPrivateKey();
const auto rsaPublicKey = rsaKey.getPublicKey();
// plain
std::string text = "hello world!";
const std::vector<std::uint8_t> plainData(text.begin(), text.end());
std::cout << "plain: ";
for (const auto& byte : plainData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
// encrypted
const auto encryptedData = rsaPublicKey.encrypt(plainData);
std::cout << "encrypted: ";
for (const auto& byte : encryptedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
// decrypted
const auto decryptedData = rsaPrivateKey.decrypt(encryptedData);
std::cout << "decrypted: ";
for (const auto& byte : decryptedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
return 0;
}
int mainSerialize() {
std::string text = "hello world!";
const std::vector<char> plainData(text.begin(), text.end());
std::cout << "plain: ";
for (const auto& byte : plainData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
std::vector<std::uint8_t> serializedData = drp::util::serialize::serializeVector(plainData);
std::cout << "serialized: ";
for (const auto& byte : serializedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
std::vector<char> deserializedData = drp::util::serialize::deserializeVector<char>(serializedData);
std::cout << "deserialized: ";
for (const auto& byte : deserializedData)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
return 0;
}
int mainRsaSerialize() {
Peer peer;
auto serializedPeer = peer.serialize();
std::cout << "serialized: ";
for (const auto& byte : serializedPeer)
std::cout << std::to_string(byte) << "-";
std::cout << std::endl;
const auto deserializedPeer = Peer::deserialize(serializedPeer);
return 0;
}
int main_test() {
// mainAes();
// mainRsa();
// mainSerialize();
// mainRsaSerialize();
return 0;
}

View file

@ -1 +0,0 @@
#include "crypto.hpp"

View file

@ -1,70 +0,0 @@
#pragma once
#include <vector>
#include <cstdint>
#include <memory>
#include <stdexcept>
#include <openssl/evp.h>
#include <openssl/pem.h>
template<std::size_t size>
std::pair<std::array<std::uint8_t, size>, std::array<std::uint8_t, size>> newRsaKeys() {
// create the context
const auto context = std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)>(
EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr),
EVP_PKEY_CTX_free
);
if (context == nullptr)
throw std::runtime_error("Could not create an EVP context.");
if (EVP_PKEY_keygen_init(context.get()) <= 0)
throw std::runtime_error("Could not initialize the EVP context.");
// configure the context
if (EVP_PKEY_CTX_set_rsa_keygen_bits(context.get(), size) <= 0)
throw std::runtime_error("Error setting RSA key size.");
// create the private key
EVP_PKEY* rawKeyPair = nullptr;
if (EVP_PKEY_keygen(context.get(), &rawKeyPair) <= 0)
throw std::runtime_error("Could not generate RSA private key.");
const std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> keyPair(
rawKeyPair,
EVP_PKEY_free
);
// extract the private and public key
const std::shared_ptr<BIO> privateBio(
BIO_new(BIO_s_mem()),
BIO_free
);
if (!PEM_write_bio_PrivateKey(
privateBio.get(),
keyPair.get(),
nullptr,
nullptr,
0,
nullptr,
nullptr
))
throw std::runtime_error("Could not generate RSA private key.");
std::array<std::uint8_t, size> privateKey;
BIO_read(privateBio.get(), privateKey.data(), BIO_pending(privateBio.get()));
const std::shared_ptr<BIO> publicBio(
BIO_new(BIO_s_mem()),
BIO_free
);
if (!PEM_write_bio_PUBKEY(publicBio.get(), keyPair.get()))
throw std::runtime_error("Could not generate RSA public key.");
std::array<std::uint8_t, size> publicKey;
BIO_read(publicBio.get(), publicKey.data(), BIO_pending(publicBio.get()));
return {privateKey, publicKey};
}

View file

@ -0,0 +1 @@
#include "AesKey.hpp"

View file

@ -0,0 +1,166 @@
#pragma once
#include <array>
#include <cstdint>
#include <functional>
#include <memory>
#include <random>
#include <openssl/evp.h>
namespace drp::util::crypto {
/**
* Represent an AES key.
* Allow for encrypting and decrypting data.
* @tparam keySize the size of the key (in bytes)
* @tparam ivSize the size of the initialisation vector (in bytes)
*/
template<std::size_t keySize, std::size_t ivSize, const EVP_CIPHER*(cipherFunction)()>
class AesKey {
public:
/**
* Create a random AES key.
*/
AesKey() {
// generate a random key
for (auto& byte : this->_data)
byte = randomDistribution(randomGenerator);
}
explicit AesKey(const std::array<std::uint8_t, keySize>& data) {
this->_data = data;
}
/**
* Encrypt data with this key.
* @param plainData the data to encrypt
* @return the encrypted data. It will always be longer than the original data.
*/
[[nodiscard]] std::vector<std::uint8_t> encrypt(const std::vector<std::uint8_t>& plainData) const {
// create an initialization vector
std::array<std::uint8_t, ivSize> iv {};
for (auto& byte : iv)
byte = randomDistribution(randomGenerator);
// create the cipher context
const auto context = std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)>(
EVP_CIPHER_CTX_new(),
EVP_CIPHER_CTX_free
);
if (context == nullptr)
throw std::runtime_error("Error creating EVP_CIPHER_CTX");
// initialize the encryptor
if (EVP_EncryptInit_ex(
context.get(),
cipherFunction(),
nullptr,
this->_data.data(),
iv.data()
) != 1)
throw std::runtime_error("Error initializing encryption");
std::vector<std::uint8_t> encryptedData(ivSize + plainData.size() + EVP_CIPHER_block_size(EVP_aes_256_cbc()));
std::copy(iv.begin(), iv.end(), encryptedData.begin());
int length;
// encrypt the data
if (EVP_EncryptUpdate(
context.get(),
encryptedData.data() + ivSize,
&length,
plainData.data(),
static_cast<int>(plainData.size())
) != 1)
throw std::runtime_error("Error encrypting data");
int encryptedDataLength = length;
// finalize the encryption
if (EVP_EncryptFinal_ex(
context.get(),
encryptedData.data() + ivSize + encryptedDataLength,
&length
) != 1)
throw std::runtime_error("Error finalizing encryption");
encryptedDataLength += length;
encryptedData.resize(ivSize + encryptedDataLength);
return encryptedData;
}
/**
* Decrypt data with this key.
* @param rawEncryptedData the encrypted data to decrypt.
* @return the decrypted data.
*/
[[nodiscard]] std::vector<std::uint8_t> decrypt(const std::vector<std::uint8_t>& rawEncryptedData) const {
// create a cipher context
const auto context = std::unique_ptr<EVP_CIPHER_CTX, decltype(&EVP_CIPHER_CTX_free)>(
EVP_CIPHER_CTX_new(),
EVP_CIPHER_CTX_free
);
std::array<std::uint8_t, ivSize> iv;
std::copy(rawEncryptedData.begin(), rawEncryptedData.begin() + ivSize, iv.data());
const std::vector encryptedData(rawEncryptedData.begin() + ivSize, rawEncryptedData.end());
// initialize the decryptor
if (EVP_DecryptInit_ex(
context.get(),
cipherFunction(),
nullptr,
this->_data.data(),
iv.data()
) != 1)
throw std::runtime_error("Error initializing decryptor");
std::vector<std::uint8_t> plainData(encryptedData.size());
int length;
// decrypt the data
if (EVP_DecryptUpdate(
context.get(),
plainData.data(),
&length,
encryptedData.data(),
static_cast<int>(encryptedData.size())
) != 1)
throw std::runtime_error("Error decrypting data");
// finalize the decryption
if (EVP_DecryptFinal_ex(
context.get(),
plainData.data(),
&length
) != 1)
throw std::runtime_error("Error finalizing decryptor");
plainData.resize(length);
return plainData;
}
private:
static std::mt19937 randomGenerator;
static std::uniform_int_distribution<std::uint8_t> randomDistribution;
std::array<std::uint8_t, keySize> _data;
};
template<std::size_t keySize, std::size_t ivSize, const EVP_CIPHER*(cipherFunction)()>
std::mt19937 AesKey<keySize, ivSize, cipherFunction>::randomGenerator = std::mt19937(std::random_device{}());
template<std::size_t keySize, std::size_t ivSize, const EVP_CIPHER*(cipherFunction)()>
std::uniform_int_distribution<std::uint8_t> AesKey<keySize, ivSize, cipherFunction>::randomDistribution = std::uniform_int_distribution(
std::numeric_limits<std::uint8_t>::min(),
std::numeric_limits<std::uint8_t>::max()
);
using AesKey256 = AesKey<256/8, 16, EVP_aes_256_cbc>;
}

View file

@ -0,0 +1,98 @@
#include "RsaKeyPair.hpp"
#include <openssl/x509.h>
namespace drp::util::crypto {
RsaKeyPair::RsaKeyPair(const std::size_t size, const int padMode) {
// create the context
const std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> context(
EVP_PKEY_CTX_new_id(EVP_PKEY_RSA, nullptr),
EVP_PKEY_CTX_free
);
if (context == nullptr)
throw std::runtime_error("Could not create an EVP context.");
if (EVP_PKEY_keygen_init(context.get()) <= 0)
throw std::runtime_error("Could not initialize the EVP context.");
// configure the context
if (EVP_PKEY_CTX_set_rsa_keygen_bits(context.get(), static_cast<int>(size)) <= 0)
throw std::runtime_error("Error setting RSA key size.");
// create the key pair
EVP_PKEY* rawKeyPair = nullptr;
if (EVP_PKEY_keygen(context.get(), &rawKeyPair) <= 0)
throw std::runtime_error("Could not generate RSA key pair.");
if (rawKeyPair == nullptr)
throw std::runtime_error("Could not generate RSA key pair.");
std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> keyPair(
rawKeyPair,
EVP_PKEY_free
);
// extract the private key
const std::unique_ptr<BIO, decltype(&BIO_free)> privateBio(
BIO_new(BIO_s_mem()),
BIO_free
);
if (privateBio == nullptr)
throw std::runtime_error("Could not create RSA private key.");
if (!i2d_PrivateKey_bio(
privateBio.get(),
keyPair.get()
))
throw std::runtime_error("Could not generate RSA private key.");
std::vector<std::uint8_t> privateKeyData(BIO_pending(privateBio.get()));
if (BIO_read(
privateBio.get(),
privateKeyData.data(),
static_cast<int>(privateKeyData.size())
) <= 0)
throw std::runtime_error("Could not read RSA private key.");
this->privateKey = RsaPrivateKey(privateKeyData, padMode);
// extract the public key
const std::unique_ptr<BIO, decltype(&BIO_free)> publicBio(
BIO_new(BIO_s_mem()),
BIO_free
);
if (publicBio == nullptr)
throw std::runtime_error("Could not create RSA public key.");
if (!i2d_PUBKEY_bio(
publicBio.get(),
keyPair.get()
))
throw std::runtime_error("Could not generate RSA public key.");
std::vector<std::uint8_t> publicKeyData(BIO_pending(publicBio.get()));
if (BIO_read(
publicBio.get(),
publicKeyData.data(),
static_cast<int>(publicKeyData.size())
) <= 0)
throw std::runtime_error("Could not read RSA public key.");
this->publicKey = RsaPublicKey(publicKeyData, padMode);
}
RsaPublicKey RsaKeyPair::getPublicKey() const {
return this->publicKey;
}
RsaPrivateKey RsaKeyPair::getPrivateKey() const {
return this->privateKey;
}
}

View file

@ -0,0 +1,44 @@
#pragma once
#include <iostream>
#include <memory>
#include <openssl/evp.h>
#include <openssl/rsa.h>
#include <openssl/types.h>
#include "RsaPrivateKey.hpp"
#include "RsaPublicKey.hpp"
namespace drp::util::crypto {
/**
* Represent a pair of RSA key.
*/
class RsaKeyPair {
public:
/**
* Generate a pair of public and private RSA keys.
*/
explicit RsaKeyPair(std::size_t size, int padMode = RSA_PKCS1_OAEP_PADDING);
/**
* Get the public key.
* @return the public key.
*/
[[nodiscard]] RsaPublicKey getPublicKey() const;
/**
* Get the private key.
* @return the private key.
*/
[[nodiscard]] RsaPrivateKey getPrivateKey() const;
private:
RsaPrivateKey privateKey;
RsaPublicKey publicKey;
};
}

View file

@ -0,0 +1,121 @@
#include "RsaPrivateKey.hpp"
#include <openssl/rsa.h>
#include <openssl/evp.h>
#include <openssl/x509.h>
#include "utils/serialize/basics.hpp"
namespace drp::util::crypto {
RsaPrivateKey::RsaPrivateKey() = default;
RsaPrivateKey::RsaPrivateKey(const std::vector<std::uint8_t>& data, const int padMode) {
this->_data = data;
this->padMode = padMode;
}
std::vector<std::uint8_t> RsaPrivateKey::decrypt(const std::vector<std::uint8_t>& encryptedData) const {
const auto key = this->getOpenSslKey();
// initialize the encryption context
const std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> context(
EVP_PKEY_CTX_new(key.get(), nullptr),
EVP_PKEY_CTX_free
);
if (context == nullptr)
throw std::runtime_error("Could not create EVP_PKEY_CTX.");
// initialize the encryption operation
if (EVP_PKEY_decrypt_init(context.get()) <= 0)
throw std::runtime_error("Could not initialize encryption.");
// set the padding
if (EVP_PKEY_CTX_set_rsa_padding(context.get(), RSA_PKCS1_OAEP_PADDING) <= 0)
throw std::runtime_error("Could not set RSA padding.");
// get the size of the output buffer
std::size_t decryptedDataLength;
if (EVP_PKEY_decrypt(
context.get(),
nullptr,
&decryptedDataLength,
encryptedData.data(),
encryptedData.size()
) <= 0)
throw std::runtime_error("Could not determine output length.");
std::vector<std::uint8_t> decryptedData(decryptedDataLength);
// encrypt the data
if (EVP_PKEY_decrypt(
context.get(),
decryptedData.data(),
&decryptedDataLength,
encryptedData.data(),
encryptedData.size()
) <= 0)
throw std::runtime_error("Could not decrypt data.");
decryptedData.resize(decryptedDataLength);
return decryptedData;
}
std::vector<std::uint8_t> RsaPrivateKey::serialize() const {
// serialize the members
const auto serializedData = serialize::serializeVector(this->_data);
const auto serializedPadMode = serialize::serializeObject(static_cast<std::uint8_t>(this->padMode));
// create a buffer to store our members
std::vector<std::uint8_t> data;
// store our members
data.insert(data.end(), serializedData.begin(), serializedData.end());
data.insert(data.end(), serializedPadMode.begin(), serializedPadMode.end());
return data;
}
RsaPrivateKey RsaPrivateKey::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto keyData = serialize::deserializeVector<std::uint8_t>(data);
const auto keyPadding = static_cast<int>(serialize::deserializeObject<std::uint8_t>(data));
return RsaPrivateKey(keyData, keyPadding);
}
[[nodiscard]] std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> RsaPrivateKey::getOpenSslKey() const {
// get the bio from the private key data
const std::unique_ptr<BIO, decltype(&BIO_free)> bio(
BIO_new_mem_buf(
this->_data.data(),
static_cast<int>(this->_data.size())
),
BIO_free
);
if (bio == nullptr)
throw std::runtime_error("Could not create BIO for private key.");
// get the key from the bio
EVP_PKEY* rawKey = nullptr;
if (!d2i_PrivateKey_bio(
bio.get(),
&rawKey
))
throw std::runtime_error("Could not deserialize RSA private key.");
return {
rawKey,
EVP_PKEY_free
};
}
}

View file

@ -0,0 +1,32 @@
#pragma once
#include <memory>
#include <vector>
#include <openssl/evp.h>
namespace drp::util::crypto {
/**
* Represent an RSA private key.
*/
class RsaPrivateKey {
public:
RsaPrivateKey();
explicit RsaPrivateKey(const std::vector<std::uint8_t>& data, int padMode);
[[nodiscard]] std::vector<std::uint8_t> decrypt(const std::vector<std::uint8_t>& encryptedData) const;
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static RsaPrivateKey deserialize(std::vector<std::uint8_t>& data);
private:
[[nodiscard]] std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> getOpenSslKey() const;
int padMode {};
std::vector<std::uint8_t> _data;
};
}

View file

@ -0,0 +1,121 @@
#include "RsaPublicKey.hpp"
#include <ranges>
#include <openssl/rsa.h>
#include <openssl/x509.h>
#include "utils/serialize/basics.hpp"
namespace drp::util::crypto {
RsaPublicKey::RsaPublicKey() = default;
RsaPublicKey::RsaPublicKey(const std::vector<uint8_t>& data, const int padMode) {
this->_data = data;
this->padMode = padMode;
}
[[nodiscard]] std::vector<uint8_t> RsaPublicKey::encrypt(const std::vector<std::uint8_t>& plainData) const {
const auto key = this->getOpenSslKey();
// initialize the encryption context
const std::unique_ptr<EVP_PKEY_CTX, decltype(&EVP_PKEY_CTX_free)> context(
EVP_PKEY_CTX_new(key.get(), nullptr),
EVP_PKEY_CTX_free
);
if (context == nullptr)
throw std::runtime_error("Could not create EVP_PKEY_CTX.");
// initialize the encryption operation
if (EVP_PKEY_encrypt_init(context.get()) <= 0)
throw std::runtime_error("Could not initialize encryption.");
// set the padding
if (EVP_PKEY_CTX_set_rsa_padding(context.get(), this->padMode) <= 0)
throw std::runtime_error("Could not set RSA padding.");
// get the size of the output buffer
std::size_t encryptedDataLength;
if (EVP_PKEY_encrypt(
context.get(),
nullptr,
&encryptedDataLength,
plainData.data(),
plainData.size()
) <= 0)
throw std::runtime_error("Could not determine output length.");
std::vector<std::uint8_t> encryptedData(encryptedDataLength);
// encrypt the data
if (EVP_PKEY_encrypt(
context.get(),
encryptedData.data(),
&encryptedDataLength,
plainData.data(),
plainData.size()
) <= 0)
throw std::runtime_error("Could not encrypt data.");
encryptedData.resize(encryptedDataLength);
return encryptedData;
}
std::vector<std::uint8_t> RsaPublicKey::serialize() const {
// serialize the members
const auto serializedData = serialize::serializeVector(this->_data);
const auto serializedPadMode = serialize::serializeObject<std::uint8_t>(static_cast<std::uint8_t>(this->padMode));
// create a buffer to store our members
std::vector<std::uint8_t> data;
// store our members
data.insert(data.end(), serializedData.begin(), serializedData.end());
data.insert(data.end(), serializedPadMode.begin(), serializedPadMode.end());
return data;
}
RsaPublicKey RsaPublicKey::deserialize(std::vector<std::uint8_t>& data) {
// deserialize the members
const auto keyData = serialize::deserializeVector<std::uint8_t>(data);
const auto keyPadding = static_cast<int>(serialize::deserializeObject<std::uint8_t>(data));
return RsaPublicKey(keyData, keyPadding);
}
[[nodiscard]] std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> RsaPublicKey::getOpenSslKey() const {
// get the bio from the public key data
const std::unique_ptr<BIO, decltype(&BIO_free)> bio(
BIO_new_mem_buf(
this->_data.data(),
static_cast<int>(this->_data.size())
),
BIO_free
);
if (bio == nullptr)
throw std::runtime_error("Could not create BIO for public key.");
// get the key from the bio
EVP_PKEY* rawKey = nullptr;
if (!d2i_PUBKEY_bio(
bio.get(),
&rawKey
))
throw std::runtime_error("Could not deserialize RSA public key.");
return {
rawKey,
EVP_PKEY_free
};
}
}

View file

@ -0,0 +1,37 @@
#pragma once
#include <vector>
#include "RsaPrivateKey.hpp"
namespace drp::util::crypto {
/**
* Represent an RSA public key.
*/
class RsaPublicKey {
public:
RsaPublicKey();
explicit RsaPublicKey(const std::vector<uint8_t>& data, int padMode);
/**
* Encrypt data with the public key. Can only be decrypted with the corresponding private key.
* @param plainData the plain data.
* @return the encrypted data.
*/
[[nodiscard]] std::vector<uint8_t> encrypt(const std::vector<std::uint8_t>& plainData) const;
[[nodiscard]] std::vector<std::uint8_t> serialize() const;
static RsaPublicKey deserialize(std::vector<std::uint8_t>& data);
private:
[[nodiscard]] std::unique_ptr<EVP_PKEY, decltype(&EVP_PKEY_free)> getOpenSslKey() const;
int padMode {};
std::vector<uint8_t> _data;
};
}

View file

View file

@ -0,0 +1,96 @@
#pragma once
#include <cstring>
#include <cstdint>
#include <vector>
namespace drp::util::serialize {
/**
* Serialize a basic object to a vector of bytes
* @tparam Type the type of the object.
* @param object the object to serialize.
* @return the object as a vector of bytes.
*/
template<typename Type>
std::vector<std::uint8_t> serializeObject(const Type& object) {
// create a vector with enough space for the object
std::vector<std::uint8_t> buffer(sizeof(Type));
// copy the object data to the buffer
std::memcpy(buffer.data(), &object, sizeof(Type));
return buffer;
}
/**
* Deserialize a vector of bytes into an object.
* @warning the data used in parameter will be stripped of the data used to deserialize the object.
* @tparam Type the type of the object
* @param data the data of the object
* @return the object
*/
template<typename Type>
Type deserializeObject(std::vector<std::uint8_t>& data) {
// create an object based on the data
Type object;
std::memcpy(&object, data.data(), sizeof(Type));
// remove the space used by the object in the data
data.erase(data.begin(), data.begin() + sizeof(Type));
return object;
}
/**
* Serialize a vector of anything to a vector of bytes.
* @tparam Type the type of the data contained in the vector.
* @tparam SizeType the range of the size of the vector.
* @param object the vector to serialize.
* @return the vector data as a vector of bytes.
*/
template<typename Type, typename SizeType = std::uint32_t>
std::vector<std::uint8_t> serializeVector(const std::vector<Type>& object) {
// create a vector with enough size for the size and the data of the vector
std::vector<std::uint8_t> buffer(sizeof(SizeType) + object.size() * sizeof(Type));
// save the size of the vector
auto size = static_cast<SizeType>(object.size());
std::memcpy(buffer.data(), &size, sizeof(SizeType));
// save the content of the vector
std::memcpy(buffer.data() + sizeof(SizeType), object.data(), object.size() * sizeof(Type));
return buffer;
}
/**
* Deserialize a vector of bytes into a vector of object.
* @warning the data used in parameter will be stripped of the data used to deserialize the object.
* @tparam Type the type of the object
* @tparam SizeType the type of the size of the vector
* @param data the data in the vector
* @return the vector of object
*/
template<typename Type, typename SizeType = std::uint32_t>
std::vector<Type> deserializeVector(std::vector<std::uint8_t>& data) {
// get the size of the vector
SizeType size;
std::memcpy(&size, data.data(), sizeof(SizeType));
// create a vector with enough size of the data
std::vector<Type> object(size);
// restore the data into the object
std::memcpy(object.data(), data.data() + sizeof(SizeType), size * sizeof(Type));
// remove the data used for the deserialization from the source
data.erase(data.begin(), data.begin() + sizeof(SizeType) + size * sizeof(Type));
return object;
}
}