167 lines
6 KiB
Python
167 lines
6 KiB
Python
import socket
|
|
import typing
|
|
import zlib
|
|
|
|
import bidict
|
|
|
|
from source import packets, utils
|
|
from source.managers import Manager
|
|
from source.utils.crypto.rsa import rsa_create_key_pair
|
|
from source.utils.crypto.type import CipherType
|
|
|
|
|
|
class CommunicationManager:
|
|
"""
|
|
Manage everything about communication
|
|
"""
|
|
|
|
def __init__(self, manager: "Manager", interface: str, broadcast_address: str = "ff02::1", port: int = 5555):
|
|
self.manager = manager
|
|
|
|
self.broadcast_address = broadcast_address
|
|
self.port = port
|
|
|
|
# create an IPv6 UDP socket
|
|
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
|
|
# enable broadcast messages
|
|
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
|
|
# use multicast on the selected interface
|
|
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(interface))
|
|
# bind to listen for any message on this port
|
|
self.socket.bind(("::", self.port))
|
|
|
|
# create a dictionary to hold the types of packets and their headers.
|
|
self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict()
|
|
|
|
# create a private and public key for RSA communication
|
|
self.private_key, self.public_key = rsa_create_key_pair()
|
|
|
|
# TODO(Faraphel): should be decided by the server when changing role, stored somewhere else
|
|
self.secret_key: bytes = b"secret key!"
|
|
|
|
def __del__(self):
|
|
# close the socket
|
|
self.socket.close()
|
|
|
|
def register_packet_type(self, header: bytes, packet_type: typing.Type[packets.base.BasePacket]) -> None:
|
|
"""
|
|
Register a new kind of packet that can be sent or received.
|
|
:param header: the binary header identifying the packet
|
|
:param packet_type: the class of the packet
|
|
"""
|
|
|
|
if len(header) != 4:
|
|
raise Exception("The header should be exactly 4 bytes long.")
|
|
|
|
self.packet_types[header] = packet_type
|
|
|
|
def packet_encode(self, packet: packets.base.BasePacket, cipher_type: CipherType) -> bytes:
|
|
"""
|
|
Encode a packet for diffusion
|
|
:param packet: a packet to encode to be sent
|
|
:param cipher_type: the type of cipher
|
|
:return: an encoded packet
|
|
"""
|
|
|
|
# get the header identifier of the type of this packet
|
|
header: typing.Optional[bytes] = self.packet_types.inverse.get(type(packet))
|
|
if header is None:
|
|
raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?")
|
|
|
|
# get the encoded packet data
|
|
data = packet.pack()
|
|
|
|
# calculate its checksum using CRC32
|
|
checksum = zlib.crc32(data).to_bytes(4, byteorder='big')
|
|
|
|
# get the packet data
|
|
packet_data = checksum + header + data
|
|
|
|
# encrypt the packet data depending on the cipher selected
|
|
match cipher_type:
|
|
case CipherType.PLAIN:
|
|
pass
|
|
|
|
case CipherType.RSA:
|
|
packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key)
|
|
|
|
case CipherType.AES_ECB:
|
|
packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, self.secret_key)
|
|
|
|
case _:
|
|
raise ValueError(f"Unknown cipher: {cipher_type}")
|
|
|
|
# prepend the cipher type to the packet data
|
|
payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data
|
|
|
|
return payload
|
|
|
|
|
|
def packet_decode(self, payload: bytes) -> packets.base.BasePacket:
|
|
"""
|
|
Decode a payload into a packet
|
|
:param payload: the data of the packet
|
|
:return: the deserialized packet
|
|
"""
|
|
|
|
cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big"))
|
|
packet_data: bytes = payload[2:]
|
|
|
|
# decrypt the packet data depending on the cipher used
|
|
match cipher_type:
|
|
case CipherType.PLAIN:
|
|
pass
|
|
|
|
case CipherType.RSA:
|
|
packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key)
|
|
|
|
case CipherType.AES_ECB:
|
|
packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, self.secret_key)
|
|
|
|
case _:
|
|
raise ValueError(f"Unknown cipher: {cipher_type}")
|
|
|
|
# split the header and data from the raw payload
|
|
checksum: int = int.from_bytes(packet_data[:4], "big")
|
|
header: bytes = packet_data[4:8]
|
|
data: bytes = packet_data[8:]
|
|
|
|
# verify the checksum for corruption
|
|
if zlib.crc32(data) != checksum:
|
|
raise ValueError("The checksum is invalid.")
|
|
|
|
# get the type of the packet from its header
|
|
packet_type: typing.Optional[typing.Type[packets.base.BasePacket]] = self.packet_types.get(header)
|
|
if header is None:
|
|
raise KeyError(f"Unrecognised packet header: {header}. Has it been registered ?")
|
|
|
|
# unpack the packet
|
|
return packet_type.unpack(data)
|
|
|
|
def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple):
|
|
self.socket.sendto(self.packet_encode(packet, cipher_type), address)
|
|
|
|
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType):
|
|
"""
|
|
Broadcast a message in the network
|
|
:param cipher_type: the type of cipher
|
|
:param packet: the message to broadcast
|
|
"""
|
|
|
|
# check that no asymmetric cipher mode is used
|
|
if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES:
|
|
raise ValueError("Asymmetric cipher cannot be used in broadcast.")
|
|
|
|
# TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???)
|
|
self.send(packet, cipher_type, (self.broadcast_address, self.port))
|
|
|
|
def receive(self) -> tuple[packets.base.BasePacket, tuple]:
|
|
"""
|
|
Receive a packet
|
|
:return: the packet content alongside the address of the sender
|
|
"""
|
|
|
|
# receive a message
|
|
payload, address = self.socket.recvfrom(65536)
|
|
# decode the payload
|
|
return self.packet_decode(payload), address
|