fully added encryption support

This commit is contained in:
study-faraphel 2025-01-04 00:34:00 +01:00
parent 2286375bae
commit 8038b8e40c
20 changed files with 213 additions and 95 deletions

View file

@ -32,4 +32,4 @@ fork : projet se basant sur le code source d'un logiciel déjà existant.
--- ---
Copyright © 2024 - Raphaël CARON Copyright © 2025 - Raphaël CARON

View file

@ -1,7 +1,16 @@
from . import base from . import base
from source import packets from source import packets
from ...packets import PeerPacket
from ...utils.crypto.type import CipherType
class DiscoveryEvent(base.BaseEvent): class DiscoveryEvent(base.BaseEvent):
"""
Event reacting to a machine trying to discover the network.
"""
def handle(self, packet: packets.DiscoveryPacket, address: tuple): def handle(self, packet: packets.DiscoveryPacket, address: tuple):
print("discovery event !") # create a peer packet containing our important information
peerPacket = PeerPacket(self.manager.communication.public_key)
# send our information back
self.manager.communication.send(peerPacket, CipherType.PLAIN, address)

View file

@ -0,0 +1,15 @@
from . import base
from source import packets
class PeerEvent(base.BaseEvent):
"""
Event reacting to receiving information about another machine
"""
def handle(self, packet: packets.PeerPacket, address: tuple):
# check if the peer is new
if address not in self.manager.peers:
# add the peer to the peers list
self.manager.peers[address] = packet
print("new peer discovered !")

View file

@ -1,3 +1,4 @@
from . import base from . import base
from .DiscoveryEvent import DiscoveryEvent from .DiscoveryEvent import DiscoveryEvent
from .PeerEvent import PeerEvent

View file

@ -1,9 +1,12 @@
import abc import abc
from source import packets from source import packets, managers
class BaseEvent(abc.ABC): class BaseEvent(abc.ABC):
def __init__(self, manager: "managers.Manager"):
self.manager = manager
@abc.abstractmethod @abc.abstractmethod
def handle(self, packet: packets.base.BasePacket, address: tuple) -> None: def handle(self, packet: packets.base.BasePacket, address: tuple) -> None:
""" """

View file

@ -1,9 +1,16 @@
import time
from . import base from . import base
from source import managers, packets from source import packets
from ...utils.crypto.type import CipherType
class UndefinedRole(base.BaseRole): class UndefinedRole(base.BaseRole):
def handle(self, manager: "managers.Manager"): def handle(self):
# discover new peers
packet = packets.DiscoveryPacket() packet = packets.DiscoveryPacket()
manager.communication.broadcast(packet) self.manager.communication.broadcast(packet, CipherType.PLAIN)
# wait
time.sleep(1)

View file

@ -4,8 +4,11 @@ from source import managers
class BaseRole(abc.ABC): class BaseRole(abc.ABC):
def __init__(self, manager: "managers.Manager"):
self.manager = manager
@abc.abstractmethod @abc.abstractmethod
def handle(self, manager: "managers.Manager") -> None: def handle(self) -> None:
""" """
Behavior of the role Behavior of the role
""" """

View file

@ -4,8 +4,9 @@ import zlib
import bidict import bidict
from source import packets from source import packets, utils
from source.managers import Manager from source.managers import Manager
from source.utils.crypto.rsa import rsa_create_key_pair
from source.utils.crypto.type import CipherType from source.utils.crypto.type import CipherType
@ -32,7 +33,10 @@ class CommunicationManager:
# create a dictionary to hold the types of packets and their headers. # create a dictionary to hold the types of packets and their headers.
self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict() self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict()
# the secret key used for AES communication # create a private and public key for RSA communication
self.private_key, self.public_key = rsa_create_key_pair()
# TODO(Faraphel): should be decided by the server when changing role, stored somewhere else
self.secret_key: bytes = b"secret key!" self.secret_key: bytes = b"secret key!"
def __del__(self): def __del__(self):
@ -70,7 +74,28 @@ class CommunicationManager:
# calculate its checksum using CRC32 # calculate its checksum using CRC32
checksum = zlib.crc32(data).to_bytes(4, byteorder='big') checksum = zlib.crc32(data).to_bytes(4, byteorder='big')
return checksum + header + data # get the packet data
packet_data = checksum + header + data
# encrypt the packet data depending on the cipher selected
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key)
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, self.secret_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# prepend the cipher type to the packet data
payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data
return payload
def packet_decode(self, payload: bytes) -> packets.base.BasePacket: def packet_decode(self, payload: bytes) -> packets.base.BasePacket:
""" """
@ -79,10 +104,27 @@ class CommunicationManager:
:return: the deserialized packet :return: the deserialized packet
""" """
cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big"))
packet_data: bytes = payload[2:]
# decrypt the packet data depending on the cipher used
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key)
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, self.secret_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# split the header and data from the raw payload # split the header and data from the raw payload
checksum: int = int.from_bytes(payload[:4], "big") checksum: int = int.from_bytes(packet_data[:4], "big")
header: bytes = payload[4:8] header: bytes = packet_data[4:8]
data: bytes = payload[8:] data: bytes = packet_data[8:]
# verify the checksum for corruption # verify the checksum for corruption
if zlib.crc32(data) != checksum: if zlib.crc32(data) != checksum:
@ -96,16 +138,22 @@ class CommunicationManager:
# unpack the packet # unpack the packet
return packet_type.unpack(data) return packet_type.unpack(data)
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType = None): def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple):
self.socket.sendto(self.packet_encode(packet, cipher_type), address)
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType):
""" """
Broadcast a message in the network Broadcast a message in the network
:param cipher_type: the type of cipher :param cipher_type: the type of cipher
:param packet: the message to broadcast :param packet: the message to broadcast
""" """
# TODO(Faraphel): should encrypt the data if required, prepend encryption mode # check that no asymmetric cipher mode is used
if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES:
raise ValueError("Asymmetric cipher cannot be used in broadcast.")
# TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???) # TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???)
self.socket.sendto(self.packet_encode(packet, cipher_type), (self.broadcast_address, self.port)) self.send(packet, cipher_type, (self.broadcast_address, self.port))
def receive(self) -> tuple[packets.base.BasePacket, tuple]: def receive(self) -> tuple[packets.base.BasePacket, tuple]:
""" """
@ -117,5 +165,3 @@ class CommunicationManager:
payload, address = self.socket.recvfrom(65536) payload, address = self.socket.recvfrom(65536)
# decode the payload # decode the payload
return self.packet_decode(payload), address return self.packet_decode(payload), address
# TODO(Faraphel): should decrypt the data

View file

@ -7,6 +7,11 @@ from source.managers import Manager
class EventManager: class EventManager:
"""
Event Manager
Responsible for receiving packets from other peers and handling them.
"""
def __init__(self, manager: "Manager"): def __init__(self, manager: "Manager"):
self.manager = manager self.manager = manager
@ -37,7 +42,11 @@ class EventManager:
# use the event handler on the packet # use the event handler on the packet
event_handler.handle(packet, address) event_handler.handle(packet, address)
def loop(self): def loop(self) -> None:
"""
Handle events forever
"""
while True: while True:
try: try:
# wait for a new packet # wait for a new packet

View file

@ -5,21 +5,34 @@ from source.behaviors import events
class Manager: class Manager:
"""
Global manager
"""
def __init__(self, interface: str): def __init__(self, interface: str):
from . import CommunicationManager, EventManager, RoleManager from . import CommunicationManager, EventManager, RoleManager
# communication manager # communication manager
self.communication = CommunicationManager(self, interface) self.communication = CommunicationManager(self, interface)
self.communication.register_packet_type(b"DISC", packets.DiscoveryPacket) self.communication.register_packet_type(b"DISC", packets.DiscoveryPacket)
self.communication.register_packet_type(b"PEER", packets.PeerPacket)
# event manager # event manager
self.event = EventManager(self) self.event = EventManager(self)
self.event.register_event_handler(packets.DiscoveryPacket, events.DiscoveryEvent()) self.event.register_event_handler(packets.DiscoveryPacket, events.DiscoveryEvent(self))
self.event.register_event_handler(packets.PeerPacket, events.PeerEvent(self))
# role manager # role manager
self.role = RoleManager(self) self.role = RoleManager(self)
def loop(self): # set of addresses associated to their peer
self.peers: dict[tuple, packets.PeerPacket] = {}
def loop(self) -> None:
"""
Handle the event and role managers forever
"""
# run a thread for the event and the role manager # run a thread for the event and the role manager
event_thread = threading.Thread(target=self.event.loop) event_thread = threading.Thread(target=self.event.loop)
role_thread = threading.Thread(target=self.role.loop) role_thread = threading.Thread(target=self.role.loop)
@ -28,4 +41,4 @@ class Manager:
role_thread.start() role_thread.start()
event_thread.join() event_thread.join()
role_thread.join() role_thread.join()

View file

@ -3,18 +3,22 @@ from source.managers import Manager
class RoleManager: class RoleManager:
"""
Role Manager
Responsible for the passive behavior of the machine and sending packets
"""
def __init__(self, manager: "Manager"): def __init__(self, manager: "Manager"):
self.manager = manager self.manager = manager
# the currently used role # the currently used role
self.current: roles.base.BaseRole = roles.UndefinedRole() self.current: roles.base.BaseRole = roles.UndefinedRole(self.manager)
def handle(self) -> None: def handle(self) -> None:
""" """
Run the role Run the role
""" """
self.current.handle(self.manager) self.current.handle()
def loop(self) -> None: def loop(self) -> None:
""" """

View file

@ -1,21 +1,21 @@
import dataclasses
import msgpack import msgpack
from source.packets import base from source.packets import base
@dataclasses.dataclass
class AudioPacket(base.BasePacket): class AudioPacket(base.BasePacket):
""" """
Represent a packet of audio data Represent a packet of audio data
""" """
def __init__(self, data: bytes, rate: int, channels: int, encoding: int): data: bytes = dataclasses.field()
super().__init__()
self.data = data rate: int = dataclasses.field()
channels: int = dataclasses.field()
self.rate = rate encoding: int = dataclasses.field()
self.channels = channels
self.encoding = encoding
def pack(self) -> bytes: def pack(self) -> bytes:
return msgpack.packb(( return msgpack.packb((

View file

@ -1,19 +1,16 @@
import dataclasses
import msgpack import msgpack
from source.packets import base from source.packets import base
@dataclasses.dataclass
class DiscoveryPacket(base.BasePacket): class DiscoveryPacket(base.BasePacket):
""" """
Represent a packet used to discover new devices in the network. Represent a packet used to discover new devices in the network.
""" """
def __init__(self):
super().__init__()
def __repr__(self) -> str:
return f"<{self.__class__.__name__}>"
def pack(self) -> bytes: def pack(self) -> bytes:
return msgpack.packb(()) return msgpack.packb(())

View file

@ -0,0 +1,24 @@
import dataclasses
import msgpack
from . import base
@dataclasses.dataclass
class PeerPacket(base.BasePacket):
"""
Represent a packet used to send information about a peer
"""
# public RSA key of the machine
public_key: bytes = dataclasses.field(repr=False)
def pack(self) -> bytes:
return msgpack.packb((
self.public_key,
))
@classmethod
def unpack(cls, data: bytes):
return cls(*msgpack.unpackb(data))

View file

@ -2,3 +2,4 @@ from . import base
from .AudioPacket import AudioPacket from .AudioPacket import AudioPacket
from .DiscoveryPacket import DiscoveryPacket from .DiscoveryPacket import DiscoveryPacket
from .PeerPacket import PeerPacket

View file

@ -0,0 +1 @@
from . import crypto

View file

@ -0,0 +1,3 @@
from . import rsa
from . import aes
from . import type

View file

@ -2,6 +2,35 @@ from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding from cryptography.hazmat.primitives.asymmetric import rsa, padding
def rsa_create_key_pair() -> tuple[bytes, bytes]:
"""
Create a pair of private and public RSA key.
:return: a pair of private and public RSA key.
"""
# create a private key
private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
# serialize the private key
private_key_data = private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
# get the public key from the private key
public_key = private_key.public_key()
# serialize the public key
public_key_data = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.PKCS1
)
return private_key_data, public_key_data
def rsa_encrypt(data: bytes, public_key_data: bytes) -> bytes: def rsa_encrypt(data: bytes, public_key_data: bytes) -> bytes:
""" """
Encrypt data with RSA using a public key Encrypt data with RSA using a public key

View file

@ -1,7 +1,17 @@
import enum import enum
import typing
class CipherType(enum.Enum): class CipherType(enum.Enum):
NONE = 0x00 PLAIN = 0x00
AES_ECB = 0x01 AES_ECB = 0x01
RSA = 0x02 RSA = 0x02
CIPHER_SYMMETRIC_TYPES: typing.Final[list[CipherType]] = [
CipherType.PLAIN,
CipherType.AES_ECB
]
CIPHER_ASYMMETRIC_TYPES: typing.Final[list[CipherType]] = [
CipherType.RSA
]

View file

@ -1,57 +0,0 @@
import typing
from source.utils.crypto import aes, rsa
from source.utils.crypto.type import CipherType
def encrypt(data: bytes, key: typing.Optional[bytes] = None, cipher_type: CipherType = CipherType.NONE) -> bytes:
"""
Encrypt data on various cipher type.
:param data: the data to cipher
:param key: the key to cipher the data
:param cipher_type: the type of cipher to use
:return:
"""
match cipher_type:
case CipherType.NONE:
return data
case CipherType.AES_ECB:
if key is None:
raise ValueError("The key cannot be None.")
return aes.aes_ecb_encrypt(data, key)
case CipherType.RSA:
if key is None:
raise ValueError("The key cannot be None.")
return rsa.rsa_encrypt(data, key)
case _:
raise KeyError("Unknown cipher mode.")
def decrypt(data: bytes, key: typing.Optional[bytes] = None, cipher_type: CipherType = CipherType.NONE) -> bytes:
"""
Encrypt data on various cipher type.
:param data: the data to decipher
:param key: the key to cipher the data
:param cipher_type: the type of cipher to use
:return:
"""
match cipher_type:
case CipherType.NONE:
return data
case CipherType.AES_ECB:
if key is None:
raise ValueError("The key cannot be None.")
return aes.aes_ecb_decrypt(data, key)
case CipherType.RSA:
if key is None:
raise ValueError("The key cannot be None.")
return rsa.rsa_decrypt(data, key)
case _:
raise KeyError("Unknown cipher mode.")