added TemplateSafeEval, TemplateMultipleSafeEval and Env type hint

This commit is contained in:
Faraphel 2022-08-14 12:47:35 +02:00
parent 923f696e24
commit 506ee740f5
19 changed files with 115 additions and 54 deletions

View file

@ -4,13 +4,18 @@ from threading import Thread
from typing import Callable
# metadata
__version__ = (0, 12, 0)
__author__ = 'Faraphel'
# external links
discord_url = "https://discord.gg/HEYW5v8ZCd"
github_wiki_url = "https://github.com/Faraphel/MKWF-Install/wiki/help"
readthedocs_url = "https://mkwf-install.readthedocs.io/"
# constant declaration
Ko: int = 1_000
Mo: int = 1_000 * Ko
Go: int = 1_000 * Mo
@ -18,6 +23,13 @@ Go: int = 1_000 * Mo
minimum_space_available: int = 15*Go
# global type hint
TemplateSafeEval: str
TemplateMultipleSafeEval: str
Env: dict[str, any]
# useful functions
def threaded(func: Callable) -> Callable:
"""
Decorate a function to run in a separate thread

View file

@ -4,6 +4,7 @@ from PIL import Image, ImageDraw, ImageFont
if TYPE_CHECKING:
from source.mkw import ModConfig
from source import TemplateMultipleSafeEval
class Cup:
@ -64,7 +65,7 @@ class Cup:
# if the icon doesn't exist, use the default automatically generated one
return self.get_default_cticon(mod_config=mod_config)
def get_ctfile(self, mod_config: "ModConfig", template: str) -> str:
def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
"""
Get the ctfile for this cup
:return: the ctfile

View file

@ -1,13 +1,16 @@
import shutil
from io import BytesIO
from pathlib import Path
from typing import Generator, IO
from typing import Generator, IO, TYPE_CHECKING
from source.mkw.ModConfig import ModConfig
from source.mkw.Patch.Patch import Patch
from source.wt import szs, lec, wit
from source.wt.wstrt import StrPath
if TYPE_CHECKING:
from source.mkw.Game import Game
class PathOutsideMod(Exception):
def __init__(self, forbidden_path: Path, allowed_range: Path):

View file

@ -1,6 +1,6 @@
import shutil
from pathlib import Path
from typing import Generator, Callable, Iterator, Iterable
from typing import Generator, Callable, Iterator, Iterable, TYPE_CHECKING
from PIL import Image
@ -16,6 +16,10 @@ from source.mkw.OriginalTrack import OriginalTrack
from source.safe_eval import safe_eval, multiple_safe_eval
from source.wt.szs import SZSPath
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval, TemplateSafeEval, Env
CT_ICON_SIZE: int = 128
@ -74,46 +78,47 @@ class ModConfig:
"specific_settings", "lpar_template", "tags_templates")
def __init__(self, path: Path | str, name: str, nickname: str = None, version: str = None, variant: str = None,
tags_cups: list[Tag] = None,
tracks: list["Track | TrackGroup"] = None, original_track_prefix: bool = None,
swap_original_order: bool = None, keep_original_track: bool = None, enable_random_cup: bool = None,
track_file_template: str = None, multiplayer_disable_if: str = None, macros: dict[str, str] = None,
messages: dict[str, dict[str, str]] = None, global_settings: dict[str, dict[str, str]] = None,
specific_settings: dict[str, dict[str, str]] = None, lpar_template: str = None,
tags_templates: dict[str, str] = None, arenas: list["Arena"] = None):
tags_cups: list[Tag] = None, tracks: list["Track | TrackGroup"] = None,
original_track_prefix: bool = None, swap_original_order: bool = None, keep_original_track: bool = None,
enable_random_cup: bool = None, track_file_template: "TemplateMultipleSafeEval" = None,
multiplayer_disable_if: "TemplateSafeEval" = None, macros: dict[str, "TemplateSafeEval"] = None,
messages: dict[str, dict[str, "TemplateMultipleSafeEval"]] = None,
global_settings: dict[str, dict[str, str]] = None, specific_settings: dict[str, dict[str, str]] = None,
lpar_template: "TemplateMultipleSafeEval" = None,
tags_templates: dict[str, "TemplateMultipleSafeEval"] = None, arenas: list["Arena"] = None):
self.path = Path(path)
self.macros: dict = macros if macros is not None else {}
self.messages: dict = messages if messages is not None else {}
self.macros = macros if macros is not None else {}
self.messages = messages if messages is not None else {}
self.global_settings: dict = AbstractModSettings.get(merge_dict(
self.global_settings = AbstractModSettings.get(merge_dict(
default_global_settings, global_settings,
dict_keys=default_global_settings.keys() # Avoid modder to add their own settings to globals one
))
self.specific_settings: dict = AbstractModSettings.get(
self.specific_settings = AbstractModSettings.get(
specific_settings if specific_settings is not None else {}
)
self.name: str = name
self.nickname: str = nickname if nickname is not None else name
self.version: str = version if version is not None else "v1.0.0"
self.variant: str = variant if variant is not None else "01"
self.name = name
self.nickname = nickname if nickname is not None else name
self.version = version if version is not None else "v1.0.0"
self.variant = variant if variant is not None else "01"
self.tags_templates: dict[str, str] = tags_templates if tags_templates is not None else {}
self.tags_cups: list[Tag] = tags_cups if tags_cups is not None else []
self.tags_templates = tags_templates if tags_templates is not None else {}
self.tags_cups = tags_cups if tags_cups is not None else []
self._tracks: list["Track | TrackGroup"] = tracks if tracks is not None else []
self.track_file_template: str = track_file_template \
self._tracks = tracks if tracks is not None else []
self.track_file_template = track_file_template \
if track_file_template is not None else "{{ getattr(track, 'sha1', '_') }}"
self.multiplayer_disable_if: str = multiplayer_disable_if if multiplayer_disable_if is not None else "False"
self.lpar_template: str = lpar_template if lpar_template is not None else "normal.lpar"
self.multiplayer_disable_if = multiplayer_disable_if if multiplayer_disable_if is not None else "False"
self.lpar_template = lpar_template if lpar_template is not None else "normal.lpar"
self.arenas: list["Arena"] = arenas if arenas is not None else []
self.arenas = arenas if arenas is not None else []
self.original_track_prefix: bool = original_track_prefix if original_track_prefix is not None else True
self.swap_original_order: bool = swap_original_order if swap_original_order is not None else True
self.keep_original_track: bool = keep_original_track if keep_original_track is not None else True
self.enable_random_cup: bool = enable_random_cup if enable_random_cup is not None else True
self.original_track_prefix = original_track_prefix if original_track_prefix is not None else True
self.swap_original_order = swap_original_order if swap_original_order is not None else True
self.keep_original_track = keep_original_track if keep_original_track is not None else True
self.enable_random_cup = enable_random_cup if enable_random_cup is not None else True
def __repr__(self):
return f"<ModConfig name={self.name} version={self.version}>"
@ -168,7 +173,7 @@ class ModConfig:
messages=json.loads(messages_file.read_text(encoding="utf8")) if messages_file.exists() else None,
)
def get_safe_eval_env(self, base_env: dict[str, any] = None) -> dict[str, any]:
def get_safe_eval_env(self, base_env: "Env" = None) -> dict[str, any]:
"""
Return the env for the modconfig safe_eval function
:param base_env: additional environment
@ -182,14 +187,14 @@ class ModConfig:
base_env if base_env is not None else {}
)
def safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> any:
def safe_eval(self, *args, env: "Env" = None, **kwargs) -> any:
"""
Safe eval with useful modconfig environment
:return: the result of the evaluation
"""
return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs)
def multiple_safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> str:
def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> str:
"""
Multiple safe eval with useful modconfig environment
:return: the str result of the evaluation
@ -227,8 +232,9 @@ class ModConfig:
:return: track or tracks groups elements
"""
filter_template: str | None = self.global_settings["include_track_if"].value if not ignore_filter else None
settings_sort: str | None = self.global_settings["sort_tracks"].value
filter_template: "TemplateSafeEval | None" = self.global_settings["include_track_if"].value \
if not ignore_filter else None
settings_sort: "TemplateSafeEval | None" = self.global_settings["sort_tracks"].value
# filter_template_func is the function checking if the track should be included. If no parameter is set,
# then always return True
@ -309,7 +315,7 @@ class ModConfig:
yield from self.get_ordered_cups()
yield from self.get_unordered_cups()
def get_ctfile(self, template: str) -> str:
def get_ctfile(self, template: "TemplateMultipleSafeEval") -> str:
"""
Return the ct_file generated from the ModConfig
:template: template for the track name

View file

@ -5,6 +5,7 @@ from source.wt import ctc
if TYPE_CHECKING:
from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class CTFileLayer(AbstractLayer):
@ -14,7 +15,7 @@ class CTFileLayer(AbstractLayer):
mode = "ctfile"
def __init__(self, template: dict[str, str]):
def __init__(self, template: "TemplateMultipleSafeEval"):
self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -7,6 +7,7 @@ from source.wt import bmg
if TYPE_CHECKING:
from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class FormatOriginalTrackLayer(AbstractLayer):
@ -16,7 +17,7 @@ class FormatOriginalTrackLayer(AbstractLayer):
mode = "format-original-track"
def __init__(self, template: str):
def __init__(self, template: "TemplateMultipleSafeEval"):
self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING
from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer
if TYPE_CHECKING:
from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class IDLayer(AbstractLayer):
@ -13,7 +15,7 @@ class IDLayer(AbstractLayer):
mode = "id"
def __init__(self, template: dict[str, str]):
def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]):
self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -5,6 +5,7 @@ from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer
if TYPE_CHECKING:
from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class RegexLayer(AbstractLayer):
@ -14,7 +15,7 @@ class RegexLayer(AbstractLayer):
mode = "regex"
def __init__(self, template: dict[str, str]):
def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]):
self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -2,8 +2,8 @@ from typing import TYPE_CHECKING
from PIL import Image
from source.mkw.Patch import *
from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer
from source.mkw.Patch import PathOutsidePatch
if TYPE_CHECKING:
from source.mkw.Patch import Patch

View file

@ -8,6 +8,7 @@ from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer
if TYPE_CHECKING:
from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class TextLayer(AbstractLayer):
@ -24,7 +25,7 @@ class TextLayer(AbstractLayer):
self.font_path: str | None = font_path
self.font_size: int = font_size
self.color: tuple[int] = tuple(color)
self.text: str = text
self.text: "TemplateMultipleSafeEval" = text
def patch_image(self, patch: "Patch", image: Image.Image) -> Image.Image:
draw = ImageDraw.Draw(image)

View file

@ -1,8 +1,11 @@
from abc import ABC, abstractmethod
from typing import Generator
from typing import Generator, TYPE_CHECKING
from source.mkw import Slot, Tag, ModConfig
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
class TrackForbiddenCustomAttribute(Exception):
def __init__(self, attribute_name: str):
@ -39,7 +42,7 @@ class AbstractTrack(ABC):
yield self
@abstractmethod
def repr_format(self, mod_config: "ModConfig", template: str) -> str:
def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
"""
return the representation of the track from the format
:param template: template for the way the text will be represented
@ -66,7 +69,7 @@ class AbstractTrack(ABC):
"""
...
def get_ctfile(self, mod_config: "ModConfig", template: str, hidden: bool = False) -> str:
def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval", hidden: bool = False) -> str:
"""
return the ctfile of the track
:hidden: if the track is in a group

View file

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from source.mkw import Slot, Tag
from source.mkw.Track.RealArenaTrack import RealArenaTrack
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
class ArenaForbiddenCustomAttribute(Exception):
def __init__(self, attribute_name: str):
@ -32,7 +37,7 @@ class Arena(RealArenaTrack):
def from_dict(cls, arena_dict: dict[str, any]) -> "Arena":
return cls(**arena_dict)
def get_ctfile(self, mod_config: "ModConfig", template: str) -> (str, str):
def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> (str, str):
"""
Return the ctfile for the arena and the redefinition of the slot property
:param mod_config: the mod_config object

View file

@ -1,5 +1,10 @@
from typing import TYPE_CHECKING
from source.mkw.Track.AbstractTrack import AbstractTrack
if TYPE_CHECKING:
from source.mkw.ModConfig import ModConfig
class DefaultTrack(AbstractTrack):
def repr_format(self, mod_config: "ModConfig", template: str) -> str:

View file

@ -1,3 +1,10 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
from source.mkw.ModConfig import ModConfig
class RealArenaTrack:
"""
class shared between all arena and track class that represent a "real" track or arena
@ -18,7 +25,7 @@ class RealArenaTrack:
return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag})
return default
def repr_format(self, mod_config: "ModConfig", template: str) -> str:
def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
return mod_config.multiple_safe_eval(
template,
env={

View file

@ -8,10 +8,9 @@ ModConfig: any
class TrackGroup:
def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None, name: str = None):
def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None):
self.tracks = tracks if tracks is not None else []
self.tags = tags if tags is not None else []
self.name = name if name is not None else ""
def get_tracks(self) -> Generator["Track", None, None]:
"""
@ -34,7 +33,6 @@ class TrackGroup:
return cls(
tracks=[CustomTrack.from_dict(track) for track in group_dict["group"]],
tags=group_dict.get("tags"),
name=group_dict.get("name"),
)
def get_ctfile(self, mod_config: "ModConfig") -> str:

View file

@ -1,5 +1,8 @@
import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateSafeEval
MACRO_START, MACRO_END = "##", "##"
@ -9,7 +12,7 @@ class NotImplementedMacro(Exception):
super().__init__(f"Invalid macro while parsing macros:\n{macro}")
def replace_macro(template: str, macros: dict[str, str]) -> str:
def replace_macro(template: str, macros: dict[str, "TemplateSafeEval"]) -> str:
"""
Replace all the macro defined in macro by their respective value
:param template: template where to replace the macro

View file

@ -1,12 +1,17 @@
import re
from typing import TYPE_CHECKING
from source.safe_eval import safe_eval
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval, TemplateSafeEval, Env
TOKEN_START, TOKEN_END = "{{", "}}"
def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str:
def multiple_safe_eval(template: "TemplateMultipleSafeEval", env: "Env" = None,
macros: dict[str, "TemplateSafeEval"] = None) -> str:
"""
Similar to safe_eval, but expression need to be enclosed between "{{" and "}}".
Example : "{{ track.author }} is the track creator !"

View file

@ -1,9 +1,13 @@
import ast
import copy
from typing import TYPE_CHECKING
from source.safe_eval.macros import replace_macro
from source.safe_eval.safe_function import get_all_safe_functions
if TYPE_CHECKING:
from source import TemplateSafeEval, Env
class SafeEvalException(Exception):
def __init__(self, message: str):
@ -19,7 +23,7 @@ all_globals = {
}
def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None,
def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str, "TemplateSafeEval"] = None,
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
"""
Run a python code in an eval function, but avoid all potential dangerous function.

View file

@ -1,4 +1,7 @@
from typing import Callable, Generator
from typing import Callable, Generator, TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateSafeEval, Env
def get_all_safe_functions() -> Generator[list[Callable], None, None]:
@ -41,7 +44,7 @@ class safe_function:
return type(obj)
@staticmethod
def eval(template: str, env: dict | None = None):
def eval(template: "TemplateSafeEval", env: "Env | None" = None):
"""
Allow a recursive safe_eval, but without the lambda functionality
"""