M2-PT-DRP/source/managers/CommunicationManager.py

304 lines
11 KiB
Python

import hashlib
import json
import socket
import typing
import zlib
from datetime import datetime
import bidict
from source import packets, utils, structures
from source.behaviors import roles
from source.managers import Manager
from source.utils.crypto.rsa import rsa_create_key_pair
from source.utils.crypto.type import CipherType
class CommunicationManager:
"""
Manage the communication between the peers
"""
def __init__(self, manager: "Manager", interface: str, broadcast_address: str = "ff02::1", port: int = 5555):
self.manager = manager
self.broadcast_address = broadcast_address
self.port = port
# create an IPv6 UDP socket
self.socket = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM)
# enable broadcast messages
self.socket.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, True)
# use multicast on the selected interface
self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_MULTICAST_IF, socket.if_nametoindex(interface))
# bind to listen for any message on this port
self.socket.bind(("::", self.port))
# create a dictionary to hold the types of packets and their headers.
self.packet_types: bidict.bidict[bytes, typing.Type[packets.base.BasePacket]] = bidict.bidict()
# load or create a private and public key for asymmetric communication
private_key_path = self.manager.storage / "private_key.der"
public_key_path = self.manager.storage / "public_key.der"
if public_key_path.exists() and private_key_path.exists():
self.private_key = private_key_path.read_bytes()
self.public_key = public_key_path.read_bytes()
else:
self.private_key, self.public_key = rsa_create_key_pair()
private_key_path.write_bytes(self.private_key)
public_key_path.write_bytes(self.public_key)
self._trusted_peers_path = self.manager.storage / "trusted-peers.json"
self._trusted_peers: set[str] = set()
if self._trusted_peers_path.exists():
self._trusted_peers = set(json.loads(self._trusted_peers_path.read_text()))
self._banned_peers_path = self.manager.storage / "banned-peers.json"
self._banned_peers: set[str] = set()
if self._banned_peers_path.exists():
self._banned_peers = set(json.loads(self._banned_peers_path.read_text()))
def __del__(self):
# close the socket
self.socket.close()
def get_secret_key(self) -> typing.Optional[bytes]:
"""
Get the symmetric secret key
:return: the symmetric secret key
"""
# check if we have an "active" role
if not isinstance(self.manager.role.current, roles.base.BaseActiveRole):
return None
# return its secret key
return self.manager.role.current.secret_key
def register_packet_type(self, header: bytes, packet_type: typing.Type[packets.base.BasePacket]) -> None:
"""
Register a new kind of packet that can be sent or received.
:param header: the binary header identifying the packet
:param packet_type: the class of the packet
"""
if len(header) != 4:
raise Exception("The header should be exactly 4 bytes long.")
self.packet_types[header] = packet_type
def packet_encode(self, packet: packets.base.BasePacket, cipher_type: CipherType) -> bytes:
"""
Encode a packet for diffusion
:param packet: a packet to encode to be sent
:param cipher_type: the type of cipher
:return: an encoded packet
"""
# get the header identifier of the type of this packet
header: typing.Optional[bytes] = self.packet_types.inverse.get(type(packet))
if header is None:
raise KeyError(f"Unrecognised packet type: {type(packet)}. Has it been registered ?")
# get the encoded packet data
data = packet.pack()
# calculate its checksum using CRC32
checksum = zlib.crc32(data).to_bytes(4, byteorder='big')
# get the packet data
packet_data = checksum + header + data
# encrypt the packet data depending on the cipher selected
secret_key = self.get_secret_key()
if secret_key is None and cipher_type in utils.crypto.type.CIPHER_SYMMETRIC_TYPES:
raise ValueError("Cannot cipher a packet with undefined secret key.")
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_encrypt(packet_data, secret_key)
case CipherType.AES_CBC:
packet_data = utils.crypto.aes.aes_cbc_encrypt(packet_data, secret_key)
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_encrypt(packet_data, self.public_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# prepend the cipher type to the packet data
payload = cipher_type.value.to_bytes(length=2, byteorder="big") + packet_data
return payload
def packet_decode(self, payload: bytes) -> packets.base.BasePacket:
"""
Decode a payload into a packet
:param payload: the data of the packet
:return: the deserialized packet
"""
cipher_type: CipherType = CipherType(int.from_bytes(payload[0:2], byteorder="big"))
packet_data: bytes = payload[2:]
# decrypt the packet data depending on the cipher used
secret_key = self.get_secret_key()
if secret_key is None and cipher_type in utils.crypto.type.CIPHER_SYMMETRIC_TYPES:
raise ValueError("Cannot decipher a packet with undefined secret key.")
match cipher_type:
case CipherType.PLAIN:
pass
case CipherType.AES_ECB:
packet_data = utils.crypto.aes.aes_ecb_decrypt(packet_data, secret_key)
case CipherType.AES_CBC:
packet_data = utils.crypto.aes.aes_cbc_decrypt(packet_data, secret_key)
case CipherType.RSA:
packet_data = utils.crypto.rsa.rsa_decrypt(packet_data, self.private_key)
case _:
raise ValueError(f"Unknown cipher: {cipher_type}")
# split the header and data from the raw payload
checksum: int = int.from_bytes(packet_data[:4], "big")
header: bytes = packet_data[4:8]
data: bytes = packet_data[8:]
# verify the checksum for corruption
if zlib.crc32(data) != checksum:
raise ValueError("The checksum is invalid.")
# get the type of the packet from its header
packet_type: typing.Optional[typing.Type[packets.base.BasePacket]] = self.packet_types.get(header)
if header is None:
raise KeyError(f"Unrecognised packet header: {header}. Has it been registered ?")
# unpack the packet
return packet_type.unpack(data)
def send(self, packet: packets.base.BasePacket, cipher_type: CipherType, address: tuple):
self.socket.sendto(self.packet_encode(packet, cipher_type), address)
def broadcast(self, packet: packets.base.BasePacket, cipher_type: CipherType):
"""
Broadcast a message in the network
:param cipher_type: the type of cipher
:param packet: the message to broadcast
"""
# check that no asymmetric cipher mode is used
if cipher_type in utils.crypto.type.CIPHER_ASYMMETRIC_TYPES:
raise ValueError("Asymmetric cipher cannot be used in broadcast.")
# TODO(Faraphel): use a channel system (OR ESTABLISH ANOTHER PORT ???)
self.send(packet, cipher_type, (self.broadcast_address, self.port))
def receive(self) -> tuple[packets.base.BasePacket, tuple]:
"""
Receive a packet
:return: the packet content alongside the address of the sender
"""
# receive a message
payload, address = self.socket.recvfrom(65536)
# check if there is a peer associated with this address
peer: structures.Peer = self.manager.peer.peers.get(address)
if peer is not None:
# update the latest interaction date
peer.last_interaction = datetime.now()
# decode the payload
return self.packet_decode(payload), address
@staticmethod
def get_local_addresses() -> list[tuple]:
"""
Get the local addresses of the machine
:return: the local addresses of the machine
"""
return socket.getaddrinfo(socket.gethostname(), None)
def is_address_local(self, address: tuple) -> bool:
"""
Is the given address local
:return: true if the address is local, false otherwise
"""
host, _, _, scope = address
# check for all the interfaces of our machine
for interface in self.get_local_addresses():
# unpack the interface information
interface_family, _, _, _, interface_address = interface
interface_host, _, _, interface_scope = interface_address
# check if it matches the address interface
if host == interface_host and scope == interface_scope:
return True
# no matching interfaces have been found
return False
def save_trusted_peers(self) -> None:
"""
Save the list of trusted peers
"""
self._trusted_peers_path.write_text(json.dumps(list(self._trusted_peers)))
def save_banned_peers(self) -> None:
"""
Save the list of banned peers
"""
self._banned_peers_path.write_text(json.dumps(list(self._banned_peers)))
def trust_peer(self, public_key: bytes) -> None:
"""
Mark a peer as trusted for future connexion
Automatically save it to a file
:param public_key: the public key of the peer
"""
self._trusted_peers.add(hashlib.sha256(public_key).hexdigest())
self.save_trusted_peers()
def ban_peer(self, public_key: bytes) -> None:
"""
Ban a peer from being used for any future connexion
Automatically save it to a file
:param public_key: the public key of the peer
"""
self._banned_peers.add(hashlib.sha256(public_key).hexdigest())
self.save_banned_peers()
def is_peer_trusted(self, public_key: bytes) -> bool:
"""
Determinate is a peer is trusted or not
:param public_key: the public key of the peer
:return: True if the peer is trusted, False otherwise
"""
return hashlib.sha256(public_key).hexdigest() in self._trusted_peers
def is_peer_banned(self, public_key: bytes) -> bool:
"""
Determinate if a peer is banned or not
:param public_key: the public key of the peer
:return: True if the peer is banned, False otherwise
"""
return hashlib.sha256(public_key).hexdigest() in self._banned_peers