started the rewrote of safe_eval to use AST (making it way easier to read and edit) and could fix some security issue.

Also allow for lambda expression to avoid recompiling and checking the expression at every call
This commit is contained in:
Faraphel 2022-08-06 00:12:02 +02:00
parent 1372a0eac2
commit 998d1274ef
3 changed files with 166 additions and 1 deletions

View file

@ -5,7 +5,7 @@ common_token_map = { # these operators and function are considered safe to use
operator: operator operator: operator
for operator in for operator in
[">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<", [">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<",
"and", "or", "&", "|", "^", "~", ":", "{", "}", "isinstance", "issubclass", "not", "is", "if", "else", "abs", "int", "and", "or", "&", "|", "^", "~", ":", "{", "}", "abs", "int",
"bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip", "bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip",
"any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None", "any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None",
"True", "False", "getattr", "dict" "True", "False", "getattr", "dict"

View file

@ -0,0 +1,112 @@
import ast
import copy
from source.safe_eval.safe_function import get_all_safe_functions
# dict of every value used by every safe_eval call
all_globals = {
"__builtins__": {},
"deepcopy": copy.deepcopy
} | {
func.__name__: func for func in get_all_safe_functions()
}
def safe_eval(template: str, env: dict[str, any] = None,
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
"""
Run a python code in an eval function, but avoid all potential dangerous function.
:env: additional variables that will be used when evaluating the template
:return_lambda: if enabled, return a lambda function instead of the result of the expression
:lambda_args: arguments that the final lambda function can receive
:return: the evaluated expression or the lambda expression
"""
if len(template) == 0: return ""
if env is None: env = {}
# prepare the execution environment
globals_ = all_globals | env
locals_ = {}
# convert the template to an ast expression
stmt: ast.stmt = ast.parse(template).body[0]
if not isinstance(stmt, ast.Expr):
raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"')
# check every node for disabled expression
for node in ast.walk(stmt):
match type(node):
# when accessing any attribute
case ast.Attribute:
# ban all magical function, disabling the __class__.__bases__[0] ... tricks
if "__" in node.attr:
raise Exception(f'Magic attribute are forbidden : "{node.attr}"')
# ban modification to environment
if isinstance(node.ctx, ast.Store):
raise Exception(f'Can\'t set value of attribute : "{node.attr}"')
# when accessing any variable
case ast.Name:
# ban modification to environment, but allow custom variable to be changed
if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_:
raise Exception(f'Can\'t set value of environment : "{node.id}"')
case ast.Lambda:
# lambda expression are disabled because they can allow forbidden action
# example: (lambda x: x.append(1))(track.tags)
raise Exception(f'Lambda expression are not allowed.')
# when calling any function
case ast.Call:
# ban the function and method from the environment
for callnode in ast.walk(node.func):
if isinstance(callnode, ast.Attribute):
for attrnode in ast.walk(callnode.value):
if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_:
raise Exception(f'Calling this function is not allowed : "{callnode.attr}"')
# when assigning a value with ":="
case ast.NamedExpr:
# embed the value into a deepcopy, to avoid interaction with class attribute
node.value = ast.Call(
func=ast.Name(id="deepcopy", ctx=ast.Load()),
args=[node.value], keywords=[],
)
# Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them
case (
ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval
ast.Raise | ast.Assert | # Error should not be raised manually
ast.Delete | # Value should not be deleted
ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions
ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check
ast.Global | ast.Nonlocal | # Changing variables range could cause some issue
ast.ClassDef | # Declaring class could maybe allow for dangerous calls
ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case
):
raise Exception(f'Forbidden syntax : "{type(node).__name__}"')
if return_lambda:
# if return_lambda is enabled, embed the whole expression into a lambda expression
stmt.value = ast.Lambda(
body=stmt.value,
args=ast.arguments(
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
posonlyargs=[], kwonlyargs=[],
kw_defaults=[], defaults=[],
)
)
# convert into a ast.Expression, object needed for the compilation
expression: ast.Expression = ast.Expression(stmt.value)
# if a node have been altered, fix the missing locations
ast.fix_missing_locations(expression)
# return the evaluated formula
return eval(compile(expression, "<string>", "eval"), globals_, locals_)

View file

@ -0,0 +1,53 @@
from typing import Callable, Generator
# TODO: exception class
def get_all_safe_functions() -> Generator[list[Callable], None, None]:
"""
Return all the safe function defined in safe_function
"""
for obj_name in filter(lambda obj_name: "__" not in obj_name, dir(safe_function)):
obj = getattr(safe_function, obj_name)
if callable(obj): yield obj
yield from safe_builtins
# these functions are builtins function that don't need additional check to be safe
safe_builtins: list[Callable] = [
abs, all, any, ascii, bin, bool, chr, dict, enumerate, float, hasattr, hex, int,
isinstance, issubclass, len, list, max, min, oct, ord, range, repr, reversed,
round, sorted, str, sum, tuple, type, zip,
]
class safe_function:
"""
Safer version of some python builtins function
"""
@staticmethod
def getattr(obj: any, name: str, default=None) -> any:
"""
Same as normal getattr, but magic attribute are banned
"""
if "__" in name: raise Exception(f'Magic method are not allowed : "{name}"')
return getattr(obj, name, default)
@staticmethod
def type(obj: any):
"""
Same as normal type, but the syntax with 3 arguments (to create new type) is banned
"""
return type(obj)
@staticmethod
def eval(template: str, env: dict | None = None):
"""
Allow a recursive safe_eval, but without the lambda functionality
"""
from source.safe_eval import safe_eval
return safe_eval(template, env)