added TemplateSafeEval, TemplateMultipleSafeEval and Env type hint

This commit is contained in:
Faraphel 2022-08-14 12:47:35 +02:00
parent 923f696e24
commit 506ee740f5
19 changed files with 115 additions and 54 deletions

View file

@ -4,13 +4,18 @@ from threading import Thread
from typing import Callable from typing import Callable
# metadata
__version__ = (0, 12, 0) __version__ = (0, 12, 0)
__author__ = 'Faraphel' __author__ = 'Faraphel'
# external links
discord_url = "https://discord.gg/HEYW5v8ZCd" discord_url = "https://discord.gg/HEYW5v8ZCd"
github_wiki_url = "https://github.com/Faraphel/MKWF-Install/wiki/help" github_wiki_url = "https://github.com/Faraphel/MKWF-Install/wiki/help"
readthedocs_url = "https://mkwf-install.readthedocs.io/"
# constant declaration
Ko: int = 1_000 Ko: int = 1_000
Mo: int = 1_000 * Ko Mo: int = 1_000 * Ko
Go: int = 1_000 * Mo Go: int = 1_000 * Mo
@ -18,6 +23,13 @@ Go: int = 1_000 * Mo
minimum_space_available: int = 15*Go minimum_space_available: int = 15*Go
# global type hint
TemplateSafeEval: str
TemplateMultipleSafeEval: str
Env: dict[str, any]
# useful functions
def threaded(func: Callable) -> Callable: def threaded(func: Callable) -> Callable:
""" """
Decorate a function to run in a separate thread Decorate a function to run in a separate thread

View file

@ -4,6 +4,7 @@ from PIL import Image, ImageDraw, ImageFont
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw import ModConfig from source.mkw import ModConfig
from source import TemplateMultipleSafeEval
class Cup: class Cup:
@ -64,7 +65,7 @@ class Cup:
# if the icon doesn't exist, use the default automatically generated one # if the icon doesn't exist, use the default automatically generated one
return self.get_default_cticon(mod_config=mod_config) return self.get_default_cticon(mod_config=mod_config)
def get_ctfile(self, mod_config: "ModConfig", template: str) -> str: def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
""" """
Get the ctfile for this cup Get the ctfile for this cup
:return: the ctfile :return: the ctfile

View file

@ -1,13 +1,16 @@
import shutil import shutil
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
from typing import Generator, IO from typing import Generator, IO, TYPE_CHECKING
from source.mkw.ModConfig import ModConfig from source.mkw.ModConfig import ModConfig
from source.mkw.Patch.Patch import Patch from source.mkw.Patch.Patch import Patch
from source.wt import szs, lec, wit from source.wt import szs, lec, wit
from source.wt.wstrt import StrPath from source.wt.wstrt import StrPath
if TYPE_CHECKING:
from source.mkw.Game import Game
class PathOutsideMod(Exception): class PathOutsideMod(Exception):
def __init__(self, forbidden_path: Path, allowed_range: Path): def __init__(self, forbidden_path: Path, allowed_range: Path):

View file

@ -1,6 +1,6 @@
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Generator, Callable, Iterator, Iterable from typing import Generator, Callable, Iterator, Iterable, TYPE_CHECKING
from PIL import Image from PIL import Image
@ -16,6 +16,10 @@ from source.mkw.OriginalTrack import OriginalTrack
from source.safe_eval import safe_eval, multiple_safe_eval from source.safe_eval import safe_eval, multiple_safe_eval
from source.wt.szs import SZSPath from source.wt.szs import SZSPath
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval, TemplateSafeEval, Env
CT_ICON_SIZE: int = 128 CT_ICON_SIZE: int = 128
@ -74,46 +78,47 @@ class ModConfig:
"specific_settings", "lpar_template", "tags_templates") "specific_settings", "lpar_template", "tags_templates")
def __init__(self, path: Path | str, name: str, nickname: str = None, version: str = None, variant: str = None, def __init__(self, path: Path | str, name: str, nickname: str = None, version: str = None, variant: str = None,
tags_cups: list[Tag] = None, tags_cups: list[Tag] = None, tracks: list["Track | TrackGroup"] = None,
tracks: list["Track | TrackGroup"] = None, original_track_prefix: bool = None, original_track_prefix: bool = None, swap_original_order: bool = None, keep_original_track: bool = None,
swap_original_order: bool = None, keep_original_track: bool = None, enable_random_cup: bool = None, enable_random_cup: bool = None, track_file_template: "TemplateMultipleSafeEval" = None,
track_file_template: str = None, multiplayer_disable_if: str = None, macros: dict[str, str] = None, multiplayer_disable_if: "TemplateSafeEval" = None, macros: dict[str, "TemplateSafeEval"] = None,
messages: dict[str, dict[str, str]] = None, global_settings: dict[str, dict[str, str]] = None, messages: dict[str, dict[str, "TemplateMultipleSafeEval"]] = None,
specific_settings: dict[str, dict[str, str]] = None, lpar_template: str = None, global_settings: dict[str, dict[str, str]] = None, specific_settings: dict[str, dict[str, str]] = None,
tags_templates: dict[str, str] = None, arenas: list["Arena"] = None): lpar_template: "TemplateMultipleSafeEval" = None,
tags_templates: dict[str, "TemplateMultipleSafeEval"] = None, arenas: list["Arena"] = None):
self.path = Path(path) self.path = Path(path)
self.macros: dict = macros if macros is not None else {} self.macros = macros if macros is not None else {}
self.messages: dict = messages if messages is not None else {} self.messages = messages if messages is not None else {}
self.global_settings: dict = AbstractModSettings.get(merge_dict( self.global_settings = AbstractModSettings.get(merge_dict(
default_global_settings, global_settings, default_global_settings, global_settings,
dict_keys=default_global_settings.keys() # Avoid modder to add their own settings to globals one dict_keys=default_global_settings.keys() # Avoid modder to add their own settings to globals one
)) ))
self.specific_settings: dict = AbstractModSettings.get( self.specific_settings = AbstractModSettings.get(
specific_settings if specific_settings is not None else {} specific_settings if specific_settings is not None else {}
) )
self.name: str = name self.name = name
self.nickname: str = nickname if nickname is not None else name self.nickname = nickname if nickname is not None else name
self.version: str = version if version is not None else "v1.0.0" self.version = version if version is not None else "v1.0.0"
self.variant: str = variant if variant is not None else "01" self.variant = variant if variant is not None else "01"
self.tags_templates: dict[str, str] = tags_templates if tags_templates is not None else {} self.tags_templates = tags_templates if tags_templates is not None else {}
self.tags_cups: list[Tag] = tags_cups if tags_cups is not None else [] self.tags_cups = tags_cups if tags_cups is not None else []
self._tracks: list["Track | TrackGroup"] = tracks if tracks is not None else [] self._tracks = tracks if tracks is not None else []
self.track_file_template: str = track_file_template \ self.track_file_template = track_file_template \
if track_file_template is not None else "{{ getattr(track, 'sha1', '_') }}" if track_file_template is not None else "{{ getattr(track, 'sha1', '_') }}"
self.multiplayer_disable_if: str = multiplayer_disable_if if multiplayer_disable_if is not None else "False" self.multiplayer_disable_if = multiplayer_disable_if if multiplayer_disable_if is not None else "False"
self.lpar_template: str = lpar_template if lpar_template is not None else "normal.lpar" self.lpar_template = lpar_template if lpar_template is not None else "normal.lpar"
self.arenas: list["Arena"] = arenas if arenas is not None else [] self.arenas = arenas if arenas is not None else []
self.original_track_prefix: bool = original_track_prefix if original_track_prefix is not None else True self.original_track_prefix = original_track_prefix if original_track_prefix is not None else True
self.swap_original_order: bool = swap_original_order if swap_original_order is not None else True self.swap_original_order = swap_original_order if swap_original_order is not None else True
self.keep_original_track: bool = keep_original_track if keep_original_track is not None else True self.keep_original_track = keep_original_track if keep_original_track is not None else True
self.enable_random_cup: bool = enable_random_cup if enable_random_cup is not None else True self.enable_random_cup = enable_random_cup if enable_random_cup is not None else True
def __repr__(self): def __repr__(self):
return f"<ModConfig name={self.name} version={self.version}>" return f"<ModConfig name={self.name} version={self.version}>"
@ -168,7 +173,7 @@ class ModConfig:
messages=json.loads(messages_file.read_text(encoding="utf8")) if messages_file.exists() else None, messages=json.loads(messages_file.read_text(encoding="utf8")) if messages_file.exists() else None,
) )
def get_safe_eval_env(self, base_env: dict[str, any] = None) -> dict[str, any]: def get_safe_eval_env(self, base_env: "Env" = None) -> dict[str, any]:
""" """
Return the env for the modconfig safe_eval function Return the env for the modconfig safe_eval function
:param base_env: additional environment :param base_env: additional environment
@ -182,14 +187,14 @@ class ModConfig:
base_env if base_env is not None else {} base_env if base_env is not None else {}
) )
def safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> any: def safe_eval(self, *args, env: "Env" = None, **kwargs) -> any:
""" """
Safe eval with useful modconfig environment Safe eval with useful modconfig environment
:return: the result of the evaluation :return: the result of the evaluation
""" """
return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs) return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs)
def multiple_safe_eval(self, *args, env: dict[str, any] = None, **kwargs) -> str: def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> str:
""" """
Multiple safe eval with useful modconfig environment Multiple safe eval with useful modconfig environment
:return: the str result of the evaluation :return: the str result of the evaluation
@ -227,8 +232,9 @@ class ModConfig:
:return: track or tracks groups elements :return: track or tracks groups elements
""" """
filter_template: str | None = self.global_settings["include_track_if"].value if not ignore_filter else None filter_template: "TemplateSafeEval | None" = self.global_settings["include_track_if"].value \
settings_sort: str | None = self.global_settings["sort_tracks"].value if not ignore_filter else None
settings_sort: "TemplateSafeEval | None" = self.global_settings["sort_tracks"].value
# filter_template_func is the function checking if the track should be included. If no parameter is set, # filter_template_func is the function checking if the track should be included. If no parameter is set,
# then always return True # then always return True
@ -309,7 +315,7 @@ class ModConfig:
yield from self.get_ordered_cups() yield from self.get_ordered_cups()
yield from self.get_unordered_cups() yield from self.get_unordered_cups()
def get_ctfile(self, template: str) -> str: def get_ctfile(self, template: "TemplateMultipleSafeEval") -> str:
""" """
Return the ct_file generated from the ModConfig Return the ct_file generated from the ModConfig
:template: template for the track name :template: template for the track name

View file

@ -5,6 +5,7 @@ from source.wt import ctc
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class CTFileLayer(AbstractLayer): class CTFileLayer(AbstractLayer):
@ -14,7 +15,7 @@ class CTFileLayer(AbstractLayer):
mode = "ctfile" mode = "ctfile"
def __init__(self, template: dict[str, str]): def __init__(self, template: "TemplateMultipleSafeEval"):
self.template = template self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -7,6 +7,7 @@ from source.wt import bmg
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class FormatOriginalTrackLayer(AbstractLayer): class FormatOriginalTrackLayer(AbstractLayer):
@ -16,7 +17,7 @@ class FormatOriginalTrackLayer(AbstractLayer):
mode = "format-original-track" mode = "format-original-track"
def __init__(self, template: str): def __init__(self, template: "TemplateMultipleSafeEval"):
self.template = template self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -1,9 +1,11 @@
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class IDLayer(AbstractLayer): class IDLayer(AbstractLayer):
@ -13,7 +15,7 @@ class IDLayer(AbstractLayer):
mode = "id" mode = "id"
def __init__(self, template: dict[str, str]): def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]):
self.template = template self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -5,6 +5,7 @@ from source.mkw.Patch.PatchOperation.BmgTxtEditor import AbstractLayer
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class RegexLayer(AbstractLayer): class RegexLayer(AbstractLayer):
@ -14,7 +15,7 @@ class RegexLayer(AbstractLayer):
mode = "regex" mode = "regex"
def __init__(self, template: dict[str, str]): def __init__(self, template: dict[str, "TemplateMultipleSafeEval"]):
self.template = template self.template = template
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str: def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:

View file

@ -2,8 +2,8 @@ from typing import TYPE_CHECKING
from PIL import Image from PIL import Image
from source.mkw.Patch import *
from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer
from source.mkw.Patch import PathOutsidePatch
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch

View file

@ -8,6 +8,7 @@ from source.mkw.Patch.PatchOperation.ImageEditor import AbstractLayer
if TYPE_CHECKING: if TYPE_CHECKING:
from source.mkw.Patch import Patch from source.mkw.Patch import Patch
from source import TemplateMultipleSafeEval
class TextLayer(AbstractLayer): class TextLayer(AbstractLayer):
@ -24,7 +25,7 @@ class TextLayer(AbstractLayer):
self.font_path: str | None = font_path self.font_path: str | None = font_path
self.font_size: int = font_size self.font_size: int = font_size
self.color: tuple[int] = tuple(color) self.color: tuple[int] = tuple(color)
self.text: str = text self.text: "TemplateMultipleSafeEval" = text
def patch_image(self, patch: "Patch", image: Image.Image) -> Image.Image: def patch_image(self, patch: "Patch", image: Image.Image) -> Image.Image:
draw = ImageDraw.Draw(image) draw = ImageDraw.Draw(image)

View file

@ -1,8 +1,11 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generator from typing import Generator, TYPE_CHECKING
from source.mkw import Slot, Tag, ModConfig from source.mkw import Slot, Tag, ModConfig
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
class TrackForbiddenCustomAttribute(Exception): class TrackForbiddenCustomAttribute(Exception):
def __init__(self, attribute_name: str): def __init__(self, attribute_name: str):
@ -39,7 +42,7 @@ class AbstractTrack(ABC):
yield self yield self
@abstractmethod @abstractmethod
def repr_format(self, mod_config: "ModConfig", template: str) -> str: def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
""" """
return the representation of the track from the format return the representation of the track from the format
:param template: template for the way the text will be represented :param template: template for the way the text will be represented
@ -66,7 +69,7 @@ class AbstractTrack(ABC):
""" """
... ...
def get_ctfile(self, mod_config: "ModConfig", template: str, hidden: bool = False) -> str: def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval", hidden: bool = False) -> str:
""" """
return the ctfile of the track return the ctfile of the track
:hidden: if the track is in a group :hidden: if the track is in a group

View file

@ -1,6 +1,11 @@
from typing import TYPE_CHECKING
from source.mkw import Slot, Tag from source.mkw import Slot, Tag
from source.mkw.Track.RealArenaTrack import RealArenaTrack from source.mkw.Track.RealArenaTrack import RealArenaTrack
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
class ArenaForbiddenCustomAttribute(Exception): class ArenaForbiddenCustomAttribute(Exception):
def __init__(self, attribute_name: str): def __init__(self, attribute_name: str):
@ -32,7 +37,7 @@ class Arena(RealArenaTrack):
def from_dict(cls, arena_dict: dict[str, any]) -> "Arena": def from_dict(cls, arena_dict: dict[str, any]) -> "Arena":
return cls(**arena_dict) return cls(**arena_dict)
def get_ctfile(self, mod_config: "ModConfig", template: str) -> (str, str): def get_ctfile(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> (str, str):
""" """
Return the ctfile for the arena and the redefinition of the slot property Return the ctfile for the arena and the redefinition of the slot property
:param mod_config: the mod_config object :param mod_config: the mod_config object

View file

@ -1,5 +1,10 @@
from typing import TYPE_CHECKING
from source.mkw.Track.AbstractTrack import AbstractTrack from source.mkw.Track.AbstractTrack import AbstractTrack
if TYPE_CHECKING:
from source.mkw.ModConfig import ModConfig
class DefaultTrack(AbstractTrack): class DefaultTrack(AbstractTrack):
def repr_format(self, mod_config: "ModConfig", template: str) -> str: def repr_format(self, mod_config: "ModConfig", template: str) -> str:

View file

@ -1,3 +1,10 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval
from source.mkw.ModConfig import ModConfig
class RealArenaTrack: class RealArenaTrack:
""" """
class shared between all arena and track class that represent a "real" track or arena class shared between all arena and track class that represent a "real" track or arena
@ -18,7 +25,7 @@ class RealArenaTrack:
return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag}) return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag})
return default return default
def repr_format(self, mod_config: "ModConfig", template: str) -> str: def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
return mod_config.multiple_safe_eval( return mod_config.multiple_safe_eval(
template, template,
env={ env={

View file

@ -8,10 +8,9 @@ ModConfig: any
class TrackGroup: class TrackGroup:
def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None, name: str = None): def __init__(self, tracks: list["Track"] = None, tags: list[Tag] = None):
self.tracks = tracks if tracks is not None else [] self.tracks = tracks if tracks is not None else []
self.tags = tags if tags is not None else [] self.tags = tags if tags is not None else []
self.name = name if name is not None else ""
def get_tracks(self) -> Generator["Track", None, None]: def get_tracks(self) -> Generator["Track", None, None]:
""" """
@ -34,7 +33,6 @@ class TrackGroup:
return cls( return cls(
tracks=[CustomTrack.from_dict(track) for track in group_dict["group"]], tracks=[CustomTrack.from_dict(track) for track in group_dict["group"]],
tags=group_dict.get("tags"), tags=group_dict.get("tags"),
name=group_dict.get("name"),
) )
def get_ctfile(self, mod_config: "ModConfig") -> str: def get_ctfile(self, mod_config: "ModConfig") -> str:

View file

@ -1,5 +1,8 @@
import re import re
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateSafeEval
MACRO_START, MACRO_END = "##", "##" MACRO_START, MACRO_END = "##", "##"
@ -9,7 +12,7 @@ class NotImplementedMacro(Exception):
super().__init__(f"Invalid macro while parsing macros:\n{macro}") super().__init__(f"Invalid macro while parsing macros:\n{macro}")
def replace_macro(template: str, macros: dict[str, str]) -> str: def replace_macro(template: str, macros: dict[str, "TemplateSafeEval"]) -> str:
""" """
Replace all the macro defined in macro by their respective value Replace all the macro defined in macro by their respective value
:param template: template where to replace the macro :param template: template where to replace the macro

View file

@ -1,12 +1,17 @@
import re import re
from typing import TYPE_CHECKING
from source.safe_eval import safe_eval from source.safe_eval import safe_eval
if TYPE_CHECKING:
from source import TemplateMultipleSafeEval, TemplateSafeEval, Env
TOKEN_START, TOKEN_END = "{{", "}}" TOKEN_START, TOKEN_END = "{{", "}}"
def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str: def multiple_safe_eval(template: "TemplateMultipleSafeEval", env: "Env" = None,
macros: dict[str, "TemplateSafeEval"] = None) -> str:
""" """
Similar to safe_eval, but expression need to be enclosed between "{{" and "}}". Similar to safe_eval, but expression need to be enclosed between "{{" and "}}".
Example : "{{ track.author }} is the track creator !" Example : "{{ track.author }} is the track creator !"

View file

@ -1,9 +1,13 @@
import ast import ast
import copy import copy
from typing import TYPE_CHECKING
from source.safe_eval.macros import replace_macro from source.safe_eval.macros import replace_macro
from source.safe_eval.safe_function import get_all_safe_functions from source.safe_eval.safe_function import get_all_safe_functions
if TYPE_CHECKING:
from source import TemplateSafeEval, Env
class SafeEvalException(Exception): class SafeEvalException(Exception):
def __init__(self, message: str): def __init__(self, message: str):
@ -19,7 +23,7 @@ all_globals = {
} }
def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None, def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str, "TemplateSafeEval"] = None,
return_lambda: bool = False, lambda_args: list[str] = None) -> any: return_lambda: bool = False, lambda_args: list[str] = None) -> any:
""" """
Run a python code in an eval function, but avoid all potential dangerous function. Run a python code in an eval function, but avoid all potential dangerous function.

View file

@ -1,4 +1,7 @@
from typing import Callable, Generator from typing import Callable, Generator, TYPE_CHECKING
if TYPE_CHECKING:
from source import TemplateSafeEval, Env
def get_all_safe_functions() -> Generator[list[Callable], None, None]: def get_all_safe_functions() -> Generator[list[Callable], None, None]:
@ -41,7 +44,7 @@ class safe_function:
return type(obj) return type(obj)
@staticmethod @staticmethod
def eval(template: str, env: dict | None = None): def eval(template: "TemplateSafeEval", env: "Env | None" = None):
""" """
Allow a recursive safe_eval, but without the lambda functionality Allow a recursive safe_eval, but without the lambda functionality
""" """