diff --git a/LICENSE.md b/LICENSE.md index c90865d..8911863 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -32,4 +32,4 @@ fork : projet se basant sur le code source d'un logiciel déjà existant. --- -Copyright © 2024 - Raphaël CARON \ No newline at end of file +Copyright © 2025 - Raphaël CARON \ No newline at end of file diff --git a/source/behaviors/events/DiscoveryEvent.py b/source/behaviors/events/DiscoveryEvent.py index 03ad7e6..bf1a022 100644 --- a/source/behaviors/events/DiscoveryEvent.py +++ b/source/behaviors/events/DiscoveryEvent.py @@ -1,7 +1,16 @@ from . import base from source import packets +from ...packets import PeerPacket +from ...utils.crypto.type import CipherType class DiscoveryEvent(base.BaseEvent): + """ + Event reacting to a machine trying to discover the network. + """ + def handle(self, packet: packets.DiscoveryPacket, address: tuple): - print("discovery event !") + # create a peer packet containing our important information + peerPacket = PeerPacket(self.manager.communication.public_key) + # send our information back + self.manager.communication.send(peerPacket, CipherType.PLAIN, address) diff --git a/source/behaviors/events/PeerEvent.py b/source/behaviors/events/PeerEvent.py new file mode 100644 index 0000000..7587962 --- /dev/null +++ b/source/behaviors/events/PeerEvent.py @@ -0,0 +1,15 @@ +from . import base +from source import packets + + +class PeerEvent(base.BaseEvent): + """ + Event reacting to receiving information about another machine + """ + + def handle(self, packet: packets.PeerPacket, address: tuple): + # check if the peer is new + if address not in self.manager.peers: + # add the peer to the peers list + self.manager.peers[address] = packet + print("new peer discovered !") diff --git a/source/behaviors/events/__init__.py b/source/behaviors/events/__init__.py index 60163d0..a3276f6 100644 --- a/source/behaviors/events/__init__.py +++ b/source/behaviors/events/__init__.py @@ -1,3 +1,4 @@ from . import base from .DiscoveryEvent import DiscoveryEvent +from .PeerEvent import PeerEvent diff --git a/source/behaviors/events/base/BaseEvent.py b/source/behaviors/events/base/BaseEvent.py index 21ed149..6261377 100644 --- a/source/behaviors/events/base/BaseEvent.py +++ b/source/behaviors/events/base/BaseEvent.py @@ -1,9 +1,12 @@ import abc -from source import packets +from source import packets, managers class BaseEvent(abc.ABC): + def __init__(self, manager: "managers.Manager"): + self.manager = manager + @abc.abstractmethod def handle(self, packet: packets.base.BasePacket, address: tuple) -> None: """ diff --git a/source/behaviors/roles/UndefinedRole.py b/source/behaviors/roles/UndefinedRole.py index 3e374f9..d711eb7 100644 --- a/source/behaviors/roles/UndefinedRole.py +++ b/source/behaviors/roles/UndefinedRole.py @@ -1,9 +1,16 @@ +import time + from . import base -from source import managers, packets +from source import packets +from ...utils.crypto.type import CipherType class UndefinedRole(base.BaseRole): - def handle(self, manager: "managers.Manager"): + def handle(self): + # discover new peers packet = packets.DiscoveryPacket() - manager.communication.broadcast(packet) + self.manager.communication.broadcast(packet, CipherType.PLAIN) + + # wait + time.sleep(1) \ No newline at end of file diff --git a/source/behaviors/roles/base/BaseRole.py b/source/behaviors/roles/base/BaseRole.py index d1c732f..9b8bc7f 100644 --- a/source/behaviors/roles/base/BaseRole.py +++ b/source/behaviors/roles/base/BaseRole.py @@ -4,8 +4,11 @@ from source import managers class BaseRole(abc.ABC): + def __init__(self, manager: "managers.Manager"): + self.manager = manager + @abc.abstractmethod - def handle(self, manager: "managers.Manager") -> None: + def handle(self) -> None: """ Behavior of the role """ diff --git a/source/managers/CommunicationManager.py b/source/managers/CommunicationManager.py index 2a54281..d2aa527 100644 --- a/source/managers/CommunicationManager.py +++ b/source/managers/CommunicationManager.py @@ -4,8 +4,9 @@ import zlib import bidict -from source import packets +from source import packets, utils from source.managers import Manager +from source.utils.crypto.rsa import rsa_create_key_pair from source.utils.crypto.type import CipherType @@ -32,7 +33,10 @@ class CommunicationManager: # create a dictionary to hold the types of packets and their headers. self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict() - # the secret key used for AES communication + # create a private and public key for RSA communication + self.private_key, self.public_key = rsa_create_key_pair() + + # TODO(Faraphel): should be decided by the server when changing role, stored somewhere else self.secret_key: bytes = b"secret key!" def __del__(self): @@ -70,7 +74,28 @@ class CommunicationManager: # calculate its checksum using CRC32 checksum = zlib.crc32(data).to_bytes(4, byteorder='big') - return checksum + header + data + # get the packet data + packet_data = checksum + header + data + + # encrypt the packet data depending on the cipher selected + match cipher_type: + case CipherType.PLAIN: + pass + + case CipherType.RSA: + packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key) + + case CipherType.AES_ECB: + packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, self.secret_key) + + case _: + raise ValueError(f"Unknown cipher: {cipher_type}") + + # prepend the cipher type to the packet data + payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data + + return payload + def packet_decode(self, payload: bytes) -> packets.base.BasePacket: """ @@ -79,10 +104,27 @@ class CommunicationManager: :return: the deserialized packet """ + cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big")) + packet_data: bytes = payload[2:] + + # decrypt the packet data depending on the cipher used + match cipher_type: + case CipherType.PLAIN: + pass + + case CipherType.RSA: + packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key) + + case CipherType.AES_ECB: + packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, self.secret_key) + + case _: + raise ValueError(f"Unknown cipher: {cipher_type}") + # split the header and data from the raw payload - checksum: int = int.from_bytes(payload[:4], "big") - header: bytes = payload[4:8] - data: bytes = payload[8:] + checksum: int = int.from_bytes(packet_data[:4], "big") + header: bytes = packet_data[4:8] + data: bytes = packet_data[8:] # verify the checksum for corruption if zlib.crc32(data) != checksum: @@ -96,16 +138,22 @@ class CommunicationManager: # unpack the packet return packet_type.unpack(data) - def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType = None): + def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple): + self.socket.sendto(self.packet_encode(packet, cipher_type), address) + + def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType): """ Broadcast a message in the network :param cipher_type: the type of cipher :param packet: the message to broadcast """ - # TODO(Faraphel): should encrypt the data if required, prepend encryption mode + # check that no asymmetric cipher mode is used + if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES: + raise ValueError("Asymmetric cipher cannot be used in broadcast.") + # TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???) - self.socket.sendto(self.packet_encode(packet, cipher_type), (self.broadcast_address, self.port)) + self.send(packet, cipher_type, (self.broadcast_address, self.port)) def receive(self) -> tuple[packets.base.BasePacket, tuple]: """ @@ -117,5 +165,3 @@ class CommunicationManager: payload, address = self.socket.recvfrom(65536) # decode the payload return self.packet_decode(payload), address - - # TODO(Faraphel): should decrypt the data diff --git a/source/managers/EventManager.py b/source/managers/EventManager.py index ffe915a..ae97316 100644 --- a/source/managers/EventManager.py +++ b/source/managers/EventManager.py @@ -7,6 +7,11 @@ from source.managers import Manager class EventManager: + """ + Event Manager + Responsible for receiving packets from other peers and handling them. + """ + def __init__(self, manager: "Manager"): self.manager = manager @@ -37,7 +42,11 @@ class EventManager: # use the event handler on the packet event_handler.handle(packet, address) - def loop(self): + def loop(self) -> None: + """ + Handle events forever + """ + while True: try: # wait for a new packet diff --git a/source/managers/Manager.py b/source/managers/Manager.py index 4f9603a..1362b8f 100644 --- a/source/managers/Manager.py +++ b/source/managers/Manager.py @@ -5,21 +5,34 @@ from source.behaviors import events class Manager: + """ + Global manager + """ + def __init__(self, interface: str): from . import CommunicationManager, EventManager, RoleManager # communication manager self.communication = CommunicationManager(self, interface) self.communication.register_packet_type(b"DISC", packets.DiscoveryPacket) + self.communication.register_packet_type(b"PEER", packets.PeerPacket) # event manager self.event = EventManager(self) - self.event.register_event_handler(packets.DiscoveryPacket, events.DiscoveryEvent()) + self.event.register_event_handler(packets.DiscoveryPacket, events.DiscoveryEvent(self)) + self.event.register_event_handler(packets.PeerPacket, events.PeerEvent(self)) # role manager self.role = RoleManager(self) - def loop(self): + # set of addresses associated to their peer + self.peers: dict[tuple, packets.PeerPacket] = {} + + def loop(self) -> None: + """ + Handle the event and role managers forever + """ + # run a thread for the event and the role manager event_thread = threading.Thread(target=self.event.loop) role_thread = threading.Thread(target=self.role.loop) @@ -28,4 +41,4 @@ class Manager: role_thread.start() event_thread.join() - role_thread.join() \ No newline at end of file + role_thread.join() diff --git a/source/managers/RoleManager.py b/source/managers/RoleManager.py index 2d93452..921e5ec 100644 --- a/source/managers/RoleManager.py +++ b/source/managers/RoleManager.py @@ -3,18 +3,22 @@ from source.managers import Manager class RoleManager: + """ + Role Manager + Responsible for the passive behavior of the machine and sending packets + """ def __init__(self, manager: "Manager"): self.manager = manager # the currently used role - self.current: roles.base.BaseRole = roles.UndefinedRole() + self.current: roles.base.BaseRole = roles.UndefinedRole(self.manager) def handle(self) -> None: """ Run the role """ - self.current.handle(self.manager) + self.current.handle() def loop(self) -> None: """ diff --git a/source/packets/AudioPacket.py b/source/packets/AudioPacket.py index b12f94e..ce7c9af 100644 --- a/source/packets/AudioPacket.py +++ b/source/packets/AudioPacket.py @@ -1,21 +1,21 @@ +import dataclasses + import msgpack from source.packets import base +@dataclasses.dataclass class AudioPacket(base.BasePacket): """ Represent a packet of audio data """ - def __init__(self, data: bytes, rate: int, channels: int, encoding: int): - super().__init__() + data: bytes = dataclasses.field() - self.data = data - - self.rate = rate - self.channels = channels - self.encoding = encoding + rate: int = dataclasses.field() + channels: int = dataclasses.field() + encoding: int = dataclasses.field() def pack(self) -> bytes: return msgpack.packb(( diff --git a/source/packets/DiscoveryPacket.py b/source/packets/DiscoveryPacket.py index 16bf3f9..e48a749 100644 --- a/source/packets/DiscoveryPacket.py +++ b/source/packets/DiscoveryPacket.py @@ -1,19 +1,16 @@ +import dataclasses + import msgpack from source.packets import base +@dataclasses.dataclass class DiscoveryPacket(base.BasePacket): """ Represent a packet used to discover new devices in the network. """ - def __init__(self): - super().__init__() - - def __repr__(self) -> str: - return f"<{self.__class__.__name__}>" - def pack(self) -> bytes: return msgpack.packb(()) diff --git a/source/packets/PeerPacket.py b/source/packets/PeerPacket.py new file mode 100644 index 0000000..51a04d9 --- /dev/null +++ b/source/packets/PeerPacket.py @@ -0,0 +1,24 @@ +import dataclasses + +import msgpack + +from . import base + + +@dataclasses.dataclass +class PeerPacket(base.BasePacket): + """ + Represent a packet used to send information about a peer + """ + + # public RSA key of the machine + public_key: bytes = dataclasses.field(repr=False) + + def pack(self) -> bytes: + return msgpack.packb(( + self.public_key, + )) + + @classmethod + def unpack(cls, data: bytes): + return cls(*msgpack.unpackb(data)) diff --git a/source/packets/__init__.py b/source/packets/__init__.py index 2ce4785..c047519 100644 --- a/source/packets/__init__.py +++ b/source/packets/__init__.py @@ -2,3 +2,4 @@ from . import base from .AudioPacket import AudioPacket from .DiscoveryPacket import DiscoveryPacket +from .PeerPacket import PeerPacket diff --git a/source/utils/__init__.py b/source/utils/__init__.py index e69de29..a2795d2 100644 --- a/source/utils/__init__.py +++ b/source/utils/__init__.py @@ -0,0 +1 @@ +from . import crypto diff --git a/source/utils/crypto/__init__.py b/source/utils/crypto/__init__.py index e69de29..c1998b8 100644 --- a/source/utils/crypto/__init__.py +++ b/source/utils/crypto/__init__.py @@ -0,0 +1,3 @@ +from . import rsa +from . import aes +from . import type diff --git a/source/utils/crypto/rsa.py b/source/utils/crypto/rsa.py index 77a4cd5..e49a995 100644 --- a/source/utils/crypto/rsa.py +++ b/source/utils/crypto/rsa.py @@ -2,6 +2,35 @@ from cryptography.hazmat.primitives import serialization, hashes from cryptography.hazmat.primitives.asymmetric import rsa, padding +def rsa_create_key_pair() -> tuple[bytes, bytes]: + """ + Create a pair of private and public RSA key. + :return: a pair of private and public RSA key. + """ + + # create a private key + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048 + ) + # serialize the private key + private_key_data = private_key.private_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption() + ) + + # get the public key from the private key + public_key = private_key.public_key() + # serialize the public key + public_key_data = public_key.public_bytes( + encoding=serialization.Encoding.DER, + format=serialization.PublicFormat.PKCS1 + ) + + return private_key_data, public_key_data + + def rsa_encrypt(data: bytes, public_key_data: bytes) -> bytes: """ Encrypt data with RSA using a public key diff --git a/source/utils/crypto/type.py b/source/utils/crypto/type.py index 1e8fbd7..b9aa530 100644 --- a/source/utils/crypto/type.py +++ b/source/utils/crypto/type.py @@ -1,7 +1,17 @@ import enum +import typing class CipherType(enum.Enum): - NONE = 0x00 + PLAIN = 0x00 AES_ECB = 0x01 RSA = 0x02 + + +CIPHER_SYMMETRIC_TYPES: typing.Final[list[CipherType]] = [ + CipherType.PLAIN, + CipherType.AES_ECB +] +CIPHER_ASYMMETRIC_TYPES: typing.Final[list[CipherType]] = [ + CipherType.RSA +] \ No newline at end of file diff --git a/source/utils/crypto/universal.py b/source/utils/crypto/universal.py deleted file mode 100644 index fb6b7b1..0000000 --- a/source/utils/crypto/universal.py +++ /dev/null @@ -1,57 +0,0 @@ -import typing - -from source.utils.crypto import aes, rsa -from source.utils.crypto.type import CipherType - - -def encrypt(data: bytes, key: typing.Optional[bytes] = None, cipher_type: CipherType = CipherType.NONE) -> bytes: - """ - Encrypt data on various cipher type. - :param data: the data to cipher - :param key: the key to cipher the data - :param cipher_type: the type of cipher to use - :return: - """ - - match cipher_type: - case CipherType.NONE: - return data - - case CipherType.AES_ECB: - if key is None: - raise ValueError("The key cannot be None.") - return aes.aes_ecb_encrypt(data, key) - - case CipherType.RSA: - if key is None: - raise ValueError("The key cannot be None.") - return rsa.rsa_encrypt(data, key) - - case _: - raise KeyError("Unknown cipher mode.") - -def decrypt(data: bytes, key: typing.Optional[bytes] = None, cipher_type: CipherType = CipherType.NONE) -> bytes: - """ - Encrypt data on various cipher type. - :param data: the data to decipher - :param key: the key to cipher the data - :param cipher_type: the type of cipher to use - :return: - """ - - match cipher_type: - case CipherType.NONE: - return data - - case CipherType.AES_ECB: - if key is None: - raise ValueError("The key cannot be None.") - return aes.aes_ecb_decrypt(data, key) - - case CipherType.RSA: - if key is None: - raise ValueError("The key cannot be None.") - return rsa.rsa_decrypt(data, key) - - case _: - raise KeyError("Unknown cipher mode.")