diff --git a/source/safe_eval.py b/source/safe_eval.py deleted file mode 100644 index 7d70595..0000000 --- a/source/safe_eval.py +++ /dev/null @@ -1,160 +0,0 @@ -import re -from typing import Callable - -common_token_map = { # these operators and function are considered safe to use in the template - operator: operator - for operator in - [">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<", - "and", "or", "&", "|", "^", "~", ":", "{", "}", "abs", "int", - "bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip", - "any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None", - "True", "False", "getattr", "dict" - ] -} | { # these methods are considered safe, except for the magic methods - f".{method}": f".{method}" - for method in dir(str) + dir(list) + dir(int) + dir(float) + dir(dict) - if not method.startswith("__") -} - -TOKEN_START, TOKEN_END = "{{", "}}" -MACRO_START, MACRO_END = "##", "##" - - -class TemplateParsingError(Exception): - def __init__(self, token: str): - super().__init__(f"Invalid token while parsing safe_eval:\n{token}") - - -class NotImplementedMacro(Exception): - def __init__(self, macro: str): - super().__init__(f"Invalid macro while parsing macros:\n{macro}") - - -class SafeFunction: - @classmethod - def get_all_safe_methods(cls) -> dict[str, Callable]: - """ - get all the safe methods defined by the class - :return: all the safe methods defined by the class - """ - return { - method: getattr(cls, method) - for method in dir(cls) - if not method.startswith("__") and method not in ["get_all_safe_methods"] - } - - @staticmethod - def getattr(obj: any, attr: str, default: any = None) -> any: - """ - Safe getattr, raise an error if the attribute is a function - :param obj: object to get the attribute from - :param attr: attribute name - :param default: default value if the attribute is not found - :return: the attribute value - """ - attr = getattr(obj, attr) if default is None else getattr(obj, attr, default) - if callable(attr): raise AttributeError(f"getattr can't be used for functions (tried: {attr})") - return attr - - @staticmethod - def type(obj: any) -> any: - """ - Safe type, can only be used to determinate the type of an object - (It can be used to create new class by using all the 3 args) - """ - return type(obj) - - -def replace_macro(template: str, macros: dict[str, str]) -> str: - """ - Replace all the macro defined in macro by their respective value - :param template: template where to replace the macro - :param macros: dictionary associating macro with their replacement - :return: the template with macro replaced - """ - - def format_macro(match: re.Match) -> str: - if (macro := macros.get(match.group(1).strip())) is None: raise NotImplementedMacro(macro) - return macro - - # match everything between MACRO_START and MACRO_END. - return re.sub(rf"{MACRO_START}(.*?){MACRO_END}", format_macro, template) - - -def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str: - """ - Evaluate the template and return the result in a safe way - :param env: variables to use when using eval - :param template: template to evaluate - :param macros: additional macro to replace in the template - """ - if env is None: env = {} - if macros is None: macros = {} - - template = replace_macro(template, macros) - token_map: dict[str, str] = common_token_map | {var: var for var in env} - final_token: str = "" - - def matched(match: re.Match | str | None, value: str = None) -> bool: - """ - check if token is matched, if yes, add it to the final token and remove it from the processing token - :param match: match object - :param value: if the match is a string, the value to replace the text with - :return: True if matched, False otherwise - """ - nonlocal final_token, template - - # if there is no match or the string is empty, return False - if not match: return False - - if isinstance(match, re.Match): - template_raw = template[match.end():] - value = match.group() - - else: - if not template.startswith(match): return False - template_raw = template[len(match):] - - template = template_raw.lstrip() - final_token += value + (len(template_raw) - len(template)) * " " - return True - - while template: # while there is still tokens to process - # if the section is a string, add it to the final token - # example : "hello", "hello \" world" - if matched(re.match(r'^(["\'])((\\{2})*|(.*?[^\\](\\{2})*))\1', template)): - continue - - # if the section is a float or an int, add it to the final token - # example : 102, 102.59 - if matched(re.match(r'^[0-9]+(?:\.[0-9]+)?', template)): - continue - - # if the section is a variable, operator or function, replace it by its value - # example : track.special, + - for key, value in token_map.items(): - if matched(key, value): break - - # else, the token is invalid, so raise an error - else: - raise TemplateParsingError(template) - - # if final_token is set, eval final_token and return the result - if final_token: return str(eval(final_token.replace("\\", "\\\\"), SafeFunction.get_all_safe_methods(), env)) - else: return final_token - - -def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str: - def format_part_template(match: re.Match) -> str: - """ - when a token is found, replace it by the corresponding value - :param match: match in the format - :return: corresponding value - """ - # get the token string without the brackets, then strip it. Also double antislash - part_template = match.group(1).strip() - return safe_eval(part_template, env, macros) - - # pass everything between TOKEN_START and TOKEN_END in the function - return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template) - diff --git a/source/safe_eval/__init__.py b/source/safe_eval/__init__.py index dc9f696..77babdb 100644 --- a/source/safe_eval/__init__.py +++ b/source/safe_eval/__init__.py @@ -1,112 +1,2 @@ -import ast -import copy - -from source.safe_eval.safe_function import get_all_safe_functions - - -# dict of every value used by every safe_eval call -all_globals = { - "__builtins__": {}, - "deepcopy": copy.deepcopy -} | { - func.__name__: func for func in get_all_safe_functions() -} - - -def safe_eval(template: str, env: dict[str, any] = None, - return_lambda: bool = False, lambda_args: list[str] = None) -> any: - """ - Run a python code in an eval function, but avoid all potential dangerous function. - :env: additional variables that will be used when evaluating the template - :return_lambda: if enabled, return a lambda function instead of the result of the expression - :lambda_args: arguments that the final lambda function can receive - - :return: the evaluated expression or the lambda expression - """ - - if len(template) == 0: return "" - if env is None: env = {} - - # prepare the execution environment - globals_ = all_globals | env - locals_ = {} - - # convert the template to an ast expression - stmt: ast.stmt = ast.parse(template).body[0] - if not isinstance(stmt, ast.Expr): - raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"') - - # check every node for disabled expression - for node in ast.walk(stmt): - match type(node): - - # when accessing any attribute - case ast.Attribute: - # ban all magical function, disabling the __class__.__bases__[0] ... tricks - if "__" in node.attr: - raise Exception(f'Magic attribute are forbidden : "{node.attr}"') - - # ban modification to environment - if isinstance(node.ctx, ast.Store): - raise Exception(f'Can\'t set value of attribute : "{node.attr}"') - - # when accessing any variable - case ast.Name: - # ban modification to environment, but allow custom variable to be changed - if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_: - raise Exception(f'Can\'t set value of environment : "{node.id}"') - - case ast.Lambda: - # lambda expression are disabled because they can allow forbidden action - # example: (lambda x: x.append(1))(track.tags) - raise Exception(f'Lambda expression are not allowed.') - - # when calling any function - case ast.Call: - # ban the function and method from the environment - for callnode in ast.walk(node.func): - if isinstance(callnode, ast.Attribute): - for attrnode in ast.walk(callnode.value): - if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_: - raise Exception(f'Calling this function is not allowed : "{callnode.attr}"') - - # when assigning a value with ":=" - case ast.NamedExpr: - # embed the value into a deepcopy, to avoid interaction with class attribute - node.value = ast.Call( - func=ast.Name(id="deepcopy", ctx=ast.Load()), - args=[node.value], keywords=[], - ) - - # Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them - case ( - ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval - ast.Raise | ast.Assert | # Error should not be raised manually - ast.Delete | # Value should not be deleted - ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions - ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check - ast.Global | ast.Nonlocal | # Changing variables range could cause some issue - ast.ClassDef | # Declaring class could maybe allow for dangerous calls - ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case - ): - raise Exception(f'Forbidden syntax : "{type(node).__name__}"') - - if return_lambda: - # if return_lambda is enabled, embed the whole expression into a lambda expression - stmt.value = ast.Lambda( - body=stmt.value, - args=ast.arguments( - args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args], - posonlyargs=[], kwonlyargs=[], - kw_defaults=[], defaults=[], - ) - ) - - # convert into a ast.Expression, object needed for the compilation - expression: ast.Expression = ast.Expression(stmt.value) - - # if a node have been altered, fix the missing locations - ast.fix_missing_locations(expression) - - # return the evaluated formula - return eval(compile(expression, "", "eval"), globals_, locals_) +from source.safe_eval.safe_eval import safe_eval +from source.safe_eval.multiple_safe_eval import multiple_safe_eval diff --git a/source/safe_eval/macros.py b/source/safe_eval/macros.py new file mode 100644 index 0000000..8e831b9 --- /dev/null +++ b/source/safe_eval/macros.py @@ -0,0 +1,25 @@ +import re + + +MACRO_START, MACRO_END = "##", "##" + + +class NotImplementedMacro(Exception): + def __init__(self, macro: str): + super().__init__(f"Invalid macro while parsing macros:\n{macro}") + + +def replace_macro(template: str, macros: dict[str, str]) -> str: + """ + Replace all the macro defined in macro by their respective value + :param template: template where to replace the macro + :param macros: dictionary associating macro with their replacement + :return: the template with macro replaced + """ + + def format_macro(match: re.Match) -> str: + if (macro := macros.get(match.group(1).strip())) is None: raise NotImplementedMacro(macro) + return macro + + # match everything between MACRO_START and MACRO_END. + return re.sub(rf"{MACRO_START}(.*?){MACRO_END}", format_macro, template) diff --git a/source/safe_eval/multiple_safe_eval.py b/source/safe_eval/multiple_safe_eval.py new file mode 100644 index 0000000..5f85248 --- /dev/null +++ b/source/safe_eval/multiple_safe_eval.py @@ -0,0 +1,26 @@ +import re + +from source.safe_eval import safe_eval + + +TOKEN_START, TOKEN_END = "{{", "}}" + + +def multiple_safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None) -> str: + """ + Similar to safe_eval, but expression need to be enclosed between "{{" and "}}". + Example : "{{ track.author }} is the track creator !" + """ + + def format_part_template(match: re.Match) -> str: + """ + when a token is found, replace it by the corresponding value + :param match: match in the format + :return: corresponding value + """ + # get the token string without the brackets, then strip it. Also double antislash + part_template = match.group(1).strip() + return str(safe_eval(template=part_template, env=env, macros=macros)) + + # pass everything between TOKEN_START and TOKEN_END in the function + return re.sub(rf"{TOKEN_START}(.*?){TOKEN_END}", format_part_template, template) diff --git a/source/safe_eval/safe_eval.py b/source/safe_eval/safe_eval.py new file mode 100644 index 0000000..7a65904 --- /dev/null +++ b/source/safe_eval/safe_eval.py @@ -0,0 +1,121 @@ +import ast +import copy + +from source.safe_eval.macros import replace_macro +from source.safe_eval.safe_function import get_all_safe_functions + + +# TODO: exception class + + +# dict of every value used by every safe_eval call +all_globals = { + "__builtins__": {}, + "deepcopy": copy.deepcopy +} | { + func.__name__: func for func in get_all_safe_functions() +} + + +def safe_eval(template: str, env: dict[str, any] = None, macros: dict[str, str] = None, + return_lambda: bool = False, lambda_args: list[str] = None) -> any: + """ + Run a python code in an eval function, but avoid all potential dangerous function. + :env: additional variables that will be used when evaluating the template + :return_lambda: if enabled, return a lambda function instead of the result of the expression + :lambda_args: arguments that the final lambda function can receive + :macros: dictionary associating a macro name to a macro value + + :return: the evaluated expression or the lambda expression + """ + + if len(template) == 0: return "" + if env is None: env = {} + if macros is None: macros = {} + + # replace the macro in the template + template = replace_macro(template=template, macros=macros) + + # prepare the execution environment + globals_ = all_globals | env + locals_ = {} + + # convert the template to an ast expression + stmt: ast.stmt = ast.parse(template).body[0] + if not isinstance(stmt, ast.Expr): + raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"') + + # check every node for disabled expression + for node in ast.walk(stmt): + match type(node): + + # when accessing any attribute + case ast.Attribute: + # ban all magical function, disabling the __class__.__bases__[0] ... tricks + if "__" in node.attr: + raise Exception(f'Magic attribute are forbidden : "{node.attr}"') + + # ban modification to environment + if isinstance(node.ctx, ast.Store): + raise Exception(f'Can\'t set value of attribute : "{node.attr}"') + + # when accessing any variable + case ast.Name: + # ban modification to environment, but allow custom variable to be changed + if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_: + raise Exception(f'Can\'t set value of environment : "{node.id}"') + + case ast.Lambda: + # lambda expression are disabled because they can allow forbidden action + # example: (lambda x: x.append(1))(track.tags) + raise Exception(f'Lambda expression are not allowed.') + + # when calling any function + case ast.Call: + # ban the function and method from the environment + for callnode in ast.walk(node.func): + if isinstance(callnode, ast.Attribute): + for attrnode in ast.walk(callnode.value): + if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_: + raise Exception(f'Calling this function is not allowed : "{callnode.attr}"') + + # when assigning a value with ":=" + case ast.NamedExpr: + # embed the value into a deepcopy, to avoid interaction with class attribute + node.value = ast.Call( + func=ast.Name(id="deepcopy", ctx=ast.Load()), + args=[node.value], keywords=[], + ) + + # Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them + case ( + ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval + ast.Raise | ast.Assert | # Error should not be raised manually + ast.Delete | # Value should not be deleted + ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions + ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check + ast.Global | ast.Nonlocal | # Changing variables range could cause some issue + ast.ClassDef | # Declaring class could maybe allow for dangerous calls + ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case + ): + raise Exception(f'Forbidden syntax : "{type(node).__name__}"') + + if return_lambda: + # if return_lambda is enabled, embed the whole expression into a lambda expression + stmt.value = ast.Lambda( + body=stmt.value, + args=ast.arguments( + args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args], + posonlyargs=[], kwonlyargs=[], + kw_defaults=[], defaults=[], + ) + ) + + # convert into a ast.Expression, object needed for the compilation + expression: ast.Expression = ast.Expression(stmt.value) + + # if a node have been altered, fix the missing locations + ast.fix_missing_locations(expression) + + # return the evaluated formula + return eval(compile(expression, "", "eval"), globals_, locals_) diff --git a/source/safe_eval/safe_function.py b/source/safe_eval/safe_function.py index 10e12de..3ebb3e3 100644 --- a/source/safe_eval/safe_function.py +++ b/source/safe_eval/safe_function.py @@ -1,10 +1,6 @@ - from typing import Callable, Generator -# TODO: exception class - - def get_all_safe_functions() -> Generator[list[Callable], None, None]: """ Return all the safe function defined in safe_function @@ -49,5 +45,5 @@ class safe_function: """ Allow a recursive safe_eval, but without the lambda functionality """ - from source.safe_eval import safe_eval - return safe_eval(template, env) + from source.safe_eval.safe_eval import safe_eval + return safe_eval(template=template, env=env)