M2-PT-DRP/source/managers/CommunicationManager.py

167 lines
6 KiB
Python

import socket
import typing
import zlib
import bidict
from source import packets, utils
from source.managers import Manager
from source.utils.crypto.rsa import rsa_create_key_pair
from source.utils.crypto.type import CipherType
class CommunicationManager:
"""
Manage everything about communication
"""
def __init__(self, manager: "Manager", interface: str, broadcast_address: str = "ff02::1", port: int = 5555):
self.manager = manager
self.broadcast_address = broadcast_address
self.port = port
# create an IPv6 UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# enable broadcast messages
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
# use multicast on the selected interface
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(interface))
# bind to listen for any message on this port
self.socket.bind(("::", self.port))
# create a dictionary to hold the types of packets and their headers.
self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict()
# create a private and public key for RSA communication
self.private_key, self.public_key = rsa_create_key_pair()
# TODO(Faraphel): should be decided by the server when changing role, stored somewhere else
self.secret_key: bytes = b"secret key!"
def __del__(self):
# close the socket
self.socket.close()
def register_packet_type(self, header: bytes, packet_type: typing.Type[packets.base.BasePacket]) -> None:
"""
Register a new kind of packet that can be sent or received.
:param header: the binary header identifying the packet
:param packet_type: the class of the packet
"""
if len(header) != 4:
raise Exception("The header should be exactly 4 bytes long.")
self.packet_types[header] = packet_type
def packet_encode(self, packet: packets.base.BasePacket, cipher_type: CipherType) -> bytes:
"""
Encode a packet for diffusion
:param packet: a packet to encode to be sent
:param cipher_type: the type of cipher
:return: an encoded packet
"""
# get the header identifier of the type of this packet
header: typing.Optional[bytes] = self.packet_types.inverse.get(type(packet))
if header is None:
raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?")
# get the encoded packet data
data = packet.pack()
# calculate its checksum using CRC32
checksum = zlib.crc32(data).to_bytes(4, byteorder='big')
# get the packet data
packet_data = checksum + header + data
# encrypt the packet data depending on the cipher selected
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key)
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, self.secret_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# prepend the cipher type to the packet data
payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data
return payload
def packet_decode(self, payload: bytes) -> packets.base.BasePacket:
"""
Decode a payload into a packet
:param payload: the data of the packet
:return: the deserialized packet
"""
cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big"))
packet_data: bytes = payload[2:]
# decrypt the packet data depending on the cipher used
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key)
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, self.secret_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# split the header and data from the raw payload
checksum: int = int.from_bytes(packet_data[:4], "big")
header: bytes = packet_data[4:8]
data: bytes = packet_data[8:]
# verify the checksum for corruption
if zlib.crc32(data) != checksum:
raise ValueError("The checksum is invalid.")
# get the type of the packet from its header
packet_type: typing.Optional[typing.Type[packets.base.BasePacket]] = self.packet_types.get(header)
if header is None:
raise KeyError(f"Unrecognised packet header: {header}. Has it been registered ?")
# unpack the packet
return packet_type.unpack(data)
def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple):
self.socket.sendto(self.packet_encode(packet, cipher_type), address)
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType):
"""
Broadcast a message in the network
:param cipher_type: the type of cipher
:param packet: the message to broadcast
"""
# check that no asymmetric cipher mode is used
if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES:
raise ValueError("Asymmetric cipher cannot be used in broadcast.")
# TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???)
self.send(packet, cipher_type, (self.broadcast_address, self.port))
def receive(self) -> tuple[packets.base.BasePacket, tuple]:
"""
Receive a packet
:return: the packet content alongside the address of the sender
"""
# receive a message
payload, address = self.socket.recvfrom(65536)
# decode the payload
return self.packet_decode(payload), address