mirror of
https://github.com/Faraphel/Atlas-Install.git
synced 2025-07-02 18:58:27 +02:00
replaced the old safe_eval by the new safe_eval (from test, normal mode allow for ~26s -> ~13s, and with lambda mode ~26s -> ~0.03s)
This commit is contained in:
parent
998d1274ef
commit
42fef0b2e3
6 changed files with 176 additions and 278 deletions
|
@ -1,160 +0,0 @@
|
|||
import re
|
||||
from typing import Callable
|
||||
|
||||
common_token_map = { # these operators and function are considered safe to use in the template
|
||||
operator: operator
|
||||
for operator in
|
||||
[">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<",
|
||||
"and", "or", "&", "|", "^", "~", ":", "{", "}", "abs", "int",
|
||||
"bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip",
|
||||
"any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None",
|
||||
"True", "False", "getattr", "dict"
|
||||
]
|
||||
} | { # these methods are considered safe, except for the magic methods
|
||||
f".{method}": f".{method}"
|
||||
for method in dir(str) + dir(list) + dir(int) + dir(float) + dir(dict)
|
||||
if not method.startswith("__")
|
||||
}
|
||||
|
||||
TOKEN_START, TOKEN_END = "{{", "}}"
|
||||
MACRO_START, MACRO_END = "##", "##"
|
||||
|
||||
|
||||
class TemplateParsingError(Exception):
|
||||
def __init__(self, token: str):
|
||||
super().__init__(f"Invalid token while parsing safe_eval:\n{token}")
|
||||
|
||||
|
||||
class NotImplementedMacro(Exception):
|
||||
def __init__(self, macro: str):
|
||||
super().__init__(f"Invalid macro while parsing macros:\n{macro}")
|
||||
|
||||
|
||||
class SafeFunction:
|
||||
@classmethod
|
||||
def get_all_safe_methods(cls) -> dict[str, Callable]:
|
||||
"""
|
||||
get all the safe methods defined by the class
|
||||
:return: all the safe methods defined by the class
|
||||
"""
|
||||
return {
|
||||
method: getattr(cls, method)
|
||||
for method in dir(cls)
|
||||
if not method.startswith("__") and method not in ["get_all_safe_methods"]
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def getattr(obj: any, attr: str, default: any = None) -> any:
|
||||
"""
|
||||
Safe getattr, raise an error if the attribute is a function
|
||||
:param obj: object to get the attribute from
|
||||
:param attr: attribute name
|
||||
:param default: default value if the attribute is not found
|
||||
:return: the attribute value
|
||||
"""
|
||||
attr = getattr(obj, attr) if default is None else getattr(obj, attr, default)
|
||||
if callable(attr): raise AttributeError(f"getattr can't be used for functions (tried: {attr})")
|
||||
return attr
|
||||
|
||||
@staticmethod
|
||||
def type(obj: any) -> any:
|
||||
"""
|
||||
Safe type, can only be used to determinate the type of an object
|
||||
(It can be used to create new class by using all the 3 args)
|
||||
"""
|
||||
return type(obj)
|
||||
|
||||
|
||||
def replace_macro(template: str, macros: dict[str, str]) -> str:
|
||||
"""
|
||||
Replace all the macro defined in macro by their respective value
|
||||
:param template: template where to replace the macro
|
||||
:param macros: dictionary associating macro with their replacement
|
||||
:return: the template with macro replaced
|
||||
"""
|
||||
|
||||
def format_macro(match: re.Match) -> str:
|
||||
if (macro := macros.get(match.group(1).strip())) is None: raise NotImplementedMacro(macro)
|
||||
return macro
|
||||
|
||||
# match everything between MACRO_START and MACRO_END.
|
||||
return re.sub(rf"{MACRO_START}(.*?){MACRO_END}", format_macro, template)
|
||||
|
||||
|
||||
def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str:
|
||||
"""
|
||||
Evaluate the template and return the result in a safe way
|
||||
:param env: variables to use when using eval
|
||||
:param template: template to evaluate
|
||||
:param macros: additional macro to replace in the template
|
||||
"""
|
||||
if env is None: env = {}
|
||||
if macros is None: macros = {}
|
||||
|
||||
template = replace_macro(template, macros)
|
||||
token_map: dict[str, str] = common_token_map | {var: var for var in env}
|
||||
final_token: str = ""
|
||||
|
||||
def matched(match: re.Match | str | None, value: str = None) -> bool:
|
||||
"""
|
||||
check if token is matched, if yes, add it to the final token and remove it from the processing token
|
||||
:param match: match object
|
||||
:param value: if the match is a string, the value to replace the text with
|
||||
:return: True if matched, False otherwise
|
||||
"""
|
||||
nonlocal final_token, template
|
||||
|
||||
# if there is no match or the string is empty, return False
|
||||
if not match: return False
|
||||
|
||||
if isinstance(match, re.Match):
|
||||
template_raw = template[match.end():]
|
||||
value = match.group()
|
||||
|
||||
else:
|
||||
if not template.startswith(match): return False
|
||||
template_raw = template[len(match):]
|
||||
|
||||
template = template_raw.lstrip()
|
||||
final_token += value + (len(template_raw) - len(template)) * " "
|
||||
return True
|
||||
|
||||
while template: # while there is still tokens to process
|
||||
# if the section is a string, add it to the final token
|
||||
# example : "hello", "hello \" world"
|
||||
if matched(re.match(r'^(["\'])((\\{2})*|(.*?[^\\](\\{2})*))\1', template)):
|
||||
continue
|
||||
|
||||
# if the section is a float or an int, add it to the final token
|
||||
# example : 102, 102.59
|
||||
if matched(re.match(r'^[0-9]+(?:\.[0-9]+)?', template)):
|
||||
continue
|
||||
|
||||
# if the section is a variable, operator or function, replace it by its value
|
||||
# example : track.special, +
|
||||
for key, value in token_map.items():
|
||||
if matched(key, value): break
|
||||
|
||||
# else, the token is invalid, so raise an error
|
||||
else:
|
||||
raise TemplateParsingError(template)
|
||||
|
||||
# if final_token is set, eval final_token and return the result
|
||||
if final_token: return str(eval(final_token.replace("\\", "\\\\"), SafeFunction.get_all_safe_methods(), env))
|
||||
else: return final_token
|
||||
|
||||
|
||||
def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str:
|
||||
def format_part_template(match: re.Match) -> str:
|
||||
"""
|
||||
when a token is found, replace it by the corresponding value
|
||||
:param match: match in the format
|
||||
:return: corresponding value
|
||||
"""
|
||||
# get the token string without the brackets, then strip it. Also double antislash
|
||||
part_template = match.group(1).strip()
|
||||
return safe_eval(part_template, env, macros)
|
||||
|
||||
# pass everything between TOKEN_START and TOKEN_END in the function
|
||||
return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template)
|
||||
|
|
@ -1,112 +1,2 @@
|
|||
import ast
|
||||
import copy
|
||||
|
||||
from source.safe_eval.safe_function import get_all_safe_functions
|
||||
|
||||
|
||||
# dict of every value used by every safe_eval call
|
||||
all_globals = {
|
||||
"__builtins__": {},
|
||||
"deepcopy": copy.deepcopy
|
||||
} | {
|
||||
func.__name__: func for func in get_all_safe_functions()
|
||||
}
|
||||
|
||||
|
||||
def safe_eval(template: str, env: dict[str, any] = None,
|
||||
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
|
||||
"""
|
||||
Run a python code in an eval function, but avoid all potential dangerous function.
|
||||
:env: additional variables that will be used when evaluating the template
|
||||
:return_lambda: if enabled, return a lambda function instead of the result of the expression
|
||||
:lambda_args: arguments that the final lambda function can receive
|
||||
|
||||
:return: the evaluated expression or the lambda expression
|
||||
"""
|
||||
|
||||
if len(template) == 0: return ""
|
||||
if env is None: env = {}
|
||||
|
||||
# prepare the execution environment
|
||||
globals_ = all_globals | env
|
||||
locals_ = {}
|
||||
|
||||
# convert the template to an ast expression
|
||||
stmt: ast.stmt = ast.parse(template).body[0]
|
||||
if not isinstance(stmt, ast.Expr):
|
||||
raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"')
|
||||
|
||||
# check every node for disabled expression
|
||||
for node in ast.walk(stmt):
|
||||
match type(node):
|
||||
|
||||
# when accessing any attribute
|
||||
case ast.Attribute:
|
||||
# ban all magical function, disabling the __class__.__bases__[0] ... tricks
|
||||
if "__" in node.attr:
|
||||
raise Exception(f'Magic attribute are forbidden : "{node.attr}"')
|
||||
|
||||
# ban modification to environment
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
raise Exception(f'Can\'t set value of attribute : "{node.attr}"')
|
||||
|
||||
# when accessing any variable
|
||||
case ast.Name:
|
||||
# ban modification to environment, but allow custom variable to be changed
|
||||
if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_:
|
||||
raise Exception(f'Can\'t set value of environment : "{node.id}"')
|
||||
|
||||
case ast.Lambda:
|
||||
# lambda expression are disabled because they can allow forbidden action
|
||||
# example: (lambda x: x.append(1))(track.tags)
|
||||
raise Exception(f'Lambda expression are not allowed.')
|
||||
|
||||
# when calling any function
|
||||
case ast.Call:
|
||||
# ban the function and method from the environment
|
||||
for callnode in ast.walk(node.func):
|
||||
if isinstance(callnode, ast.Attribute):
|
||||
for attrnode in ast.walk(callnode.value):
|
||||
if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_:
|
||||
raise Exception(f'Calling this function is not allowed : "{callnode.attr}"')
|
||||
|
||||
# when assigning a value with ":="
|
||||
case ast.NamedExpr:
|
||||
# embed the value into a deepcopy, to avoid interaction with class attribute
|
||||
node.value = ast.Call(
|
||||
func=ast.Name(id="deepcopy", ctx=ast.Load()),
|
||||
args=[node.value], keywords=[],
|
||||
)
|
||||
|
||||
# Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them
|
||||
case (
|
||||
ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval
|
||||
ast.Raise | ast.Assert | # Error should not be raised manually
|
||||
ast.Delete | # Value should not be deleted
|
||||
ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions
|
||||
ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check
|
||||
ast.Global | ast.Nonlocal | # Changing variables range could cause some issue
|
||||
ast.ClassDef | # Declaring class could maybe allow for dangerous calls
|
||||
ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case
|
||||
):
|
||||
raise Exception(f'Forbidden syntax : "{type(node).__name__}"')
|
||||
|
||||
if return_lambda:
|
||||
# if return_lambda is enabled, embed the whole expression into a lambda expression
|
||||
stmt.value = ast.Lambda(
|
||||
body=stmt.value,
|
||||
args=ast.arguments(
|
||||
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
|
||||
posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[],
|
||||
)
|
||||
)
|
||||
|
||||
# convert into a ast.Expression, object needed for the compilation
|
||||
expression: ast.Expression = ast.Expression(stmt.value)
|
||||
|
||||
# if a node have been altered, fix the missing locations
|
||||
ast.fix_missing_locations(expression)
|
||||
|
||||
# return the evaluated formula
|
||||
return eval(compile(expression, "<string>", "eval"), globals_, locals_)
|
||||
from source.safe_eval.safe_eval import safe_eval
|
||||
from source.safe_eval.multiple_safe_eval import multiple_safe_eval
|
||||
|
|
25
source/safe_eval/macros.py
Normal file
25
source/safe_eval/macros.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
import re
|
||||
|
||||
|
||||
MACRO_START, MACRO_END = "##", "##"
|
||||
|
||||
|
||||
class NotImplementedMacro(Exception):
|
||||
def __init__(self, macro: str):
|
||||
super().__init__(f"Invalid macro while parsing macros:\n{macro}")
|
||||
|
||||
|
||||
def replace_macro(template: str, macros: dict[str, str]) -> str:
|
||||
"""
|
||||
Replace all the macro defined in macro by their respective value
|
||||
:param template: template where to replace the macro
|
||||
:param macros: dictionary associating macro with their replacement
|
||||
:return: the template with macro replaced
|
||||
"""
|
||||
|
||||
def format_macro(match: re.Match) -> str:
|
||||
if (macro := macros.get(match.group(1).strip())) is None: raise NotImplementedMacro(macro)
|
||||
return macro
|
||||
|
||||
# match everything between MACRO_START and MACRO_END.
|
||||
return re.sub(rf"{MACRO_START}(.*?){MACRO_END}", format_macro, template)
|
26
source/safe_eval/multiple_safe_eval.py
Normal file
26
source/safe_eval/multiple_safe_eval.py
Normal file
|
@ -0,0 +1,26 @@
|
|||
import re
|
||||
|
||||
from source.safe_eval import safe_eval
|
||||
|
||||
|
||||
TOKEN_START, TOKEN_END = "{{", "}}"
|
||||
|
||||
|
||||
def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str:
|
||||
"""
|
||||
Similar to safe_eval, but expression need to be enclosed between "{{" and "}}".
|
||||
Example : "{{ track.author }} is the track creator !"
|
||||
"""
|
||||
|
||||
def format_part_template(match: re.Match) -> str:
|
||||
"""
|
||||
when a token is found, replace it by the corresponding value
|
||||
:param match: match in the format
|
||||
:return: corresponding value
|
||||
"""
|
||||
# get the token string without the brackets, then strip it. Also double antislash
|
||||
part_template = match.group(1).strip()
|
||||
return str(safe_eval(template=part_template, env=env, macros=macros))
|
||||
|
||||
# pass everything between TOKEN_START and TOKEN_END in the function
|
||||
return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template)
|
121
source/safe_eval/safe_eval.py
Normal file
121
source/safe_eval/safe_eval.py
Normal file
|
@ -0,0 +1,121 @@
|
|||
import ast
|
||||
import copy
|
||||
|
||||
from source.safe_eval.macros import replace_macro
|
||||
from source.safe_eval.safe_function import get_all_safe_functions
|
||||
|
||||
|
||||
# TODO: exception class
|
||||
|
||||
|
||||
# dict of every value used by every safe_eval call
|
||||
all_globals = {
|
||||
"__builtins__": {},
|
||||
"deepcopy": copy.deepcopy
|
||||
} | {
|
||||
func.__name__: func for func in get_all_safe_functions()
|
||||
}
|
||||
|
||||
|
||||
def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None,
|
||||
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
|
||||
"""
|
||||
Run a python code in an eval function, but avoid all potential dangerous function.
|
||||
:env: additional variables that will be used when evaluating the template
|
||||
:return_lambda: if enabled, return a lambda function instead of the result of the expression
|
||||
:lambda_args: arguments that the final lambda function can receive
|
||||
:macros: dictionary associating a macro name to a macro value
|
||||
|
||||
:return: the evaluated expression or the lambda expression
|
||||
"""
|
||||
|
||||
if len(template) == 0: return ""
|
||||
if env is None: env = {}
|
||||
if macros is None: macros = {}
|
||||
|
||||
# replace the macro in the template
|
||||
template = replace_macro(template=template, macros=macros)
|
||||
|
||||
# prepare the execution environment
|
||||
globals_ = all_globals | env
|
||||
locals_ = {}
|
||||
|
||||
# convert the template to an ast expression
|
||||
stmt: ast.stmt = ast.parse(template).body[0]
|
||||
if not isinstance(stmt, ast.Expr):
|
||||
raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"')
|
||||
|
||||
# check every node for disabled expression
|
||||
for node in ast.walk(stmt):
|
||||
match type(node):
|
||||
|
||||
# when accessing any attribute
|
||||
case ast.Attribute:
|
||||
# ban all magical function, disabling the __class__.__bases__[0] ... tricks
|
||||
if "__" in node.attr:
|
||||
raise Exception(f'Magic attribute are forbidden : "{node.attr}"')
|
||||
|
||||
# ban modification to environment
|
||||
if isinstance(node.ctx, ast.Store):
|
||||
raise Exception(f'Can\'t set value of attribute : "{node.attr}"')
|
||||
|
||||
# when accessing any variable
|
||||
case ast.Name:
|
||||
# ban modification to environment, but allow custom variable to be changed
|
||||
if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_:
|
||||
raise Exception(f'Can\'t set value of environment : "{node.id}"')
|
||||
|
||||
case ast.Lambda:
|
||||
# lambda expression are disabled because they can allow forbidden action
|
||||
# example: (lambda x: x.append(1))(track.tags)
|
||||
raise Exception(f'Lambda expression are not allowed.')
|
||||
|
||||
# when calling any function
|
||||
case ast.Call:
|
||||
# ban the function and method from the environment
|
||||
for callnode in ast.walk(node.func):
|
||||
if isinstance(callnode, ast.Attribute):
|
||||
for attrnode in ast.walk(callnode.value):
|
||||
if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_:
|
||||
raise Exception(f'Calling this function is not allowed : "{callnode.attr}"')
|
||||
|
||||
# when assigning a value with ":="
|
||||
case ast.NamedExpr:
|
||||
# embed the value into a deepcopy, to avoid interaction with class attribute
|
||||
node.value = ast.Call(
|
||||
func=ast.Name(id="deepcopy", ctx=ast.Load()),
|
||||
args=[node.value], keywords=[],
|
||||
)
|
||||
|
||||
# Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them
|
||||
case (
|
||||
ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval
|
||||
ast.Raise | ast.Assert | # Error should not be raised manually
|
||||
ast.Delete | # Value should not be deleted
|
||||
ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions
|
||||
ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check
|
||||
ast.Global | ast.Nonlocal | # Changing variables range could cause some issue
|
||||
ast.ClassDef | # Declaring class could maybe allow for dangerous calls
|
||||
ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case
|
||||
):
|
||||
raise Exception(f'Forbidden syntax : "{type(node).__name__}"')
|
||||
|
||||
if return_lambda:
|
||||
# if return_lambda is enabled, embed the whole expression into a lambda expression
|
||||
stmt.value = ast.Lambda(
|
||||
body=stmt.value,
|
||||
args=ast.arguments(
|
||||
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
|
||||
posonlyargs=[], kwonlyargs=[],
|
||||
kw_defaults=[], defaults=[],
|
||||
)
|
||||
)
|
||||
|
||||
# convert into a ast.Expression, object needed for the compilation
|
||||
expression: ast.Expression = ast.Expression(stmt.value)
|
||||
|
||||
# if a node have been altered, fix the missing locations
|
||||
ast.fix_missing_locations(expression)
|
||||
|
||||
# return the evaluated formula
|
||||
return eval(compile(expression, "<string>", "eval"), globals_, locals_)
|
|
@ -1,10 +1,6 @@
|
|||
|
||||
from typing import Callable, Generator
|
||||
|
||||
|
||||
# TODO: exception class
|
||||
|
||||
|
||||
def get_all_safe_functions() -> Generator[list[Callable], None, None]:
|
||||
"""
|
||||
Return all the safe function defined in safe_function
|
||||
|
@ -49,5 +45,5 @@ class safe_function:
|
|||
"""
|
||||
Allow a recursive safe_eval, but without the lambda functionality
|
||||
"""
|
||||
from source.safe_eval import safe_eval
|
||||
return safe_eval(template, env)
|
||||
from source.safe_eval.safe_eval import safe_eval
|
||||
return safe_eval(template=template, env=env)
|
||||
|
|
Loading…
Reference in a new issue