mirror of
https://github.com/Faraphel/Atlas-Install.git
synced 2025-07-02 18:58:27 +02:00
safe_eval and multiple_safe_eval are now 20x faster (~2.5s -> ~0.13s)
This commit is contained in:
parent
89de19c723
commit
8afd7e7926
13 changed files with 78 additions and 62 deletions
|
@ -420,7 +420,7 @@ class ButtonInstall(ttk.Button):
|
|||
message = message_texts.get(self.root.options["language"])
|
||||
if message is None: message = message_texts.get("*")
|
||||
if message is None: message = _('NO_MESSAGE_FROM_AUTHOR')
|
||||
message = mod_config.multiple_safe_eval(message)
|
||||
message = mod_config.multiple_safe_eval(message)()
|
||||
|
||||
messagebox.showinfo(
|
||||
_("INSTALLATION_COMPLETED"),
|
||||
|
|
|
@ -78,10 +78,7 @@ class Window(AbstractPreviewWindow):
|
|||
self.text_track_select.configure(state=tkinter.NORMAL)
|
||||
self.text_track_select.delete(1.0, tkinter.END)
|
||||
|
||||
template_func = self.mod_config.safe_eval(
|
||||
self.entry_template_input.get(),
|
||||
return_lambda=True, lambda_args=["track"]
|
||||
)
|
||||
template_func = self.mod_config.safe_eval(self.entry_template_input.get(), args=["track"])
|
||||
|
||||
for track in self.mod_config.get_all_tracks(ignore_filter=True):
|
||||
value: bool = template_func(track) is True
|
||||
|
|
|
@ -127,7 +127,7 @@ class ExtractedGame:
|
|||
ct_file.write_text(mod_config.get_ctfile(template="-"))
|
||||
|
||||
lpar_dir: Path = mod_config.path.parent / "_LPAR/"
|
||||
lpar: Path = lpar_dir / mod_config.multiple_safe_eval(mod_config.lpar_template)
|
||||
lpar: Path = lpar_dir / mod_config.multiple_safe_eval(mod_config.lpar_template)()
|
||||
if not lpar.is_relative_to(lpar_dir): raise PathOutsideMod(lpar, lpar_dir)
|
||||
|
||||
for lecode_file in (self.path / "files/rel/").glob("lecode-*.bin"):
|
||||
|
|
|
@ -187,14 +187,14 @@ class ModConfig:
|
|||
base_env if base_env is not None else {}
|
||||
)
|
||||
|
||||
def safe_eval(self, *args, env: "Env" = None, **kwargs) -> any:
|
||||
def safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable:
|
||||
"""
|
||||
Safe eval with useful modconfig environment
|
||||
:return: the result of the evaluation
|
||||
"""
|
||||
return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs)
|
||||
|
||||
def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> str:
|
||||
def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable:
|
||||
"""
|
||||
Multiple safe eval with useful modconfig environment
|
||||
:return: the str result of the evaluation
|
||||
|
@ -239,9 +239,7 @@ class ModConfig:
|
|||
# filter_template_func is the function checking if the track should be included. If no parameter is set,
|
||||
# then always return True
|
||||
filter_template_func: Callable = self.safe_eval(
|
||||
filter_template if filter_template is not None else "True",
|
||||
return_lambda=True,
|
||||
lambda_args=["track"]
|
||||
filter_template if filter_template is not None else "True", args=["track"]
|
||||
)
|
||||
|
||||
# if a sorting function is set, use it. If no function is set, but sorting is not disabled, use settings.
|
||||
|
@ -249,9 +247,7 @@ class ModConfig:
|
|||
if not ignore_sorting and (sorting_template is not None or settings_sort is not None):
|
||||
# get the sorting_template_func. If not defined, use the settings one.
|
||||
sorting_template_func: Callable = self.safe_eval(
|
||||
template=sorting_template if sorting_template is not None else settings_sort,
|
||||
return_lambda=True,
|
||||
lambda_args=["track"]
|
||||
template=sorting_template if sorting_template is not None else settings_sort, args=["track"]
|
||||
)
|
||||
|
||||
# wrap the iterator inside a sort function
|
||||
|
@ -435,8 +431,7 @@ class ModConfig:
|
|||
|
||||
track_directory = self.path.parent / "_TRACKS"
|
||||
multiplayer_disable_if_func: Callable = self.safe_eval(
|
||||
self.multiplayer_disable_if,
|
||||
return_lambda=True, lambda_args=["track"]
|
||||
self.multiplayer_disable_if, args=["track"]
|
||||
)
|
||||
|
||||
for track in self.get_all_arenas_tracks():
|
||||
|
|
|
@ -77,4 +77,7 @@ class PatchObject(ABC):
|
|||
:param extracted_game: the extracted game object
|
||||
:return: should the patch be applied ?
|
||||
"""
|
||||
return self.patch.mod_config.safe_eval(self.configuration["if"], env={"extracted_game": extracted_game}) is True
|
||||
return self.patch.mod_config.safe_eval(
|
||||
self.configuration["if"],
|
||||
env={"extracted_game": extracted_game}
|
||||
)() is True
|
||||
|
|
|
@ -20,6 +20,6 @@ class IDLayer(AbstractLayer):
|
|||
|
||||
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:
|
||||
return decoded_content + "\n" + ("\n".join(
|
||||
[f" {id}\t= {patch.mod_config.multiple_safe_eval(repl)}" for id, repl in self.template.items()]
|
||||
[f" {id}\t= {patch.mod_config.multiple_safe_eval(repl)()}" for id, repl in self.template.items()]
|
||||
)) + "\n"
|
||||
# add new bmg definition at the end of the bmg file, overwritting old id.
|
||||
|
|
|
@ -30,7 +30,7 @@ class RegexLayer(AbstractLayer):
|
|||
for pattern, repl in self.template.items():
|
||||
value = re.sub(
|
||||
pattern,
|
||||
patch.mod_config.multiple_safe_eval(repl),
|
||||
patch.mod_config.multiple_safe_eval(repl)(),
|
||||
value,
|
||||
flags=re.DOTALL
|
||||
)
|
||||
|
|
|
@ -45,7 +45,7 @@ class TextLayer(AbstractLayer):
|
|||
)
|
||||
draw.text(
|
||||
self.get_layer_position(image),
|
||||
text=patch.mod_config.multiple_safe_eval(self.text),
|
||||
text=patch.mod_config.multiple_safe_eval(self.text)(),
|
||||
fill=self.color,
|
||||
font=font
|
||||
)
|
||||
|
|
|
@ -24,4 +24,7 @@ class CustomTrack(RealArenaTrack, AbstractTrack):
|
|||
return cls(**track_dict)
|
||||
|
||||
def is_new(self, mod_config: "ModConfig") -> bool:
|
||||
return mod_config.safe_eval(mod_config.global_settings["replace_random_new"].value, env={"track": self}) is True
|
||||
return mod_config.safe_eval(
|
||||
mod_config.global_settings["replace_random_new"].value,
|
||||
args=["track"]
|
||||
)(track=self) is True
|
||||
|
|
|
@ -22,16 +22,20 @@ class RealArenaTrack:
|
|||
:return: formatted representation of the tag
|
||||
"""
|
||||
for tag in filter(lambda tag: tag in mod_config.tags_templates[template_name], self.tags):
|
||||
return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag})
|
||||
return mod_config.multiple_safe_eval(
|
||||
mod_config.tags_templates[template_name][tag],
|
||||
args=["tag"],
|
||||
)(tag=tag)
|
||||
return default
|
||||
|
||||
def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
|
||||
return mod_config.multiple_safe_eval(
|
||||
template,
|
||||
env={
|
||||
"track": self,
|
||||
"get_tag_template": lambda *args, **kwargs: self.get_tag_template(mod_config, *args, **kwargs)
|
||||
}
|
||||
args=["track", "get_tag_template"]
|
||||
)(
|
||||
track=self,
|
||||
get_tag_template=lambda *args, **kwargs: self.get_tag_template(mod_config, *args, **kwargs)
|
||||
# get_tag_template can't be in env because it is dependent of the track self
|
||||
)
|
||||
|
||||
def get_filename(self, mod_config: "ModConfig") -> str:
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Iterable, Callable
|
||||
|
||||
from source.safe_eval import safe_eval
|
||||
|
||||
|
@ -11,21 +10,39 @@ TOKEN_START, TOKEN_END = "{{", "}}"
|
|||
|
||||
|
||||
def multiple_safe_eval(template: "TemplateMultipleSafeEval", env: "Env" = None,
|
||||
macros: dict[str, "TemplateSafeEval"] = None) -> str:
|
||||
macros: dict[str, "TemplateSafeEval"] = None, args: Iterable[str] = None) -> Callable:
|
||||
"""
|
||||
Similar to safe_eval, but expression need to be enclosed between "{{" and "}}".
|
||||
Example : "{{ track.author }} is the track creator !"
|
||||
"""
|
||||
|
||||
def format_part_template(match: re.Match) -> str:
|
||||
"""
|
||||
when a token is found, replace it by the corresponding value
|
||||
:param match: match in the format
|
||||
:return: corresponding value
|
||||
"""
|
||||
# get the token string without the brackets, then strip it. Also double antislash
|
||||
part_template = match.group(1).strip()
|
||||
return str(safe_eval(template=part_template, env=env, macros=macros))
|
||||
lambda_templates: list[str | Callable] = []
|
||||
|
||||
while len(template) > 0:
|
||||
token_start: int = template.find(TOKEN_START) # the position of the "{{"
|
||||
part_template_start: int = token_start + len(TOKEN_START) # the position just after the start token
|
||||
part_template_end: int = template.find(TOKEN_END) # the position before the end token
|
||||
token_end: int = part_template_end + len(TOKEN_END) # the end position of the "}}"
|
||||
|
||||
# if there is no more template, add all the template into the lambda
|
||||
if token_start < 0 or part_template_end < 0:
|
||||
lambda_templates.append(template)
|
||||
template = ""
|
||||
|
||||
# if there is still a template part, add the remaining text, then add the lambda template between the tokens.
|
||||
else:
|
||||
lambda_templates.append(template[:token_start])
|
||||
lambda_templates.append(safe_eval(
|
||||
template=template[part_template_start:part_template_end].strip(),
|
||||
env=env,
|
||||
macros=macros,
|
||||
args=args,
|
||||
))
|
||||
template = template[token_end:]
|
||||
|
||||
return lambda *args, **kwargs: "".join([
|
||||
str(part(*args, **kwargs)) if callable(part) else part
|
||||
for part in lambda_templates
|
||||
])
|
||||
|
||||
|
||||
# pass everything between TOKEN_START and TOKEN_END in the function
|
||||
return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import ast
|
||||
import copy
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Iterable, Callable
|
||||
|
||||
from source.safe_eval.macros import replace_macro
|
||||
from source.safe_eval.safe_function import get_all_safe_functions
|
||||
|
@ -28,29 +28,27 @@ all_globals = {
|
|||
|
||||
|
||||
def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str, "TemplateSafeEval"] = None,
|
||||
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
|
||||
args: Iterable[str] = None) -> Callable:
|
||||
"""
|
||||
Run a python code in an eval function, but avoid all potential dangerous function.
|
||||
:env: additional variables that will be used when evaluating the template
|
||||
:return_lambda: if enabled, return a lambda function instead of the result of the expression
|
||||
:lambda_args: arguments that the final lambda function can receive
|
||||
:macros: dictionary associating a macro name to a macro value
|
||||
|
||||
:return: the evaluated expression or the lambda expression
|
||||
:return: the lambda expression
|
||||
"""
|
||||
|
||||
if len(template) == 0: return ""
|
||||
if len(template) == 0: return lambda *_, **__: ""
|
||||
if env is None: env = {}
|
||||
if macros is None: macros = {}
|
||||
if lambda_args is None: lambda_args = []
|
||||
args = tuple(args) if args is not None else () # allow the argument to be any iterable
|
||||
|
||||
template_key: tuple = (template, args) # unique identifiant for every template (need to be hashable)
|
||||
# if the safe_eval return a callable and have already been called, return the cached callable
|
||||
if return_lambda is True and template in self.safe_eval_cache:
|
||||
return self.safe_eval_cache[template]
|
||||
if template_key in self.safe_eval_cache: return self.safe_eval_cache[template_key]
|
||||
|
||||
# replace the macro in the template
|
||||
template = replace_macro(template=template, macros=macros)
|
||||
|
||||
# escape backslash to avoid unreadable expression
|
||||
template = template.replace("\\", "\\\\")
|
||||
|
||||
|
@ -113,16 +111,15 @@ def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str,
|
|||
):
|
||||
raise SafeEvalException(f'Forbidden syntax : "{type(node).__name__}"')
|
||||
|
||||
if return_lambda:
|
||||
# if return_lambda is enabled, embed the whole expression into a lambda expression
|
||||
stmt.value = ast.Lambda(
|
||||
body=stmt.value,
|
||||
args=ast.arguments(
|
||||
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
|
||||
posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[],
|
||||
)
|
||||
# embed the whole expression into a lambda expression
|
||||
stmt.value = ast.Lambda(
|
||||
body=stmt.value,
|
||||
args=ast.arguments(
|
||||
args=[ast.arg(arg=arg) for arg in args],
|
||||
posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[],
|
||||
)
|
||||
)
|
||||
|
||||
# convert into a ast.Expression, object needed for the compilation
|
||||
expression: ast.Expression = ast.Expression(stmt.value)
|
||||
|
@ -131,6 +128,6 @@ def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str,
|
|||
ast.fix_missing_locations(expression)
|
||||
|
||||
# return the evaluated formula
|
||||
result = eval(compile(expression, "<safe_eval>", "eval"), globals_, locals_)
|
||||
if return_lambda: self.safe_eval_cache[template] = result # cache the callable for potential latter call
|
||||
return result
|
||||
lambda_template = eval(compile(expression, "<safe_eval>", "eval"), globals_, locals_)
|
||||
self.safe_eval_cache[template_key] = lambda_template # cache the callable for potential latter call
|
||||
return lambda_template
|
||||
|
|
|
@ -49,4 +49,4 @@ class safe_function:
|
|||
Allow a recursive safe_eval, but without the lambda functionality
|
||||
"""
|
||||
from source.safe_eval.safe_eval import safe_eval
|
||||
return safe_eval(template=template, env=env)
|
||||
return safe_eval(template=template, env=env)()
|
||||
|
|
Loading…
Reference in a new issue