import socket import typing import zlib from datetime import datetime import bidict from source import packets, utils, structures from source.managers import Manager from source.utils.crypto.rsa import rsa_create_key_pair from source.utils.crypto.type import CipherType class CommunicationManager: """ Manage the communication between the peers """ def __init__(self, manager: "Manager", interface: str, broadcast_address: str = "ff02::1", port: int = 5555): self.manager = manager self.broadcast_address = broadcast_address self.port = port # create an IPv6 UDP socket self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) # enable broadcast messages self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True) # use multicast on the selected interface self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(interface)) # bind to listen for any message on this port self.socket.bind(("::", self.port)) # create a dictionary to hold the types of packets and their headers. self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict() # create a private and public key for RSA communication self.private_key, self.public_key = rsa_create_key_pair() # TODO(Faraphel): should be decided by the server when changing role, stored somewhere else self.secret_key: typing.Optional[bytes] = None def __del__(self): # close the socket self.socket.close() def register_packet_type(self, header: bytes, packet_type: typing.Type[packets.base.BasePacket]) -> None: """ Register a new kind of packet that can be sent or received. :param header: the binary header identifying the packet :param packet_type: the class of the packet """ if len(header) != 4: raise Exception("The header should be exactly 4 bytes long.") self.packet_types[header] = packet_type def packet_encode(self, packet: packets.base.BasePacket, cipher_type: CipherType) -> bytes: """ Encode a packet for diffusion :param packet: a packet to encode to be sent :param cipher_type: the type of cipher :return: an encoded packet """ # get the header identifier of the type of this packet header: typing.Optional[bytes] = self.packet_types.inverse.get(type(packet)) if header is None: raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?") # get the encoded packet data data = packet.pack() # calculate its checksum using CRC32 checksum = zlib.crc32(data).to_bytes(4, byteorder='big') # get the packet data packet_data = checksum + header + data # encrypt the packet data depending on the cipher selected match cipher_type: case CipherType.PLAIN: pass case CipherType.AES_ECB: packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, self.secret_key) case CipherType.AES_CBC: packet_data = utils.crypto.aes.aes_cbc_encrypt(packet_data, self.secret_key) case CipherType.RSA: packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key) case _: raise ValueError(f"Unknown cipher: {cipher_type}") # prepend the cipher type to the packet data payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data return payload def packet_decode(self, payload: bytes) -> packets.base.BasePacket: """ Decode a payload into a packet :param payload: the data of the packet :return: the deserialized packet """ cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big")) packet_data: bytes = payload[2:] # decrypt the packet data depending on the cipher used match cipher_type: case CipherType.PLAIN: pass case CipherType.AES_ECB: packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, self.secret_key) case CipherType.AES_CBC: packet_data = utils.crypto.aes.aes_cbc_decrypt(packet_data, self.secret_key) case CipherType.RSA: packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key) case _: raise ValueError(f"Unknown cipher: {cipher_type}") # split the header and data from the raw payload checksum: int = int.from_bytes(packet_data[:4], "big") header: bytes = packet_data[4:8] data: bytes = packet_data[8:] # verify the checksum for corruption if zlib.crc32(data) != checksum: raise ValueError("The checksum is invalid.") # get the type of the packet from its header packet_type: typing.Optional[typing.Type[packets.base.BasePacket]] = self.packet_types.get(header) if header is None: raise KeyError(f"Unrecognised packet header: {header}. Has it been registered ?") # unpack the packet return packet_type.unpack(data) def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple): self.socket.sendto(self.packet_encode(packet, cipher_type), address) def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType): """ Broadcast a message in the network :param cipher_type: the type of cipher :param packet: the message to broadcast """ # check that no asymmetric cipher mode is used if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES: raise ValueError("Asymmetric cipher cannot be used in broadcast.") # TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???) self.send(packet, cipher_type, (self.broadcast_address, self.port)) def receive(self) -> tuple[packets.base.BasePacket, tuple]: """ Receive a packet :return: the packet content alongside the address of the sender """ # receive a message payload, address = self.socket.recvfrom(65536) # check if there is a peer associated with this address peer: structures.Peer = self.manager.peer.peers.get(address) if peer is not None: # update the latest interaction date peer.last_interaction = datetime.now() # decode the payload return self.packet_decode(payload), address @staticmethod def get_local_addresses() -> list[tuple]: """ Get the local addresses of the machine :return: the local addresses of the machine """ return socket.getaddrinfo(socket.gethostname(), None) def is_address_local(self, address: tuple) -> bool: """ Is the given address local :return: true if the address is local, false otherwise """ host, _, _, scope = address # check for all the interfaces of our machine for interface in self.get_local_addresses(): # unpack the interface information interface_family, _, _, _, interface_address = interface interface_host, _, _, interface_scope = interface_address # check if it matches the address interface if host == interface_host and scope == interface_scope: return True # no matching interfaces have been found return False