safe_eval and multiple_safe_eval are now 20x faster (~2.5s -> ~0.13s)

This commit is contained in:
Faraphel 2022-08-15 10:58:36 +02:00
parent 89de19c723
commit 8afd7e7926
13 changed files with 78 additions and 62 deletions

View file

@ -420,7 +420,7 @@ class ButtonInstall(ttk.Button):
message = message_texts.get(self.root.options["language"])
if message is None: message = message_texts.get("*")
if message is None: message = _('NO_MESSAGE_FROM_AUTHOR')
message = mod_config.multiple_safe_eval(message)
message = mod_config.multiple_safe_eval(message)()
messagebox.showinfo(
_("INSTALLATION_COMPLETED"),

View file

@ -78,10 +78,7 @@ class Window(AbstractPreviewWindow):
self.text_track_select.configure(state=tkinter.NORMAL)
self.text_track_select.delete(1.0, tkinter.END)
template_func = self.mod_config.safe_eval(
self.entry_template_input.get(),
return_lambda=True, lambda_args=["track"]
)
template_func = self.mod_config.safe_eval(self.entry_template_input.get(), args=["track"])
for track in self.mod_config.get_all_tracks(ignore_filter=True):
value: bool = template_func(track) is True

View file

@ -127,7 +127,7 @@ class ExtractedGame:
ct_file.write_text(mod_config.get_ctfile(template="-"))
lpar_dir: Path = mod_config.path.parent / "_LPAR/"
lpar: Path = lpar_dir / mod_config.multiple_safe_eval(mod_config.lpar_template)
lpar: Path = lpar_dir / mod_config.multiple_safe_eval(mod_config.lpar_template)()
if not lpar.is_relative_to(lpar_dir): raise PathOutsideMod(lpar, lpar_dir)
for lecode_file in (self.path / "files/rel/").glob("lecode-*.bin"):

View file

@ -187,14 +187,14 @@ class ModConfig:
base_env if base_env is not None else {}
)
def safe_eval(self, *args, env: "Env" = None, **kwargs) -> any:
def safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable:
"""
Safe eval with useful modconfig environment
:return: the result of the evaluation
"""
return safe_eval(*args, env=self.get_safe_eval_env(base_env=env), macros=self.macros, **kwargs)
def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> str:
def multiple_safe_eval(self, *args, env: "Env" = None, **kwargs) -> Callable:
"""
Multiple safe eval with useful modconfig environment
:return: the str result of the evaluation
@ -239,9 +239,7 @@ class ModConfig:
# filter_template_func is the function checking if the track should be included. If no parameter is set,
# then always return True
filter_template_func: Callable = self.safe_eval(
filter_template if filter_template is not None else "True",
return_lambda=True,
lambda_args=["track"]
filter_template if filter_template is not None else "True", args=["track"]
)
# if a sorting function is set, use it. If no function is set, but sorting is not disabled, use settings.
@ -249,9 +247,7 @@ class ModConfig:
if not ignore_sorting and (sorting_template is not None or settings_sort is not None):
# get the sorting_template_func. If not defined, use the settings one.
sorting_template_func: Callable = self.safe_eval(
template=sorting_template if sorting_template is not None else settings_sort,
return_lambda=True,
lambda_args=["track"]
template=sorting_template if sorting_template is not None else settings_sort, args=["track"]
)
# wrap the iterator inside a sort function
@ -435,8 +431,7 @@ class ModConfig:
track_directory = self.path.parent / "_TRACKS"
multiplayer_disable_if_func: Callable = self.safe_eval(
self.multiplayer_disable_if,
return_lambda=True, lambda_args=["track"]
self.multiplayer_disable_if, args=["track"]
)
for track in self.get_all_arenas_tracks():

View file

@ -77,4 +77,7 @@ class PatchObject(ABC):
:param extracted_game: the extracted game object
:return: should the patch be applied ?
"""
return self.patch.mod_config.safe_eval(self.configuration["if"], env={"extracted_game": extracted_game}) is True
return self.patch.mod_config.safe_eval(
self.configuration["if"],
env={"extracted_game": extracted_game}
)() is True

View file

@ -20,6 +20,6 @@ class IDLayer(AbstractLayer):
def patch_bmg(self, patch: "Patch", decoded_content: str) -> str:
return decoded_content + "\n" + ("\n".join(
[f" {id}\t= {patch.mod_config.multiple_safe_eval(repl)}" for id, repl in self.template.items()]
[f" {id}\t= {patch.mod_config.multiple_safe_eval(repl)()}" for id, repl in self.template.items()]
)) + "\n"
# add new bmg definition at the end of the bmg file, overwritting old id.

View file

@ -30,7 +30,7 @@ class RegexLayer(AbstractLayer):
for pattern, repl in self.template.items():
value = re.sub(
pattern,
patch.mod_config.multiple_safe_eval(repl),
patch.mod_config.multiple_safe_eval(repl)(),
value,
flags=re.DOTALL
)

View file

@ -45,7 +45,7 @@ class TextLayer(AbstractLayer):
)
draw.text(
self.get_layer_position(image),
text=patch.mod_config.multiple_safe_eval(self.text),
text=patch.mod_config.multiple_safe_eval(self.text)(),
fill=self.color,
font=font
)

View file

@ -24,4 +24,7 @@ class CustomTrack(RealArenaTrack, AbstractTrack):
return cls(**track_dict)
def is_new(self, mod_config: "ModConfig") -> bool:
return mod_config.safe_eval(mod_config.global_settings["replace_random_new"].value, env={"track": self}) is True
return mod_config.safe_eval(
mod_config.global_settings["replace_random_new"].value,
args=["track"]
)(track=self) is True

View file

@ -22,16 +22,20 @@ class RealArenaTrack:
:return: formatted representation of the tag
"""
for tag in filter(lambda tag: tag in mod_config.tags_templates[template_name], self.tags):
return mod_config.multiple_safe_eval(mod_config.tags_templates[template_name][tag], env={"tag": tag})
return mod_config.multiple_safe_eval(
mod_config.tags_templates[template_name][tag],
args=["tag"],
)(tag=tag)
return default
def repr_format(self, mod_config: "ModConfig", template: "TemplateMultipleSafeEval") -> str:
return mod_config.multiple_safe_eval(
template,
env={
"track": self,
"get_tag_template": lambda *args, **kwargs: self.get_tag_template(mod_config, *args, **kwargs)
}
args=["track", "get_tag_template"]
)(
track=self,
get_tag_template=lambda *args, **kwargs: self.get_tag_template(mod_config, *args, **kwargs)
# get_tag_template can't be in env because it is dependent of the track self
)
def get_filename(self, mod_config: "ModConfig") -> str:

View file

@ -1,5 +1,4 @@
import re
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable, Callable
from source.safe_eval import safe_eval
@ -11,21 +10,39 @@ TOKEN_START, TOKEN_END = "{{", "}}"
def multiple_safe_eval(template: "TemplateMultipleSafeEval", env: "Env" = None,
macros: dict[str, "TemplateSafeEval"] = None) -> str:
macros: dict[str, "TemplateSafeEval"] = None, args: Iterable[str] = None) -> Callable:
"""
Similar to safe_eval, but expression need to be enclosed between "{{" and "}}".
Example : "{{ track.author }} is the track creator !"
"""
def format_part_template(match: re.Match) -> str:
"""
when a token is found, replace it by the corresponding value
:param match: match in the format
:return: corresponding value
"""
# get the token string without the brackets, then strip it. Also double antislash
part_template = match.group(1).strip()
return str(safe_eval(template=part_template, env=env, macros=macros))
lambda_templates: list[str | Callable] = []
while len(template) > 0:
token_start: int = template.find(TOKEN_START) # the position of the "{{"
part_template_start: int = token_start + len(TOKEN_START) # the position just after the start token
part_template_end: int = template.find(TOKEN_END) # the position before the end token
token_end: int = part_template_end + len(TOKEN_END) # the end position of the "}}"
# if there is no more template, add all the template into the lambda
if token_start < 0 or part_template_end < 0:
lambda_templates.append(template)
template = ""
# if there is still a template part, add the remaining text, then add the lambda template between the tokens.
else:
lambda_templates.append(template[:token_start])
lambda_templates.append(safe_eval(
template=template[part_template_start:part_template_end].strip(),
env=env,
macros=macros,
args=args,
))
template = template[token_end:]
return lambda *args, **kwargs: "".join([
str(part(*args, **kwargs)) if callable(part) else part
for part in lambda_templates
])
# pass everything between TOKEN_START and TOKEN_END in the function
return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template)

View file

@ -1,6 +1,6 @@
import ast
import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Iterable, Callable
from source.safe_eval.macros import replace_macro
from source.safe_eval.safe_function import get_all_safe_functions
@ -28,29 +28,27 @@ all_globals = {
def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str, "TemplateSafeEval"] = None,
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
args: Iterable[str] = None) -> Callable:
"""
Run a python code in an eval function, but avoid all potential dangerous function.
:env: additional variables that will be used when evaluating the template
:return_lambda: if enabled, return a lambda function instead of the result of the expression
:lambda_args: arguments that the final lambda function can receive
:macros: dictionary associating a macro name to a macro value
:return: the evaluated expression or the lambda expression
:return: the lambda expression
"""
if len(template) == 0: return ""
if len(template) == 0: return lambda *_, **__: ""
if env is None: env = {}
if macros is None: macros = {}
if lambda_args is None: lambda_args = []
args = tuple(args) if args is not None else () # allow the argument to be any iterable
template_key: tuple = (template, args) # unique identifiant for every template (need to be hashable)
# if the safe_eval return a callable and have already been called, return the cached callable
if return_lambda is True and template in self.safe_eval_cache:
return self.safe_eval_cache[template]
if template_key in self.safe_eval_cache: return self.safe_eval_cache[template_key]
# replace the macro in the template
template = replace_macro(template=template, macros=macros)
# escape backslash to avoid unreadable expression
template = template.replace("\\", "\\\\")
@ -113,16 +111,15 @@ def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str,
):
raise SafeEvalException(f'Forbidden syntax : "{type(node).__name__}"')
if return_lambda:
# if return_lambda is enabled, embed the whole expression into a lambda expression
stmt.value = ast.Lambda(
body=stmt.value,
args=ast.arguments(
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
posonlyargs=[], kwonlyargs=[],
kw_defaults=[], defaults=[],
)
# embed the whole expression into a lambda expression
stmt.value = ast.Lambda(
body=stmt.value,
args=ast.arguments(
args=[ast.arg(arg=arg) for arg in args],
posonlyargs=[], kwonlyargs=[],
kw_defaults=[], defaults=[],
)
)
# convert into a ast.Expression, object needed for the compilation
expression: ast.Expression = ast.Expression(stmt.value)
@ -131,6 +128,6 @@ def safe_eval(template: "TemplateSafeEval", env: "Env" = None, macros: dict[str,
ast.fix_missing_locations(expression)
# return the evaluated formula
result = eval(compile(expression, "<safe_eval>", "eval"), globals_, locals_)
if return_lambda: self.safe_eval_cache[template] = result # cache the callable for potential latter call
return result
lambda_template = eval(compile(expression, "<safe_eval>", "eval"), globals_, locals_)
self.safe_eval_cache[template_key] = lambda_template # cache the callable for potential latter call
return lambda_template

View file

@ -49,4 +49,4 @@ class safe_function:
Allow a recursive safe_eval, but without the lambda functionality
"""
from source.safe_eval.safe_eval import safe_eval
return safe_eval(template=template, env=env)
return safe_eval(template=template, env=env)()