From 506ee740f5a77d8840274477e559254d862b8729 Mon Sep 17 00:00:00 2001 From: Faraphel Date: Sun, 14 Aug 2022 12:47:35 +0200 Subject: [PATCH] added TemplateSafeEval, TemplateMultipleSafeEval and Env type hint --- source/__init__.py | 12 ++++ source/mkw/Cup.py | 3 +- source/mkw/ExtractedGame.py | 5 +- source/mkw/ModConfig.py | 72 ++++++++++--------- .../BmgTxtEditor/CTFileLayer.py | 3 +- .../BmgTxtEditor/FormatOriginalTrackLayer.py | 3 +- .../PatchOperation/BmgTxtEditor/IDLayer.py | 4 +- .../PatchOperation/BmgTxtEditor/RegexLayer.py | 3 +- .../PatchOperation/ImageEditor/ImageLayer.py | 2 +- .../PatchOperation/ImageEditor/TextLayer.py | 3 +- source/mkw/Track/AbstractTrack.py | 9 ++- source/mkw/Track/Arena.py | 7 +- source/mkw/Track/DefaultTrack.py | 5 ++ source/mkw/Track/RealArenaTrack.py | 9 ++- source/mkw/Track/TrackGroup.py | 4 +- source/safe_eval/macros.py | 5 +- source/safe_eval/multiple_safe_eval.py | 7 +- source/safe_eval/safe_eval.py | 6 +- source/safe_eval/safe_function.py | 7 +- 19 files changed, 115 insertions(+), 54 deletions(-) diff --git a/source/__init__.py b/source/__init__.py index ecb8f23..10b47a6 100644 --- a/source/__init__.py +++ b/source/__init__.py @@ -4,13 +4,18 @@ from threading import Thread from typing import Callable +# metadata __version__ = (0, 12, 0) __author__ = 'Faraphel' +# external links discord_url = "https://discord.gg/HEYW5v8ZCd" github_wiki_url = "https://github.com/Faraphel/MKWF-Install/wiki/help" +readthedocs_url = "https://mkwf-install.readthedocs.io/" + +# constant declaration Ko: int = 1_000 Mo: int = 1_000 * Ko Go: int = 1_000 * Mo @@ -18,6 +23,13 @@ Go: int = 1_000 * Mo minimum_space_available: int = 15*Go +# global type hint +TemplateSafeEval: str +TemplateMultipleSafeEval: str +Env: dict[str, any] + + +# useful functions def threaded(func: Callable) -> Callable: """ Decorate a function to run in a separate thread diff --git a/source/mkw/Cup.py b/source/mkw/Cup.py index 16a0b61..48bd608 100644 --- a/source/mkw/Cup.py +++ b/source/mkw/Cup.py @@ -4,6 +4,7 @@ from PIL import Image, ImageDraw, ImageFont if TYPE_CHECKING: from source.mkw import ModConfig + from source import TemplateMultipleSafeEval class Cup: @@ -64,7 +65,7 @@ class Cup: # if the icon doesn't exist, use the default automatically generated one return self.get_default_cticon(mod_config=mod_config) - def get_ctfile(self, mod_config: "ModConfig", template: str) -> str: + def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str: """ Get the ctfile for this cup :return: the ctfile diff --git a/source/mkw/ExtractedGame.py b/source/mkw/ExtractedGame.py index 7ba5b1b..fd01022 100644 --- a/source/mkw/ExtractedGame.py +++ b/source/mkw/ExtractedGame.py @@ -1,13 +1,16 @@ import shutil from io import BytesIO from pathlib import Path -from typing import Generator, IO +from typing import Generator, IO, TYPE_CHECKING from source.mkw.ModConfig import ModConfig from source.mkw.Patch.Patch import Patch from source.wt import szs, lec, wit from source.wt.wstrt import StrPath +if TYPE_CHECKING: + from source.mkw.Game import Game + class PathOutsideMod(Exception): def __init__(self, forbidden_path: Path, allowed_range: Path): diff --git a/source/mkw/ModConfig.py b/source/mkw/ModConfig.py index e3d065e..7442c7d 100644 --- a/source/mkw/ModConfig.py +++ b/source/mkw/ModConfig.py @@ -1,6 +1,6 @@ import shutil from pathlib import Path -from typing import Generator, Callable, Iterator, Iterable +from typing import Generator, Callable, Iterator, Iterable, TYPE_CHECKING from PIL import Image @@ -16,6 +16,10 @@ from source.mkw.OriginalTrack import OriginalTrack from source.safe_eval import safe_eval, multiple_safe_eval from source.wt.szs import SZSPath +if TYPE_CHECKING: + from source import TemplateMultipleSafeEval, TemplateSafeEval, Env + + CT_ICON_SIZE: int = 128 @@ -74,46 +78,47 @@ class ModConfig: "specific_settings", "lpar_template", "tags_templates") def __init__(self, path: Path | str, name: str, nickname: str = None, version: str = None, variant: str = None, - tags_cups: list[Tag] = None, - tracks: list["Track | TrackGroup"] = None, original_track_prefix: bool = None, - swap_original_order: bool = None, keep_original_track: bool = None, enable_random_cup: bool = None, - track_file_template: str = None, multiplayer_disable_if: str = None, macros: dict[str, str] = None, - messages: dict[str, dict[str, str]] = None, global_settings: dict[str, dict[str, str]] = None, - specific_settings: dict[str, dict[str, str]] = None, lpar_template: str = None, - tags_templates: dict[str, str] = None, arenas: list["Arena"] = None): + tags_cups: list[Tag] = None, tracks: list["Track | TrackGroup"] = None, + original_track_prefix: bool = None, swap_original_order: bool = None, keep_original_track: bool = None, + enable_random_cup: bool = None, track_file_template: "TemplateMultipleSafeEval" = None, + multiplayer_disable_if: "TemplateSafeEval" = None, macros: dict[str, "TemplateSafeEval"] = None, + messages: dict[str, dict[str, "TemplateMultipleSafeEval"]] = None, + global_settings: dict[str, dict[str, str]] = None, specific_settings: dict[str, dict[str, str]] = None, + lpar_template: "TemplateMultipleSafeEval" = None, + tags_templates: dict[str, "TemplateMultipleSafeEval"] = None, arenas: list["Arena"] = None): self.path = Path(path) - self.macros: dict = macros if macros is not None else {} - self.messages: dict = messages if messages is not None else {} + self.macros = macros if macros is not None else {} + self.messages = messages if messages is not None else {} - self.global_settings: dict = AbstractModSettings.get(merge_dict( + self.global_settings = AbstractModSettings.get(merge_dict( default_global_settings, global_settings, dict_keys=default_global_settings.keys() # Avoid modder to add their own settings to globals one )) - self.specific_settings: dict = AbstractModSettings.get( + self.specific_settings = AbstractModSettings.get( specific_settings if specific_settings is not None else {} ) - self.name: str = name - self.nickname: str = nickname if nickname is not None else name - self.version: str = version if version is not None else "v1.0.0" - self.variant: str = variant if variant is not None else "01" + self.name = name + self.nickname = nickname if nickname is not None else name + self.version = version if version is not None else "v1.0.0" + self.variant = variant if variant is not None else "01" - self.tags_templates: dict[str, str] = tags_templates if tags_templates is not None else {} - self.tags_cups: list[Tag] = tags_cups if tags_cups is not None else [] + self.tags_templates = tags_templates if tags_templates is not None else {} + self.tags_cups = tags_cups if tags_cups is not None else [] - self._tracks: list["Track | TrackGroup"] = tracks if tracks is not None else [] - self.track_file_template: str = track_file_template \ + self._tracks = tracks if tracks is not None else [] + self.track_file_template = track_file_template \ if track_file_template is not None else "{{ getattr(track, 'sha1', '_') }}" - self.multiplayer_disable_if: str = multiplayer_disable_if if multiplayer_disable_if is not None else "False" - self.lpar_template: str = lpar_template if lpar_template is not None else "normal.lpar" + self.multiplayer_disable_if = multiplayer_disable_if if multiplayer_disable_if is not None else "False" + self.lpar_template = lpar_template if lpar_template is not None else "normal.lpar" - self.arenas: list["Arena"] = arenas if arenas is not None else [] + self.arenas = arenas if arenas is not None else [] - self.original_track_prefix: bool = original_track_prefix if original_track_prefix is not None else True - self.swap_original_order: bool = swap_original_order if swap_original_order is not None else True - self.keep_original_track: bool = keep_original_track if keep_original_track is not None else True - self.enable_random_cup: bool = enable_random_cup if enable_random_cup is not None else True + self.original_track_prefix = original_track_prefix if original_track_prefix is not None else True + self.swap_original_order = swap_original_order if swap_original_order is not None else True + self.keep_original_track = keep_original_track if keep_original_track is not None else True + self.enable_random_cup = enable_random_cup if enable_random_cup is not None else True def __repr__(self): return f"" @@ -168,7 +173,7 @@ class ModConfig: messages=json.loads(messages_file.read_text(encoding="utf8")) if messages_file.exists() else None, ) - def get_safe_eval_env(self, base_env: dict[str, any] = None) -> dict[str, any]: + def get_safe_eval_env(self, base_env: "Env" = None) -> dict[str, any]: """ Return the env for the modconfig safe_eval function :param base_env: additional environment @@ -182,14 +187,14 @@ class ModConfig: base_env if base_env is not None else {} ) - def safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> any: + def safe_eval(self, *args, env: "Env" = None, **kwargs) -> any: """ Safe eval with useful modconfig environment :return: the result of the evaluation """ return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs) - def multiple_safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> str: + def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> str: """ Multiple safe eval with useful modconfig environment :return: the str result of the evaluation @@ -227,8 +232,9 @@ class ModConfig: :return: track or tracks groups elements """ - filter_template: str | None = self.global_settings["include_track_if"].value if not ignore_filter else None - settings_sort: str | None = self.global_settings["sort_tracks"].value + filter_template: "TemplateSafeEval | None" = self.global_settings["include_track_if"].value \ + if not ignore_filter else None + settings_sort: "TemplateSafeEval | None" = self.global_settings["sort_tracks"].value # filter_template_func is the function checking if the track should be included. If no parameter is set, # then always return True @@ -309,7 +315,7 @@ class ModConfig: yield from self.get_ordered_cups() yield from self.get_unordered_cups() - def get_ctfile(self, template: str) -> str: + def get_ctfile(self, template: "TemplateMultipleSafeEval") -> str: """ Return the ct_file generated from the ModConfig :template: template for the track name diff --git a/source/mkw/Patch/PatchOperation/BmgTxtEditor/CTFileLayer.py b/source/mkw/Patch/PatchOperation/BmgTxtEditor/CTFileLayer.py index a65c70d..9041ad9 100644 --- a/source/mkw/Patch/PatchOperation/BmgTxtEditor/CTFileLayer.py +++ b/source/mkw/Patch/PatchOperation/BmgTxtEditor/CTFileLayer.py @@ -5,6 +5,7 @@ from source.wt import ctc if TYPE_CHECKING: from source.mkw.Patch import Patch + from source import TemplateMultipleSafeEval class CTFileLayer(AbstractLayer): @@ -14,7 +15,7 @@ class CTFileLayer(AbstractLayer): mode = "ctfile" - def __init__(self, template: dict[str, str]): + def __init__(self, template: "TemplateMultipleSafeEval"): self.template = template def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: diff --git a/source/mkw/Patch/PatchOperation/BmgTxtEditor/FormatOriginalTrackLayer.py b/source/mkw/Patch/PatchOperation/BmgTxtEditor/FormatOriginalTrackLayer.py index 5bd321e..46389f1 100644 --- a/source/mkw/Patch/PatchOperation/BmgTxtEditor/FormatOriginalTrackLayer.py +++ b/source/mkw/Patch/PatchOperation/BmgTxtEditor/FormatOriginalTrackLayer.py @@ -7,6 +7,7 @@ from source.wt import bmg if TYPE_CHECKING: from source.mkw.Patch import Patch + from source import TemplateMultipleSafeEval class FormatOriginalTrackLayer(AbstractLayer): @@ -16,7 +17,7 @@ class FormatOriginalTrackLayer(AbstractLayer): mode = "format-original-track" - def __init__(self, template: str): + def __init__(self, template: "TemplateMultipleSafeEval"): self.template = template def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: diff --git a/source/mkw/Patch/PatchOperation/BmgTxtEditor/IDLayer.py b/source/mkw/Patch/PatchOperation/BmgTxtEditor/IDLayer.py index 47a96cf..e645706 100644 --- a/source/mkw/Patch/PatchOperation/BmgTxtEditor/IDLayer.py +++ b/source/mkw/Patch/PatchOperation/BmgTxtEditor/IDLayer.py @@ -1,9 +1,11 @@ from typing import TYPE_CHECKING + from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer if TYPE_CHECKING: from source.mkw.Patch import Patch + from source import TemplateMultipleSafeEval class IDLayer(AbstractLayer): @@ -13,7 +15,7 @@ class IDLayer(AbstractLayer): mode = "id" - def __init__(self, template: dict[str, str]): + def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]): self.template = template def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: diff --git a/source/mkw/Patch/PatchOperation/BmgTxtEditor/RegexLayer.py b/source/mkw/Patch/PatchOperation/BmgTxtEditor/RegexLayer.py index 1e9696a..668989a 100644 --- a/source/mkw/Patch/PatchOperation/BmgTxtEditor/RegexLayer.py +++ b/source/mkw/Patch/PatchOperation/BmgTxtEditor/RegexLayer.py @@ -5,6 +5,7 @@ from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer if TYPE_CHECKING: from source.mkw.Patch import Patch + from source import TemplateMultipleSafeEval class RegexLayer(AbstractLayer): @@ -14,7 +15,7 @@ class RegexLayer(AbstractLayer): mode = "regex" - def __init__(self, template: dict[str, str]): + def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]): self.template = template def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: diff --git a/source/mkw/Patch/PatchOperation/ImageEditor/ImageLayer.py b/source/mkw/Patch/PatchOperation/ImageEditor/ImageLayer.py index fd24c80..bdce754 100644 --- a/source/mkw/Patch/PatchOperation/ImageEditor/ImageLayer.py +++ b/source/mkw/Patch/PatchOperation/ImageEditor/ImageLayer.py @@ -2,8 +2,8 @@ from typing import TYPE_CHECKING from PIL import Image -from source.mkw.Patch import * from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer +from source.mkw.Patch import PathOutsidePatch if TYPE_CHECKING: from source.mkw.Patch import Patch diff --git a/source/mkw/Patch/PatchOperation/ImageEditor/TextLayer.py b/source/mkw/Patch/PatchOperation/ImageEditor/TextLayer.py index 5dbb5ce..a73e295 100644 --- a/source/mkw/Patch/PatchOperation/ImageEditor/TextLayer.py +++ b/source/mkw/Patch/PatchOperation/ImageEditor/TextLayer.py @@ -8,6 +8,7 @@ from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer if TYPE_CHECKING: from source.mkw.Patch import Patch + from source import TemplateMultipleSafeEval class TextLayer(AbstractLayer): @@ -24,7 +25,7 @@ class TextLayer(AbstractLayer): self.font_path: str | None = font_path self.font_size: int = font_size self.color: tuple[int] = tuple(color) - self.text: str = text + self.text: "TemplateMultipleSafeEval" = text def patch_image(self, patch: "Patch", image: Image.Image) -> Image.Image: draw = ImageDraw.Draw(image) diff --git a/source/mkw/Track/AbstractTrack.py b/source/mkw/Track/AbstractTrack.py index cbaf168..4cf97cb 100644 --- a/source/mkw/Track/AbstractTrack.py +++ b/source/mkw/Track/AbstractTrack.py @@ -1,8 +1,11 @@ from abc import ABC, abstractmethod -from typing import Generator +from typing import Generator, TYPE_CHECKING from source.mkw import Slot, Tag, ModConfig +if TYPE_CHECKING: + from source import TemplateMultipleSafeEval + class TrackForbiddenCustomAttribute(Exception): def __init__(self, attribute_name: str): @@ -39,7 +42,7 @@ class AbstractTrack(ABC): yield self @abstractmethod - def repr_format(self, mod_config: "ModConfig", template: str) -> str: + def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str: """ return the representation of the track from the format :param template: template for the way the text will be represented @@ -66,7 +69,7 @@ class AbstractTrack(ABC): """ ... - def get_ctfile(self, mod_config: "ModConfig", template: str, hidden: bool = False) -> str: + def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval", hidden: bool = False) -> str: """ return the ctfile of the track :hidden: if the track is in a group diff --git a/source/mkw/Track/Arena.py b/source/mkw/Track/Arena.py index 18f69a8..f574df7 100644 --- a/source/mkw/Track/Arena.py +++ b/source/mkw/Track/Arena.py @@ -1,6 +1,11 @@ +from typing import TYPE_CHECKING + from source.mkw import Slot, Tag from source.mkw.Track.RealArenaTrack import RealArenaTrack +if TYPE_CHECKING: + from source import TemplateMultipleSafeEval + class ArenaForbiddenCustomAttribute(Exception): def __init__(self, attribute_name: str): @@ -32,7 +37,7 @@ class Arena(RealArenaTrack): def from_dict(cls, arena_dict: dict[str, any]) -> "Arena": return cls(**arena_dict) - def get_ctfile(self, mod_config: "ModConfig", template: str) -> (str, str): + def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> (str, str): """ Return the ctfile for the arena and the redefinition of the slot property :param mod_config: the mod_config object diff --git a/source/mkw/Track/DefaultTrack.py b/source/mkw/Track/DefaultTrack.py index 79d61b5..52f6889 100644 --- a/source/mkw/Track/DefaultTrack.py +++ b/source/mkw/Track/DefaultTrack.py @@ -1,5 +1,10 @@ +from typing import TYPE_CHECKING + from source.mkw.Track.AbstractTrack import AbstractTrack +if TYPE_CHECKING: + from source.mkw.ModConfig import ModConfig + class DefaultTrack(AbstractTrack): def repr_format(self, mod_config: "ModConfig", template: str) -> str: diff --git a/source/mkw/Track/RealArenaTrack.py b/source/mkw/Track/RealArenaTrack.py index dab527a..52ef202 100644 --- a/source/mkw/Track/RealArenaTrack.py +++ b/source/mkw/Track/RealArenaTrack.py @@ -1,3 +1,10 @@ +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from source import TemplateMultipleSafeEval + from source.mkw.ModConfig import ModConfig + + class RealArenaTrack: """ class shared between all arena and track class that represent a "real" track or arena @@ -18,7 +25,7 @@ class RealArenaTrack: return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag}) return default - def repr_format(self, mod_config: "ModConfig", template: str) -> str: + def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str: return mod_config.multiple_safe_eval( template, env={ diff --git a/source/mkw/Track/TrackGroup.py b/source/mkw/Track/TrackGroup.py index a4c9ceb..4513ccc 100644 --- a/source/mkw/Track/TrackGroup.py +++ b/source/mkw/Track/TrackGroup.py @@ -8,10 +8,9 @@ ModConfig: any class TrackGroup: - def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None, name: str = None): + def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None): self.tracks = tracks if tracks is not None else [] self.tags = tags if tags is not None else [] - self.name = name if name is not None else "" def get_tracks(self) -> Generator["Track", None, None]: """ @@ -34,7 +33,6 @@ class TrackGroup: return cls( tracks=[CustomTrack.from_dict(track) for track in group_dict["group"]], tags=group_dict.get("tags"), - name=group_dict.get("name"), ) def get_ctfile(self, mod_config: "ModConfig") -> str: diff --git a/source/safe_eval/macros.py b/source/safe_eval/macros.py index 8e831b9..d6a0c8e 100644 --- a/source/safe_eval/macros.py +++ b/source/safe_eval/macros.py @@ -1,5 +1,8 @@ import re +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from source import TemplateSafeEval MACRO_START, MACRO_END = "##", "##" @@ -9,7 +12,7 @@ class NotImplementedMacro(Exception): super().__init__(f"Invalid macro while parsing macros:\n{macro}") -def replace_macro(template: str, macros: dict[str, str]) -> str: +def replace_macro(template: str, macros: dict[str, "TemplateSafeEval"]) -> str: """ Replace all the macro defined in macro by their respective value :param template: template where to replace the macro diff --git a/source/safe_eval/multiple_safe_eval.py b/source/safe_eval/multiple_safe_eval.py index 5f85248..5affac7 100644 --- a/source/safe_eval/multiple_safe_eval.py +++ b/source/safe_eval/multiple_safe_eval.py @@ -1,12 +1,17 @@ import re +from typing import TYPE_CHECKING from source.safe_eval import safe_eval +if TYPE_CHECKING: + from source import TemplateMultipleSafeEval, TemplateSafeEval, Env + TOKEN_START, TOKEN_END = "{{", "}}" -def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str: +def multiple_safe_eval(template: "TemplateMultipleSafeEval", env: "Env" = None, + macros: dict[str, "TemplateSafeEval"] = None) -> str: """ Similar to safe_eval, but expression need to be enclosed between "{{" and "}}". Example : "{{ track.author }} is the track creator !" diff --git a/source/safe_eval/safe_eval.py b/source/safe_eval/safe_eval.py index 381e958..00b0798 100644 --- a/source/safe_eval/safe_eval.py +++ b/source/safe_eval/safe_eval.py @@ -1,9 +1,13 @@ import ast import copy +from typing import TYPE_CHECKING from source.safe_eval.macros import replace_macro from source.safe_eval.safe_function import get_all_safe_functions +if TYPE_CHECKING: + from source import TemplateSafeEval, Env + class SafeEvalException(Exception): def __init__(self, message: str): @@ -19,7 +23,7 @@ all_globals = { } -def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None, +def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str, "TemplateSafeEval"] = None, return_lambda: bool = False, lambda_args: list[str] = None) -> any: """ Run a python code in an eval function, but avoid all potential dangerous function. diff --git a/source/safe_eval/safe_function.py b/source/safe_eval/safe_function.py index 3ebb3e3..cfccd73 100644 --- a/source/safe_eval/safe_function.py +++ b/source/safe_eval/safe_function.py @@ -1,4 +1,7 @@ -from typing import Callable, Generator +from typing import Callable, Generator, TYPE_CHECKING + +if TYPE_CHECKING: + from source import TemplateSafeEval, Env def get_all_safe_functions() -> Generator[list[Callable], None, None]: @@ -41,7 +44,7 @@ class safe_function: return type(obj) @staticmethod - def eval(template: str, env: dict | None = None): + def eval(template: "TemplateSafeEval", env: "Env | None" = None): """ Allow a recursive safe_eval, but without the lambda functionality """