diff --git a/source/network/packet/PacketBoatPlaced.py b/source/network/packet/PacketBoatPlaced.py index ab09090..18f93fe 100644 --- a/source/network/packet/PacketBoatPlaced.py +++ b/source/network/packet/PacketBoatPlaced.py @@ -6,13 +6,11 @@ from source.network.packet.abc import Packet @dataclass class PacketBoatPlaced(Packet): + packet_size: int = 0 + def to_bytes(self): return b"" @classmethod def from_bytes(cls, data: bytes): return cls() - - @classmethod - def from_connection(cls, connection: socket.socket) -> "PacketBoatPlaced": - return cls.from_bytes(connection.recv(0)) diff --git a/source/network/packet/PacketBombPlaced.py b/source/network/packet/PacketBombPlaced.py index d33f2b6..d45bd71 100644 --- a/source/network/packet/PacketBombPlaced.py +++ b/source/network/packet/PacketBombPlaced.py @@ -9,6 +9,8 @@ from source.type import Point2D class PacketBombPlaced(Packet): position: Point2D = field() + packet_size: int = 2 + def to_bytes(self): x, y = self.position return x.to_bytes(1, "big") + y.to_bytes(1, "big") @@ -19,7 +21,3 @@ class PacketBombPlaced(Packet): int.from_bytes(data[0:1], "big"), int.from_bytes(data[1:2], "big"), )) - - @classmethod - def from_connection(cls, connection: socket.socket) -> "PacketBombPlaced": - return cls.from_bytes(connection.recv(2)) diff --git a/source/network/packet/PacketBombState.py b/source/network/packet/PacketBombState.py index b09b089..7605cd6 100644 --- a/source/network/packet/PacketBombState.py +++ b/source/network/packet/PacketBombState.py @@ -11,6 +11,8 @@ class PacketBombState(Packet): position: Point2D = field() bomb_state: BombState = field() + packet_size: int = 3 + def to_bytes(self): x, y = self.position @@ -29,7 +31,3 @@ class PacketBombState(Packet): ), bomb_state=BombState.from_bytes(data[2:3]) ) - - @classmethod - def from_connection(cls, connection: socket.socket) -> "PacketBombState": - return cls.from_bytes(connection.recv(3)) diff --git a/source/network/packet/PacketChat.py b/source/network/packet/PacketChat.py index ba593bd..bb03c74 100644 --- a/source/network/packet/PacketChat.py +++ b/source/network/packet/PacketChat.py @@ -8,13 +8,11 @@ from source.network.packet.abc import Packet class PacketChat(Packet): message: str = field() + packet_size: int = 256 + def to_bytes(self): return self.message.encode("utf-8") @classmethod def from_bytes(cls, data: bytes): return cls(message=data.decode("utf-8")) - - @classmethod - def from_connection(cls, connection: socket.socket) -> "PacketChat": - return cls.from_bytes(connection.recv(256)) \ No newline at end of file diff --git a/source/network/packet/abc/Packet.py b/source/network/packet/abc/Packet.py index e22079c..61fd750 100644 --- a/source/network/packet/abc/Packet.py +++ b/source/network/packet/abc/Packet.py @@ -5,19 +5,13 @@ from typing import Type, Optional class Packet(ABC): packed_header: bytes + packet_size: int packet_id: int = 0 def __init_subclass__(cls, **kwargs): cls.packet_header = Packet.packet_id.to_bytes(1, "big") Packet.packet_id = Packet.packet_id + 1 - @classmethod - def cls_from_header(cls, packet_header: bytes) -> Type["Packet"]: - return next(filter( - lambda subcls: subcls.packet_header == packet_header, - cls.__subclasses__() - )) - @abstractmethod def to_bytes(self) -> bytes: pass @@ -27,16 +21,24 @@ class Packet(ABC): def from_bytes(cls, data: bytes) -> "Packet": pass + @classmethod + def cls_from_header(cls, packet_header: bytes) -> Type["Packet"]: + return next(filter( + lambda subcls: subcls.packet_header == packet_header, + cls.__subclasses__() + )) + def send_connection(self, connection: socket.socket) -> None: connection.send(self.packet_header) connection.send(self.to_bytes()) @classmethod - def from_connection(cls, connection: socket.socket) -> Optional[Type["Packet"]]: + def from_connection(cls, connection: socket.socket) -> Optional["Packet"]: packet_header: Optional[bytes] = None try: packet_header = connection.recv(1) except socket.timeout: pass - if not packet_header: return None + if not packet_header: return None # si le header du packet est invalide, ignore + subcls = cls.cls_from_header(packet_header) - return cls.cls_from_header(packet_header).from_connection(connection) + return subcls.from_bytes(connection.recv(subcls.packet_size))