diff --git a/source/safe_eval.py b/source/safe_eval.py index ffaf65b..7d70595 100644 --- a/source/safe_eval.py +++ b/source/safe_eval.py @@ -5,7 +5,7 @@ common_token_map = { # these operators and function are considered safe to use operator: operator for operator in [">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<", - "and", "or", "&", "|", "^", "~", ":", "{", "}", "isinstance", "issubclass", "not", "is", "if", "else", "abs", "int", + "and", "or", "&", "|", "^", "~", ":", "{", "}", "abs", "int", "bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip", "any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None", "True", "False", "getattr", "dict" diff --git a/source/safe_eval/__init__.py b/source/safe_eval/__init__.py new file mode 100644 index 0000000..dc9f696 --- /dev/null +++ b/source/safe_eval/__init__.py @@ -0,0 +1,112 @@ +import ast +import copy + +from source.safe_eval.safe_function import get_all_safe_functions + + +# dict of every value used by every safe_eval call +all_globals = { + "__builtins__": {}, + "deepcopy": copy.deepcopy +} | { + func.__name__: func for func in get_all_safe_functions() +} + + +def safe_eval(template: str, env: dict[str, any] = None, + return_lambda: bool = False, lambda_args: list[str] = None) -> any: + """ + Run a python code in an eval function, but avoid all potential dangerous function. + :env: additional variables that will be used when evaluating the template + :return_lambda: if enabled, return a lambda function instead of the result of the expression + :lambda_args: arguments that the final lambda function can receive + + :return: the evaluated expression or the lambda expression + """ + + if len(template) == 0: return "" + if env is None: env = {} + + # prepare the execution environment + globals_ = all_globals | env + locals_ = {} + + # convert the template to an ast expression + stmt: ast.stmt = ast.parse(template).body[0] + if not isinstance(stmt, ast.Expr): + raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"') + + # check every node for disabled expression + for node in ast.walk(stmt): + match type(node): + + # when accessing any attribute + case ast.Attribute: + # ban all magical function, disabling the __class__.__bases__[0] ... tricks + if "__" in node.attr: + raise Exception(f'Magic attribute are forbidden : "{node.attr}"') + + # ban modification to environment + if isinstance(node.ctx, ast.Store): + raise Exception(f'Can\'t set value of attribute : "{node.attr}"') + + # when accessing any variable + case ast.Name: + # ban modification to environment, but allow custom variable to be changed + if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_: + raise Exception(f'Can\'t set value of environment : "{node.id}"') + + case ast.Lambda: + # lambda expression are disabled because they can allow forbidden action + # example: (lambda x: x.append(1))(track.tags) + raise Exception(f'Lambda expression are not allowed.') + + # when calling any function + case ast.Call: + # ban the function and method from the environment + for callnode in ast.walk(node.func): + if isinstance(callnode, ast.Attribute): + for attrnode in ast.walk(callnode.value): + if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_: + raise Exception(f'Calling this function is not allowed : "{callnode.attr}"') + + # when assigning a value with ":=" + case ast.NamedExpr: + # embed the value into a deepcopy, to avoid interaction with class attribute + node.value = ast.Call( + func=ast.Name(id="deepcopy", ctx=ast.Load()), + args=[node.value], keywords=[], + ) + + # Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them + case ( + ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval + ast.Raise | ast.Assert | # Error should not be raised manually + ast.Delete | # Value should not be deleted + ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions + ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check + ast.Global | ast.Nonlocal | # Changing variables range could cause some issue + ast.ClassDef | # Declaring class could maybe allow for dangerous calls + ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case + ): + raise Exception(f'Forbidden syntax : "{type(node).__name__}"') + + if return_lambda: + # if return_lambda is enabled, embed the whole expression into a lambda expression + stmt.value = ast.Lambda( + body=stmt.value, + args=ast.arguments( + args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args], + posonlyargs=[], kwonlyargs=[], + kw_defaults=[], defaults=[], + ) + ) + + # convert into a ast.Expression, object needed for the compilation + expression: ast.Expression = ast.Expression(stmt.value) + + # if a node have been altered, fix the missing locations + ast.fix_missing_locations(expression) + + # return the evaluated formula + return eval(compile(expression, "", "eval"), globals_, locals_) diff --git a/source/safe_eval/safe_function.py b/source/safe_eval/safe_function.py new file mode 100644 index 0000000..10e12de --- /dev/null +++ b/source/safe_eval/safe_function.py @@ -0,0 +1,53 @@ + +from typing import Callable, Generator + + +# TODO: exception class + + +def get_all_safe_functions() -> Generator[list[Callable], None, None]: + """ + Return all the safe function defined in safe_function + """ + for obj_name in filter(lambda obj_name: "__" not in obj_name, dir(safe_function)): + obj = getattr(safe_function, obj_name) + if callable(obj): yield obj + + yield from safe_builtins + + +# these functions are builtins function that don't need additional check to be safe +safe_builtins: list[Callable] = [ + abs, all, any, ascii, bin, bool, chr, dict, enumerate, float, hasattr, hex, int, + isinstance, issubclass, len, list, max, min, oct, ord, range, repr, reversed, + round, sorted, str, sum, tuple, type, zip, +] + + +class safe_function: + """ + Safer version of some python builtins function + """ + + @staticmethod + def getattr(obj: any, name: str, default=None) -> any: + """ + Same as normal getattr, but magic attribute are banned + """ + if "__" in name: raise Exception(f'Magic method are not allowed : "{name}"') + return getattr(obj, name, default) + + @staticmethod + def type(obj: any): + """ + Same as normal type, but the syntax with 3 arguments (to create new type) is banned + """ + return type(obj) + + @staticmethod + def eval(template: str, env: dict | None = None): + """ + Allow a recursive safe_eval, but without the lambda functionality + """ + from source.safe_eval import safe_eval + return safe_eval(template, env)