Merge pull request 'Rewrote the project in Python' (#7) from python into main

Reviewed-on: #7
This commit is contained in:
faraphel 2025-01-27 16:06:31 +01:00
commit bffd973062
61 changed files with 1492 additions and 529 deletions

5
.gitignore vendored
View file

@ -1,2 +1,7 @@
# IDE
.idea
cmake-build-*
# Local
assets
storage

View file

@ -1,31 +0,0 @@
cmake_minimum_required(VERSION 3.28)
project(M2-PT-DRP LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
find_package(PkgConfig REQUIRED)
pkg_check_modules(MPG123 REQUIRED libmpg123)
pkg_check_modules(PORTAUDIO REQUIRED portaudio-2.0)
add_executable(M2-PT-DRP
source/main.cpp
source/Client.cpp
source/Client.hpp
source/Server.cpp
source/Server.hpp
source/packets/AudioPacket.hpp
source/utils/audio.cpp
source/utils/audio.hpp
)
target_include_directories(M2-PT-DRP PRIVATE
${MPG123_INCLUDE_DIRS}
${PORTAUDIO_INCLUDE_DIRS}
)
target_link_libraries(M2-PT-DRP PRIVATE
${MPG123_LIBRARIES}
${PORTAUDIO_LIBRARIES}
)

View file

@ -28,8 +28,8 @@ Ce logiciel est distribué tel quel, sans aucune garantie de quelque nature que
## VI. Glossaire
fork : projet se basant sur le code code source d'un logiciel déjà existant.
fork : projet se basant sur le code source d'un logiciel déjà existant.
---
Copyright © 2024 - Raphaël CARON
Copyright © 2025 - Raphaël CARON

33
README.md Normal file
View file

@ -0,0 +1,33 @@
# M2 Projet Thématique - Diffusion Radio Proche-en-Proche
Un projet visant à créer un réseau de machine capable de diffuser une source
audio à jouer de manière synchronisé.
Les communications du réseau doivent être chiffré et il ne doit pas être possible
d'inséré une machine inconnu pour pertuber le réseau.
## Usage
Cet application nécessite que votre machine utilise `Python >= 3.13` avec `chrony`
pour synchroniser les machines entre elles.
Debian
```bash
# dependencies
sudo apt upgrade
sudo apt install -y git ffmpeg libportaudio2
# download the project
git clone https://git.faraphel.fr/study-faraphel/M2-PT-DRP
cd ./M2-PT-DRP/
# create a virtual environment
python3 -m venv ./.venv/
source ./.venv/bin/activate
# install python packages
pip3 install -r ./requirements.txt
# run the application
python3 -m source
```

Binary file not shown.

17
requirements.txt Normal file
View file

@ -0,0 +1,17 @@
# extended standard
bidict
pause
sortedcontainers
numpy
# networking
psutil
msgpack
# cryptography
cryptography
# audio
pydub
audioop-lts
sounddevice

View file

@ -1,194 +0,0 @@
#include "Client.hpp"
#include <cstring>
#include <iostream>
#include <list>
#include <map>
#include <netdb.h>
#include <queue>
#include <stdexcept>
#include <thread>
#include <sys/socket.h>
#include "packets/AudioPacket.hpp"
Client::Client() {
this->stream = nullptr;
this->audioLock = std::unique_lock(this->audioMutex);
this->streamChannels = 0;
this->streamSampleFormat = 0;
this->streamRate = 0;
}
void Client::updateStream(const int channels, const std::uint32_t sampleFormat, const double sampleRate) {
// check if any information changed. If no, ignore this
if (
this->streamChannels == channels &&
this->streamSampleFormat == sampleFormat &&
this->streamRate == sampleRate
)
return;
// close the current stream
// ignore errors that could happen if no audio is currently playing
Pa_CloseStream(&this->stream);
// open a new stream with the new settings
if (const PaError error = Pa_OpenDefaultStream(
&this->stream,
0,
channels,
sampleFormat,
sampleRate,
paFramesPerBufferUnspecified,
nullptr,
nullptr
) != paNoError)
throw std::runtime_error("[Client] Could not open the stream: " + std::string(Pa_GetErrorText(error)));
// update the new audio values
this->streamChannels = channels;
this->streamSampleFormat = sampleFormat;
this->streamRate = sampleRate;
}
Client::~Client() {
// stop any currently playing audio
Pa_StopStream(this->stream);
// close the audio stream
if (const PaError error = Pa_CloseStream(this->stream))
std::cerr << "[Client] Could not close the stream: " << std::string(Pa_GetErrorText(error)) << std::endl;
}
void Client::loop() {
// run an audio receiver alongside an audio player
this->receiverThread = std::thread(&Client::loopReceiver, this);
this->playerThread = std::thread(&Client::loopPlayer, this);
this->receiverThread.join();
this->playerThread.join();
}
void Client::loopReceiver() {
// create the socket
const int clientSocket = socket(
AF_INET6,
SOCK_DGRAM,
0
);
if (clientSocket < 0)
throw std::runtime_error("[Client] Could not create the socket: " + std::string(gai_strerror(errno)));
// get the broadcast address
addrinfo serverHints = {};
serverHints.ai_family = AF_INET6;
serverHints.ai_socktype = SOCK_DGRAM;
serverHints.ai_protocol = IPPROTO_UDP;
// TODO(Faraphel): port as argument
addrinfo *serverInfo;
if(getaddrinfo(
nullptr, // hostname
"5650", // our port
&serverHints,
&serverInfo
) != 0)
throw std::runtime_error("[Client] Could not get the address: " + std::string(gai_strerror(errno)));
// bind the socket to the address
if (bind(
clientSocket,
serverInfo->ai_addr,
serverInfo->ai_addrlen
) < 0)
throw std::runtime_error("[Client] Could not bind to the address: " + std::string(gai_strerror(errno)));
// free the server address
freeaddrinfo(serverInfo);
// prepare space for the server address
sockaddr_storage serverAddress {};
socklen_t serverAddressLength;
// prepare space for the received audio
AudioPacket audioPacket;
// receive new audio data
while (true) {
// receive new audio data
const ssize_t size = recvfrom(
clientSocket,
&audioPacket,
sizeof(audioPacket),
0,
reinterpret_cast<sockaddr *>(&serverAddress),
&serverAddressLength
);
if (size == -1) {
std::cerr << "[Client] Could not receive from the socket: " << gai_strerror(errno) << std::endl;
continue;
}
// save the audio data into the queue for the player
std::cout << "[Client] Received: " << size << " bytes" << std::endl;
this->audioQueue.push(audioPacket);
// notify that a new audio chunk is available
this->audioCondition.notify_one();
}
}
void Client::loopPlayer() {
while (true) {
// wait for a new element in the audio queue
this->audioCondition.wait(
this->audioLock,
[this] { return !this->audioQueue.empty(); }
);
// get the most recent audio chunk
const auto audioPacket = this->audioQueue.top();
// update the stream with the new audio settings
this->updateStream(
audioPacket.channels,
audioPacket.sampleFormat,
audioPacket.sampleRate
);
// wait until it must be played
std::this_thread::sleep_until(audioPacket.timePlay);
std::cout << "[Client] Playing: " << audioPacket.timePlay << std::endl;
// immediately stop playing music
// this avoids an offset created if this client's clock is too ahead of the others
// don't handle errors since audio might not be playing before
Pa_AbortStream(this->stream);
// play the new audio data
if (const int error = Pa_StartStream(this->stream) != paNoError)
throw std::runtime_error("[Client] Could not start the PortAudio stream: " + std::string(Pa_GetErrorText(error)));
// write the new audio data into the audio buffer
const int error = Pa_WriteStream(
this->stream,
audioPacket.content.data(),
audioPacket.contentSize / Pa_GetSampleSize(this->streamSampleFormat) / this->streamChannels
);
switch (error) {
// success
case paNoError:
// the output might be very slightly underflowed,
// causing a very small period where no noise will be played.
case paOutputUnderflowed:
break;
default:
std::cerr << "[Client] Could not write to the audio stream: " << Pa_GetErrorText(error) << std::endl;
}
// remove the audio chunk
this->audioQueue.pop();
}
}

View file

@ -1,65 +0,0 @@
#pragma once
#include <condition_variable>
#include <mutex>
#include <portaudio.h>
#include <queue>
#include <thread>
#include "packets/AudioPacket.hpp"
// TODO(Faraphel): should be moved somewhere else
struct AudioPacketsComparator {
bool operator() (const AudioPacket &a, const AudioPacket &b) const {
return a.timePlay > b.timePlay;
}
};
/**
* the audio Client.
* Receive audio packets and play them at a specific time.
*/
class Client {
public:
explicit Client();
~Client();
/**
* Update the current audio stream
* @param channels the number of channels
* @param sampleFormat the sample format type
* @param sampleRate the audio rate
*/
void updateStream(int channels, std::uint32_t sampleFormat, double sampleRate);
/**
* Indefinitely receive and play audio data.
*/
void loop();
private:
/**
* Indefinitely receive audio data.
*/
void loopReceiver();
/**
* Indefinitely play audio data.
*/
void loopPlayer();
PaStream* stream;
int streamChannels;
std::uint32_t streamSampleFormat;
double streamRate;
std::priority_queue<AudioPacket, std::vector<AudioPacket>, AudioPacketsComparator> audioQueue;
std::mutex audioMutex;
std::unique_lock<std::mutex> audioLock;
std::condition_variable audioCondition;
std::thread receiverThread;
std::thread playerThread;
};

View file

@ -1,127 +0,0 @@
#include "Server.hpp"
#include <iostream>
#include <cstdint>
#include <cstring>
#include <mpg123.h>
#include <netdb.h>
#include <stdexcept>
#include <thread>
#include <sys/socket.h>
#include <vector>
#include "packets/AudioPacket.hpp"
#include "utils/audio.hpp"
Server::Server() {
this->channels = 0;
this->encoding = 0;
this->sampleRate = 0;
// create a new mpg123 handle
int error;
this->mpgHandle = mpg123_new(nullptr, &error);
if (this->mpgHandle == nullptr)
throw std::runtime_error("[Server] Could not create a mpg123 handle.");
// open the mp3 file
// TODO(Faraphel): mp3 file as argument
if (mpg123_open(
this->mpgHandle,
// "./assets/Caravan Palace - Wonderland.mp3"
"./assets/Queen - Another One Bites the Dust.mp3"
) != MPG123_OK)
throw std::runtime_error("[Server] Could not open file.");
// get the format of the file
if (mpg123_getformat(
this->mpgHandle,
&this->sampleRate,
&this->channels,
&this->encoding
) != MPG123_OK)
throw std::runtime_error("[Server] Could not get the format of the file.");
}
Server::~Server() {
// delete the mpg123 handle
mpg123_close(this->mpgHandle);
mpg123_delete(this->mpgHandle);
}
void Server::loop() const {
// get the broadcast address
addrinfo broadcastHints {};
broadcastHints.ai_family = AF_INET6;
broadcastHints.ai_socktype = SOCK_DGRAM;
broadcastHints.ai_protocol = IPPROTO_UDP;
// TODO(Faraphel): ip / port as argument
addrinfo *broadcastInfo;
if(const int error = getaddrinfo(
"::1",
"5650",
&broadcastHints,
&broadcastInfo
) != 0)
throw std::runtime_error("[Server] Could not get the address: " + std::string(gai_strerror(error)));
const int broadcastSocket = socket(
broadcastInfo->ai_family,
broadcastInfo->ai_socktype,
broadcastInfo->ai_protocol
);
if (broadcastSocket == -1)
throw std::runtime_error("[Server] Could not create the socket: " + std::string(gai_strerror(errno)));
// read the file
AudioPacket audioPacket;
std::size_t done;
while (mpg123_read(
this->mpgHandle,
&audioPacket.content,
std::size(audioPacket.content),
&done
) == MPG123_OK) {
// set the target time
// TODO(Faraphel): dynamically change this delay to be the lowest possible
audioPacket.timePlay =
std::chrono::high_resolution_clock::now() +
std::chrono::milliseconds(5000);
// set the audio settings
audioPacket.channels = this->channels;
audioPacket.sampleFormat = encoding_mpg123_to_PulseAudio(this->encoding);
audioPacket.sampleRate = this->sampleRate;
// set the size of the content
audioPacket.contentSize = done;
// broadcast the audio data
if (sendto(
broadcastSocket,
&audioPacket,
sizeof(audioPacket),
0,
broadcastInfo->ai_addr,
broadcastInfo->ai_addrlen
) == -1) {
std::cerr << "[Server] Could not send audio packet: " << strerror(errno) << std::endl;
continue;
}
std::cout << "[Server] Sent: " << done << " bytes" << std::endl;
// wait for the duration of the audio chunk
std::this_thread::sleep_for(std::chrono::milliseconds(static_cast<uint64_t>(
(1 / static_cast<double>(this->sampleRate * this->channels * mpg123_encsize(this->encoding))) *
1000 *
static_cast<double>(done)
)));
}
// free the server address
freeaddrinfo(broadcastInfo);
}

View file

@ -1,25 +0,0 @@
#pragma once
#include <mpg123.h>
/**
* the audio Server.
* Read and broadcast audio data.
*/
class Server {
public:
explicit Server();
~Server();
/**
* Indefinitely read and broadcast audio data.
*/
void loop() const;
private:
mpg123_handle* mpgHandle;
long sampleRate;
int channels;
int encoding;
};

3
source/__init__.py Normal file
View file

@ -0,0 +1,3 @@
from . import managers
from . import behaviors
from . import packets

21
source/__main__.py Normal file
View file

@ -0,0 +1,21 @@
import argparse
from source.managers import Manager
parser = argparse.ArgumentParser(
prog="ISRI-DRP",
description="Create a network of machine that try to play an audio file synchronously."
)
parser.add_argument(
"-i", "--interface",
required=True,
help="The interface on which other peers are accessible."
)
arguments = parser.parse_args()
manager = Manager(arguments.interface)
manager.loop()

View file

View file

@ -0,0 +1,14 @@
from source import packets
from source.behaviors.events import base
class AudioEvent(base.BaseTrustedEvent):
"""
Event reacting to receiving audio data.
"""
def handle(self, packet: packets.AudioPacket, address: tuple):
super().handle(packet, address)
# add the audio chunk to the buffer of audio packet to play
self.manager.audio.add_audio(packet)

View file

@ -0,0 +1,20 @@
from . import base
from source import packets
from .. import roles
from ...utils.crypto.type import CipherType
class DiscoveryEvent(base.BaseEvent):
"""
Event reacting to a machine trying to discover the network.
"""
def handle(self, packet: packets.DiscoveryPacket, address: tuple):
# create a peer packet containing our important information
response = packets.PeerPacket(
self.manager.communication.public_key,
isinstance(self.manager.role.current, roles.MasterRole)
)
# send our information back
# don't use any encryption to share the RSA key for further communication
self.manager.communication.send(response, CipherType.PLAIN, address)

View file

@ -0,0 +1,21 @@
from source import packets
from source.behaviors import roles
from source.behaviors.events import base
class KeyEvent(base.BaseTrustedEvent):
"""
Event reacting to a machine sending us their secret key
"""
def handle(self, packet: packets.KeyPacket, address: tuple):
super().handle(packet, address)
# check if we are a slave
if not isinstance(self.manager.role.current, roles.SlaveRole):
return
# TODO(Faraphel): check if this come from our server ?
# use the secret key for further symmetric communication
self.manager.role.current.secret_key = packet.secret_key

View file

@ -0,0 +1,23 @@
from . import base
from source import packets, structures
class PeerEvent(base.BaseEvent):
"""
Event reacting to receiving information about another machine
"""
def handle(self, packet: packets.PeerPacket, address: tuple):
# ignore peers with a banned key
if self.manager.communication.is_peer_banned(packet.public_key):
return
# TODO(Faraphel): SHOULD NOT BE TRUSTED AUTOMATICALLY !
self.manager.communication.trust_peer(packet.public_key)
# update our peers database to add new peer information
self.manager.peer.peers[address] = structures.Peer(
public_key=packet.public_key,
master=packet.master,
trusted=self.manager.communication.is_peer_trusted(packet.public_key)
)

View file

@ -0,0 +1,22 @@
from source import packets
from source.behaviors import roles
from source.behaviors.events import base
from source.utils.crypto.type import CipherType
class RequestKeyEvent(base.BaseTrustedEvent):
"""
Event reacting to a machine trying to get our secret symmetric key for secure communication
"""
def handle(self, packet: packets.RequestKeyPacket, address: tuple):
super().handle(packet, address)
# check if we are a master
if not isinstance(self.manager.role.current, roles.MasterRole):
return
# create a packet containing our secret key
packet = packets.KeyPacket(self.manager.role.current.secret_key)
# send it back to the slave
self.manager.communication.send(packet, CipherType.RSA, address)

View file

@ -0,0 +1,7 @@
from . import base
from .DiscoveryEvent import DiscoveryEvent
from .PeerEvent import PeerEvent
from .AudioEvent import AudioEvent
from .RequestKeyEvent import RequestKeyEvent
from .KeyEvent import KeyEvent

View file

@ -0,0 +1,16 @@
import abc
from source import packets, managers
class BaseEvent(abc.ABC):
def __init__(self, manager: "managers.Manager"):
self.manager = manager
@abc.abstractmethod
def handle(self, packet: packets.base.BasePacket, address: tuple) -> None:
"""
Handle a packet
:param packet: the packet to handle
:param address: the address of the machine that sent the packet
"""

View file

@ -0,0 +1,18 @@
import abc
from source import packets
from source.behaviors.events.base import BaseEvent
from source.error import UntrustedPeerException
class BaseTrustedEvent(BaseEvent, abc.ABC):
"""
Event that can only be triggered if the distant peer is trusted
"""
def handle(self, packet: packets.base.BasePacket, address: tuple) -> None:
# get the peer that sent the message
peer = self.manager.peer.peers.get(address)
# check if it is trusted
if peer is None or not peer.trusted:
raise UntrustedPeerException(peer)

View file

@ -0,0 +1,2 @@
from .BaseEvent import BaseEvent
from .BaseTrustedEvent import BaseTrustedEvent

View file

@ -0,0 +1,66 @@
import os
from datetime import datetime, timedelta
import pause
import pydub
from pydub.utils import make_chunks
from source.behaviors.roles import base
from source.managers import Manager
from source.packets import AudioPacket
from source.utils.crypto.type import CipherType
class MasterRole(base.BaseActiveRole):
"""
Role used when the machine is declared as the master.
It will be the machine responsible for emitting data that the others peers should play together.
"""
TARGET_SIZE: int = 60 * 1024 # set an upper bound because of the IPv6 maximum packet size.
def __init__(self, manager: "Manager"):
super().__init__(manager)
# generate a random secret key for symmetric communication
self.secret_key = os.urandom(32)
# prepare the audio file that will be streamed
# TODO(Faraphel): use another audio source
self.audio = pydub.AudioSegment.from_file("./assets/Queen - Another One Bites the Dust.mp3")
self.play_time = datetime.now()
# calculate the number of bytes per milliseconds in the audio
bytes_per_ms = self.audio.frame_rate * self.audio.sample_width * self.audio.channels / 1000
# calculate the required chunk duration to reach that size
self.chunk_duration = timedelta(milliseconds=self.TARGET_SIZE / bytes_per_ms)
# split the audio into chunks
self.chunk_count = 0
self.chunks = make_chunks(self.audio, self.chunk_duration.total_seconds() * 1000)
def handle(self) -> None:
# TODO(Faraphel): communicate with chrony to add peers and enable server mode ?
# TODO(Faraphel): share the secret key generated with the other *allowed* peers ! How to select them ? A file ?
# TODO(Faraphel): check if another server is emitting sound in the network. Return to undefined if yes
# get the current chunk
chunk = self.chunks[self.chunk_count]
# broadcast it in the network
audio_packet = AudioPacket(
# TODO(Faraphel): adjust time depending on the network average latency ?
datetime.now() + timedelta(seconds=5), # play it in some seconds to let all the machines get the sample
chunk.channels,
chunk.frame_rate,
chunk.sample_width,
chunk.raw_data,
)
self.manager.communication.broadcast(audio_packet, CipherType.AES_CBC)
# increment the chunk count
self.chunk_count += 1
# wait for the next chunk time
pause.until(self.play_time + (self.chunk_duration * self.chunk_count))

View file

@ -0,0 +1,34 @@
import typing
from datetime import timedelta, datetime
from source import managers, packets, structures
from source.behaviors.roles import base, UndefinedRole
from source.utils.crypto.type import CipherType
class SlaveRole(base.BaseActiveRole):
"""
Role used when the machine is declared as a slave.
It shall listen for a master and check if everything is working properly
"""
def __init__(self, manager: "managers.Manager", master_address: tuple):
super().__init__(manager)
# the address of the server
self.master_address = master_address
def handle(self):
# get our master peer
master_peer: structures.Peer = self.manager.peer.peers[self.master_address]
# check if we know the network secret key
if self.secret_key is None:
# if we don't know it, request it
packet = packets.RequestKeyPacket()
self.manager.communication.send(packet, CipherType.RSA, self.master_address)
# check if the master interacted recently
if datetime.now() - master_peer.last_interaction > timedelta(seconds=10):
# if the master didn't react in a moment, return to undefined mode
self.manager.role.current = UndefinedRole(self.manager)

View file

@ -0,0 +1,59 @@
from datetime import datetime, timedelta
import pause
from source.behaviors import roles
from source.behaviors.roles import base
class UndefinedRole(base.BaseRole):
"""
Role used when the machine is looking for how it should insert itself in the network
"""
def handle(self) -> None:
# calculate a timeout of when stopping to look for new peers
timeout = self.manager.peer.peers.last_added + timedelta(seconds=5)
# if the timeout have not been reach, wait for it and restart
if not datetime.now() > timeout:
pause.until(timeout)
return
# SCENARIO 1 - empty network
# filter ourselves out of the remote peers
remote_peers = {
address: peer
for (address, peer) in self.manager.peer.peers.items()
if not self.manager.communication.is_address_local(address)
}
# if no other peers have been found
if len(remote_peers) == 0:
# declare ourselves as the master of the network
self.manager.role.current = roles.MasterRole(self.manager)
return
# SCENARIO 2 - network with a master
# list all the peers considered as masters
master_peers = {
address: peer
for (address, peer) in remote_peers.items()
if peer.master
}
# if there is a master, become a slave
if len(master_peers) >= 1:
# get the first master information
master_address, master_peer = master_peers[0]
# declare ourselves as a slave of the network
self.manager.role.current = roles.SlaveRole(self.manager, master_address)
return
# SCENARIO 3 - network with no master
# TODO(Faraphel): elect the machine with the lowest ping in the network
raise NotImplementedError("Not implemented: elect the machine with the lowest ping as a master.")

View file

@ -0,0 +1,5 @@
from . import base
from .MasterRole import MasterRole
from .SlaveRole import SlaveRole
from .UndefinedRole import UndefinedRole

View file

@ -0,0 +1,17 @@
import abc
from typing import Optional
from source import managers
from source.behaviors.roles.base import BaseRole
class BaseActiveRole(BaseRole, abc.ABC):
"""
Base for a role where the machine know what it should do in the network
"""
def __init__(self, manager: "managers.Manager"):
super().__init__(manager)
# an "active" machine shall be able to use symmetric encryption
self.secret_key: Optional[bytes] = None

View file

@ -0,0 +1,18 @@
import abc
from source import managers
class BaseRole(abc.ABC):
"""
Base for all the role the machine can have
"""
def __init__(self, manager: "managers.Manager"):
self.manager = manager
@abc.abstractmethod
def handle(self) -> None:
"""
Behavior of the role
"""

View file

@ -0,0 +1,2 @@
from .BaseRole import BaseRole
from .BaseActiveRole import BaseActiveRole

View file

@ -0,0 +1,8 @@
import typing
from source import structures
class UntrustedPeerException(Exception):
def __init__(self, peer: typing.Optional[structures.Peer]):
super().__init__(f"Peer not trusted: {peer}")

1
source/error/__init__.py Normal file
View file

@ -0,0 +1 @@
from .UntrustedPeerException import UntrustedPeerException

View file

@ -1,34 +0,0 @@
#include <mpg123.h>
#include <portaudio.h>
#include <stdexcept>
#include <thread>
#include "Client.hpp"
#include "Server.hpp"
int main(int argc, char* argv[]) {
// TODO(Faraphel): move in the Client
// initialize the mpg123 library
if (mpg123_init() != MPG123_OK)
throw std::runtime_error("Error while initializing mpg123.");
// initialize the PortAudio library
if (Pa_Initialize() != paNoError)
throw std::runtime_error("Could not initialize PortAudio.");
// start the client and server
Server server;
Client client;
std::thread serverThread(&Server::loop, &server);
std::thread clientThread(&Client::loop, &client);
serverThread.join();
clientThread.join();
// terminate the libraries
Pa_Terminate();
mpg123_exit();
}

View file

@ -0,0 +1,99 @@
import threading
import typing
from datetime import datetime
import numpy
import pause
import sortedcontainers
import sounddevice
from source import packets, managers
from source.utils.audio.audio import sample_width_to_type
class AudioManager:
"""
Manage playing audio data in the buffer
"""
def __init__(self, manager: "managers.Manager"):
self.stream: typing.Optional[sounddevice.OutputStream] = None
# buffer containing the set of audio chunk to play. Sort them by their time to play
self.buffer = sortedcontainers.SortedList(key=lambda audio: audio.time)
# thread support
self.lock = threading.Lock()
self.new_audio_event = threading.Event() # event triggered when a new audio have been added
def add_audio(self, audio: packets.AudioPacket) -> None:
"""
Add a new audio chunk to play
:param audio: the audio chunk to play
"""
with self.lock:
# add the audio packet to the buffer
self.buffer.add(audio)
# trigger the new audio event
self.new_audio_event.set()
def play_audio(self, audio: packets.AudioPacket) -> None:
# create a numpy array for our sample
sample = numpy.frombuffer(audio.data, dtype=sample_width_to_type(audio.sample_width))
# reshape it to have a sub-array for each channels
sample = sample.reshape((-1, audio.channels))
# normalize the sample to be between -1 and 1
sample = sample / (2 ** (audio.sample_width * 8 - 1))
# use float32 for the audio library
sample = sample.astype(numpy.float32)
# wait for the audio given time
pause.until(audio.time)
# update the stream if the audio use different settings
if (
self.stream is None or
self.stream.samplerate != audio.sample_rate or
self.stream.channels != audio.channels
):
self.stream = sounddevice.OutputStream(
samplerate=audio.sample_rate,
channels=audio.channels,
)
# play
self.stream.start()
# write the audio to the stream
self.stream.write(sample)
def handle(self) -> None:
"""
Play the audio chunk in the buffer at the given time
"""
# wait for a new audio packet
# TODO(Faraphel): use self.lock ? seem to softlock the application
if len(self.buffer) == 0:
self.new_audio_event.clear()
self.new_audio_event.wait()
# get the most recent audio packet to play
audio: packets.AudioPacket = self.buffer.pop(0)
# if the audio should have been played before, skip it
if audio.time < datetime.now():
return
# play the audio packet
self.play_audio(audio)
def loop(self) -> None:
"""
Handle forever
"""
while True:
self.handle()

View file

@ -0,0 +1,302 @@
import hashlib
import json
import socket
import typing
import zlib
from datetime import datetime
import bidict
import psutil
from source import packets, utils, structures
from source.behaviors import roles
from source.managers import Manager
from source.utils.crypto.rsa import rsa_create_key_pair
from source.utils.crypto.type import CipherType
class CommunicationManager:
"""
Manage the communication between the peers
"""
def __init__(self, manager: "Manager", interface: str, broadcast_address: str = "ff02::1", port: int = 5555):
self.manager = manager
self.broadcast_address = broadcast_address
self.interface = interface
self.port = port
# create an IPv6 UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# enable broadcast messages
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
# use multicast on the selected interface
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(interface))
# bind to listen for any message on this port
self.socket.bind(("::", self.port))
# create a dictionary to hold the types of packets and their headers.
self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict()
# load or create a private and public key for asymmetric communication
private_key_path = self.manager.storage / "private_key.der"
public_key_path = self.manager.storage / "public_key.der"
if public_key_path.exists() and private_key_path.exists():
self.private_key = private_key_path.read_bytes()
self.public_key = public_key_path.read_bytes()
else:
self.private_key, self.public_key = rsa_create_key_pair()
private_key_path.write_bytes(self.private_key)
public_key_path.write_bytes(self.public_key)
self._trusted_peers_path = self.manager.storage / "trusted-peers.json"
self._trusted_peers: set[str] = set()
if self._trusted_peers_path.exists():
self._trusted_peers = set(json.loads(self._trusted_peers_path.read_text()))
self._banned_peers_path = self.manager.storage / "banned-peers.json"
self._banned_peers: set[str] = set()
if self._banned_peers_path.exists():
self._banned_peers = set(json.loads(self._banned_peers_path.read_text()))
def __del__(self):
# close the socket
self.socket.close()
def get_secret_key(self) -> typing.Optional[bytes]:
"""
Get the symmetric secret key
:return: the symmetric secret key
"""
# check if we have an "active" role
if not isinstance(self.manager.role.current, roles.base.BaseActiveRole):
return None
# return its secret key
return self.manager.role.current.secret_key
def register_packet_type(self, header: bytes, packet_type: typing.Type[packets.base.BasePacket]) -> None:
"""
Register a new kind of packet that can be sent or received.
:param header: the binary header identifying the packet
:param packet_type: the class of the packet
"""
if len(header) != 4:
raise Exception("The header should be exactly 4 bytes long.")
self.packet_types[header] = packet_type
def packet_encode(self, packet: packets.base.BasePacket, cipher_type: CipherType) -> bytes:
"""
Encode a packet for diffusion
:param packet: a packet to encode to be sent
:param cipher_type: the type of cipher
:return: an encoded packet
"""
# get the header identifier of the type of this packet
header: typing.Optional[bytes] = self.packet_types.inverse.get(type(packet))
if header is None:
raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?")
# get the encoded packet data
data = packet.pack()
# calculate its checksum using CRC32
checksum = zlib.crc32(data).to_bytes(4, byteorder='big')
# get the packet data
packet_data = checksum + header + data
# encrypt the packet data depending on the cipher selected
secret_key = self.get_secret_key()
if secret_key is None and cipher_type in utils.crypto.type.CIPHER_SYMMETRIC_TYPES:
raise ValueError("Cannot cipher a packet with undefined secret key.")
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, secret_key)
case CipherType.AES_CBC:
packet_data = utils.crypto.aes.aes_cbc_encrypt(packet_data, secret_key)
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# prepend the cipher type to the packet data
payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data
return payload
def packet_decode(self, payload: bytes) -> packets.base.BasePacket:
"""
Decode a payload into a packet
:param payload: the data of the packet
:return: the deserialized packet
"""
cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big"))
packet_data: bytes = payload[2:]
# decrypt the packet data depending on the cipher used
secret_key = self.get_secret_key()
if secret_key is None and cipher_type in utils.crypto.type.CIPHER_SYMMETRIC_TYPES:
raise ValueError("Cannot decipher a packet with undefined secret key.")
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, secret_key)
case CipherType.AES_CBC:
packet_data = utils.crypto.aes.aes_cbc_decrypt(packet_data, secret_key)
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# split the header and data from the raw payload
checksum: int = int.from_bytes(packet_data[:4], "big")
header: bytes = packet_data[4:8]
data: bytes = packet_data[8:]
# verify the checksum for corruption
if zlib.crc32(data) != checksum:
raise ValueError("The checksum is invalid.")
# get the type of the packet from its header
packet_type: typing.Optional[typing.Type[packets.base.BasePacket]] = self.packet_types.get(header)
if header is None:
raise KeyError(f"Unrecognised packet header: {header}. Has it been registered ?")
# unpack the packet
return packet_type.unpack(data)
def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple):
self.socket.sendto(self.packet_encode(packet, cipher_type), address)
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType):
"""
Broadcast a message in the network
:param cipher_type: the type of cipher
:param packet: the message to broadcast
"""
# check that no asymmetric cipher mode is used
if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES:
raise ValueError("Asymmetric cipher cannot be used in broadcast.")
# TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???)
self.send(packet, cipher_type, (self.broadcast_address, self.port))
def receive(self) -> tuple[packets.base.BasePacket, tuple]:
"""
Receive a packet
:return: the packet content alongside the address of the sender
"""
# receive a message
payload, address = self.socket.recvfrom(65536)
# check if there is a peer associated with this address
peer: structures.Peer = self.manager.peer.peers.get(address)
if peer is not None:
# update the latest interaction date
peer.last_interaction = datetime.now()
# decode the payload
return self.packet_decode(payload), address
def get_local_addresses(self) -> typing.Iterator[tuple]:
"""
Get the local hosts addresses of the machine (on the selected interface)
:return: the local hosts addresses of the machine
"""
for address in psutil.net_if_addrs()[self.interface]:
# return the address family and the host (without the interface suffix)
yield address.family, address.address.split("%")[0]
def is_address_local(self, address: tuple) -> bool:
"""
Is the given address local
:return: true if the address is local, false otherwise
"""
host, _, _, scope = address
# check if the host is in our local hosts list
for local_address in self.get_local_addresses():
local_family, local_host = local_address
if host == local_host:
return True
return False
def save_trusted_peers(self) -> None:
"""
Save the list of trusted peers
"""
self._trusted_peers_path.write_text(json.dumps(list(self._trusted_peers)))
def save_banned_peers(self) -> None:
"""
Save the list of banned peers
"""
self._banned_peers_path.write_text(json.dumps(list(self._banned_peers)))
def trust_peer(self, public_key: bytes) -> None:
"""
Mark a peer as trusted for future connexion
Automatically save it to a file
:param public_key: the public key of the peer
"""
self._trusted_peers.add(hashlib.sha256(public_key).hexdigest())
self.save_trusted_peers()
def ban_peer(self, public_key: bytes) -> None:
"""
Ban a peer from being used for any future connexion
Automatically save it to a file
:param public_key: the public key of the peer
"""
self._banned_peers.add(hashlib.sha256(public_key).hexdigest())
self.save_banned_peers()
def is_peer_trusted(self, public_key: bytes) -> bool:
"""
Determinate is a peer is trusted or not
:param public_key: the public key of the peer
:return: True if the peer is trusted, False otherwise
"""
return hashlib.sha256(public_key).hexdigest() in self._trusted_peers
def is_peer_banned(self, public_key: bytes) -> bool:
"""
Determinate if a peer is banned or not
:param public_key: the public key of the peer
:return: True if the peer is banned, False otherwise
"""
return hashlib.sha256(public_key).hexdigest() in self._banned_peers

View file

@ -0,0 +1,63 @@
import traceback
import typing
import warnings
from source import packets
from source.behaviors import events
from source.error import UntrustedPeerException
from source.managers import Manager
class EventManager:
"""
Responsible for receiving packets from other peers and handling them.
"""
def __init__(self, manager: "Manager"):
self.manager = manager
# events
self.event_handlers: dict[typing.Type[packets.base.BasePacket], events.base.BaseEvent] = {}
def register_event_handler(self, packet_type: typing.Type[packets.base.BasePacket], event: events.base.BaseEvent) -> None:
"""
Register a new event to react to a specific packet type
:param packet_type: the type of packet to listen to
:param event: the event handler
"""
self.event_handlers[packet_type] = event
def handle(self, packet: packets.base.BasePacket, address: tuple) -> None:
"""
Handle the packet received
:param packet: the packet to handle
:param address: the address of the machine that sent the packet
"""
# get the event handler of this kind of packet
event_handler = self.event_handlers.get(type(packet))
if event_handler is None:
raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?")
# use the event handler on the packet
event_handler.handle(packet, address)
def loop(self) -> None:
"""
Handle events forever
"""
while True:
try:
# wait for a new packet
packet, address = self.manager.communication.receive()
print(f"Received message from {address}: {packet}")
# give it to the event handler
self.manager.event.handle(packet, address)
except UntrustedPeerException:
print("Ignored: untrusted peer.")
except Exception: # NOQA
warnings.warn(traceback.format_exc())

View file

@ -0,0 +1,63 @@
import threading
from pathlib import Path
from source import packets
from source.behaviors import events
class Manager:
"""
Global manager
"""
def __init__(self, interface: str):
from . import CommunicationManager, EventManager, RoleManager, AudioManager, PeerManager
self.storage = Path("./storage/")
self.storage.mkdir(exist_ok=True)
# communication manager
self.communication = CommunicationManager(self, interface)
self.communication.register_packet_type(b"DISC", packets.DiscoveryPacket)
self.communication.register_packet_type(b"PEER", packets.PeerPacket)
self.communication.register_packet_type(b"AUDI", packets.AudioPacket)
self.communication.register_packet_type(b"RQSK", packets.RequestKeyPacket)
self.communication.register_packet_type(b"GTSK", packets.KeyPacket)
# event manager
self.event = EventManager(self)
self.event.register_event_handler(packets.DiscoveryPacket, events.DiscoveryEvent(self))
self.event.register_event_handler(packets.PeerPacket, events.PeerEvent(self))
self.event.register_event_handler(packets.AudioPacket, events.AudioEvent(self))
self.event.register_event_handler(packets.RequestKeyPacket, events.RequestKeyEvent(self))
self.event.register_event_handler(packets.KeyPacket, events.KeyEvent(self))
# role manager
self.role = RoleManager(self)
# audio manager
self.audio = AudioManager(self)
# peer manager
self.peer = PeerManager(self)
def loop(self) -> None:
"""
Handle the sub-managers forever
"""
# run a thread for the event and the role manager
event_thread = threading.Thread(target=self.event.loop)
role_thread = threading.Thread(target=self.role.loop)
audio_thread = threading.Thread(target=self.audio.loop)
peer_thread = threading.Thread(target=self.peer.loop)
event_thread.start()
role_thread.start()
audio_thread.start()
peer_thread.start()
event_thread.join()
role_thread.join()
audio_thread.join()
peer_thread.join()

View file

@ -0,0 +1,30 @@
import time
from source import packets, structures
from source.managers import Manager
from source.utils.crypto.type import CipherType
from source.utils.dict import TimestampedDict
class PeerManager:
"""
Manage the peers network
"""
def __init__(self, manager: "Manager"):
self.manager = manager
# set of addresses associated to their peer
self.peers: TimestampedDict[tuple, structures.Peer] = TimestampedDict()
def handle(self) -> None:
# send requests to discover new peers
packet = packets.DiscoveryPacket()
self.manager.communication.broadcast(packet, CipherType.PLAIN)
def loop(self) -> None:
while True:
self.handle()
# TODO(Faraphel): adjust sleep time ? as much seconds as there are peer to avoid flooding the network ?
time.sleep(1)

View file

@ -0,0 +1,29 @@
from source.behaviors import roles
from source.managers import Manager
class RoleManager:
"""
Role Manager
Responsible for the passive behavior of the machine and sending packets
"""
def __init__(self, manager: "Manager"):
self.manager = manager
# the currently used role
self.current: roles.base.BaseRole = roles.UndefinedRole(self.manager)
def handle(self) -> None:
"""
Run the role
"""
self.current.handle()
def loop(self) -> None:
"""
Handle forever
"""
while True:
self.handle()

View file

@ -0,0 +1,7 @@
from .CommunicationManager import CommunicationManager
from .EventManager import EventManager
from .RoleManager import RoleManager
from .AudioManager import AudioManager
from .PeerManager import PeerManager
from .Manager import Manager

View file

@ -1,18 +0,0 @@
#pragma once
#include <chrono>
#include <cstdint>
struct AudioPacket {
// scheduling
// TODO(Faraphel): use a more "fixed" size format ?
std::chrono::time_point<std::chrono::high_resolution_clock> timePlay;
// audio settings
std::uint8_t channels;
std::uint32_t sampleFormat;
std::uint32_t sampleRate;
// content
std::uint16_t contentSize;
std::array<std::uint8_t, 65280> content;
};

View file

@ -0,0 +1,62 @@
import dataclasses
import zlib
from datetime import datetime
import msgpack
from source.packets import base
@dataclasses.dataclass
class AudioPacket(base.BasePacket):
"""
Represent a packet of audio data
"""
# expected time to play the audio
time: datetime = dataclasses.field()
# audio details
channels: int = dataclasses.field()
sample_rate: int = dataclasses.field()
sample_width: int = dataclasses.field()
# raw audio data
_data: bytes = dataclasses.field(repr=False)
# is the audio zlib compressed
compressed: bool = dataclasses.field(default=False)
def pack(self) -> bytes:
return msgpack.packb((
self.time.timestamp(),
self.channels,
self.sample_rate,
self.sample_width,
self._data,
self.compressed
))
def __post_init__(self):
# if the audio is not compressed, compress it
if not self.compressed:
self._data = zlib.compress(self._data)
self.compressed = True
@property
def data(self):
return zlib.decompress(self._data) if self.compressed else self._data
@classmethod
def unpack(cls, data: bytes):
# unpack the message
timestamp, channels, sample_rate, sample_width, data, compressed = msgpack.unpackb(data)
return cls(
datetime.fromtimestamp(timestamp),
channels,
sample_rate,
sample_width,
data,
compressed,
)

View file

@ -0,0 +1,19 @@
import dataclasses
import msgpack
from source.packets import base
@dataclasses.dataclass
class DiscoveryPacket(base.BasePacket):
"""
Represent a packet used to discover new devices in the network.
"""
def pack(self) -> bytes:
return msgpack.packb(())
@classmethod
def unpack(cls, data: bytes):
return cls()

View file

@ -0,0 +1,23 @@
import dataclasses
import msgpack
from source.packets import base
@dataclasses.dataclass
class KeyPacket(base.BasePacket):
"""
Represent a packet containing a secret symmetric key
"""
secret_key: bytes = dataclasses.field(repr=False)
def pack(self) -> bytes:
return msgpack.packb((
self.secret_key
))
@classmethod
def unpack(cls, data: bytes):
return cls(*msgpack.unpackb(data))

View file

@ -0,0 +1,31 @@
import dataclasses
import msgpack
from . import base
@dataclasses.dataclass
class PeerPacket(base.BasePacket):
"""
Represent a packet used to send information about a peer
"""
# public RSA key of the machine
public_key: bytes = dataclasses.field(repr=False)
# is the machine a master
master: bool = dataclasses.field()
# TODO(Faraphel): share our trusted / banned peers with the other peer so that only one machine need to trust / ban it
# to propagate it to the whole network ?
def pack(self) -> bytes:
return msgpack.packb((
self.public_key,
self.master
))
@classmethod
def unpack(cls, data: bytes):
return cls(*msgpack.unpackb(data))

View file

@ -0,0 +1,19 @@
import dataclasses
import msgpack
from source.packets import base
@dataclasses.dataclass
class RequestKeyPacket(base.BasePacket):
"""
Represent a packet used to request a secret symmetric key
"""
def pack(self) -> bytes:
return msgpack.packb(())
@classmethod
def unpack(cls, data: bytes):
return cls()

View file

@ -0,0 +1,7 @@
from . import base
from .AudioPacket import AudioPacket
from .DiscoveryPacket import DiscoveryPacket
from .PeerPacket import PeerPacket
from .RequestKeyPacket import RequestKeyPacket
from .KeyPacket import KeyPacket

View file

@ -0,0 +1,19 @@
import abc
class BasePacket(abc.ABC):
@abc.abstractmethod
def pack(self) -> bytes:
"""
Serialize the object to bytes.
:return: bytes representing the object
"""
@classmethod
@abc.abstractmethod
def unpack(cls, data: bytes) -> "BasePacket":
"""
Deserialize the object from bytes.
:param data: the data to deserialize
:return: the deserialized object
"""

View file

@ -0,0 +1 @@
from .BasePacket import BasePacket

20
source/structures/Peer.py Normal file
View file

@ -0,0 +1,20 @@
import dataclasses
from datetime import datetime
from typing import Optional
@dataclasses.dataclass
class Peer:
# is the peer a master
master: bool = dataclasses.field()
# public asymmetric key
public_key: bytes = dataclasses.field(repr=False)
# secret symmetric key
secret_key: Optional[bytes] = dataclasses.field(default=None, repr=False)
# is the machine trusted
trusted: bool = dataclasses.field(default=False)
# when did the peer last communication with us occurred
last_interaction: datetime = dataclasses.field(default_factory=datetime.now)

View file

@ -0,0 +1 @@
from .Peer import Peer

1
source/utils/__init__.py Normal file
View file

@ -0,0 +1 @@
from . import crypto

View file

@ -1,27 +0,0 @@
#include "audio.hpp"
#include <stdexcept>
#include <fmt123.h>
#include <portaudio.h>
std::uint32_t encoding_mpg123_to_PulseAudio(const int encoding_mpg123) {
switch (encoding_mpg123) {
case MPG123_ENC_UNSIGNED_8:
return paUInt8;
case MPG123_ENC_SIGNED_8:
return paInt8;
case MPG123_ENC_SIGNED_16:
return paInt16;
case MPG123_ENC_SIGNED_24:
return paInt24;
case MPG123_ENC_SIGNED_32:
return paInt32;
case MPG123_ENC_FLOAT:
case MPG123_ENC_FLOAT_32:
return paFloat32;
default:
throw std::runtime_error("Invalid encoding value.");
}
}

View file

@ -1,6 +0,0 @@
#pragma once
#include <cstdint>
std::uint32_t encoding_mpg123_to_PulseAudio(int encoding_mpg123);

View file

View file

@ -0,0 +1,19 @@
import numpy
def sample_width_to_type(sample_width: int):
"""
Return the numpy type to use depending on the sample width used in an audio sample
:param sample_width: the sample width
:return: the corresponding numpy type
"""
match sample_width:
case 1:
return numpy.int8
case 2:
return numpy.int16
case 4:
return numpy.int32
case _:
return numpy.int16

View file

@ -0,0 +1,3 @@
from . import rsa
from . import aes
from . import type

View file

@ -0,0 +1,99 @@
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
def aes_ecb_encrypt(data: bytes, key: bytes) -> bytes:
"""
Encrypt the message using AES in ECB mode.
:param data: the data to cipher
:param key: the key to use for the cipher
:return: the encrypted data
"""
# pad the data with PKCS7 for AES to work properly
padder = padding.PKCS7(128).padder()
padded_data = padder.update(data) + padder.finalize()
# create the AES cipher in ECB mode
cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
encryptor = cipher.encryptor()
# encrypt the padded data
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return encrypted_data
def aes_ecb_decrypt(encrypted_data: bytes, key: bytes) -> bytes:
"""
Decrypt data encrypted with AES in CBC mode.
:param encrypted_data: the encrypted data
:param key: the key used to encrypt it
:return: the decrypted data
"""
# create the AES cipher in ECB mode
cipher = Cipher(algorithms.AES(key), modes.ECB(), backend=default_backend())
decryptor = cipher.decryptor()
# decrypt the encrypted data
decrypted_data = decryptor.update(encrypted_data) + decryptor.finalize()
# unpad the data
unpadder = padding.PKCS7(128).unpadder()
data = unpadder.update(decrypted_data) + unpadder.finalize()
return data
def aes_cbc_encrypt(data: bytes, key: bytes) -> bytes:
"""
Encrypt the message using AES in CBC mode.
:param data: the data to cipher
:param key: the key to use for the cipher
:return: the encrypted data
"""
# pad the data with PKCS7 for AES to work properly
padder = padding.PKCS7(128).padder()
padded_data = padder.update(data) + padder.finalize()
# create an initialisation vector
iv = os.urandom(16)
# create the AES cipher in CBC mode
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
# encrypt the padded data
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
# prepend the iv to the encrypted data
return iv + encrypted_data
def aes_cbc_decrypt(payload: bytes, key: bytes) -> bytes:
"""
Decrypt data encrypted with AES in CBC mode.
:param payload: the encrypted data
:param key: the key used to encrypt it
:return: the decrypted data
"""
# split the payload into the iv and the encrypted data
iv = payload[:16]
encrypted_data = payload[16:]
# create the AES cipher in CBC mode
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
# decrypt the encrypted data
decrypted_data = decryptor.update(encrypted_data) + decryptor.finalize()
# unpad the data
unpadder = padding.PKCS7(128).unpadder()
data = unpadder.update(decrypted_data) + unpadder.finalize()
return data

View file

@ -0,0 +1,80 @@
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
def rsa_create_key_pair() -> tuple[bytes, bytes]:
"""
Create a pair of private and public RSA key.
:return: a pair of private and public RSA key.
"""
# create a private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
# serialize the private key
private_key_data = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
# get the public key from the private key
public_key = private_key.public_key()
# serialize the public key
public_key_data = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.PKCS1
)
return private_key_data, public_key_data
def rsa_encrypt(data: bytes, public_key_data: bytes) -> bytes:
"""
Encrypt data with RSA using a public key
:param data: the data to encrypt
:param public_key_data: the public key to encrypt with
:return: the encrypted data
"""
# load the public key
public_key = serialization.load_der_public_key(public_key_data)
# verify if the key is loaded
if not isinstance(public_key, rsa.RSAPublicKey):
raise ValueError("Could not load the public key.")
# encrypt the data with the key
return public_key.encrypt(
data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
def rsa_decrypt(encrypted_data: bytes, private_key_data: bytes) -> bytes:
"""
Decrypt the data with the RSA private key
:param encrypted_data: the data to decrypt
:param private_key_data: the private key data
:return: the decrypted data
"""
# load the private key
private_key = serialization.load_der_private_key(private_key_data, None)
# verify if the key is loaded
if not isinstance(private_key, rsa.RSAPrivateKey):
raise ValueError("Could not load the private key.")
# decrypt the data
return private_key.decrypt(
encrypted_data,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)

View file

@ -0,0 +1,18 @@
import enum
import typing
class CipherType(enum.Enum):
PLAIN = 0x00
AES_ECB = 0x01 # legacy
AES_CBC = 0x02
RSA = 0x10
CIPHER_SYMMETRIC_TYPES: typing.Final[list[CipherType]] = [
CipherType.AES_ECB,
CipherType.AES_CBC,
]
CIPHER_ASYMMETRIC_TYPES: typing.Final[list[CipherType]] = [
CipherType.RSA,
]

View file

@ -0,0 +1,42 @@
import collections
from datetime import datetime
class TimestampedDict(collections.UserDict):
"""
A dictionary with additional metadata
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) # NOQA
self._last_modified = datetime.now() # last time a value got modified
self._last_added = datetime.now() # last time a new value have been added
def __setitem__(self, key, value):
# if the key is already used, we only update a value
update = key in self
# set the value
super().__setitem__(key, value)
# update modification time
self._last_modified = datetime.now()
# if this is not an update, set the added time
if not update:
self._last_added = datetime.now()
def __delitem__(self, key):
super().__delitem__(key)
self._last_modified = datetime.now()
def update(self, *args, **kwargs):
super().update(*args, **kwargs) # NOQA
self._last_modified = datetime.now()
@property
def last_modified(self):
return self._last_modified
@property
def last_added(self):
return self._last_added

View file

@ -0,0 +1 @@
from .TimestampedDict import TimestampedDict