mirror of
https://github.com/Faraphel/Atlas-Install.git
synced 2025-07-03 19:28:25 +02:00
started the rewrote of safe_eval to use AST (making it way easier to read and edit) and could fix some security issue.
Also allow for lambda expression to avoid recompiling and checking the expression at every call
This commit is contained in:
parent
1372a0eac2
commit
998d1274ef
3 changed files with 166 additions and 1 deletions
|
@ -5,7 +5,7 @@ common_token_map = { # these operators and function are considered safe to use
|
||||||
operator: operator
|
operator: operator
|
||||||
for operator in
|
for operator in
|
||||||
[">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<",
|
[">=", "<=", "<<", ">>", "+", "-", "*", "/", "%", "**", ",", "(", ")", "[", "]", "==", "!=", "in", ">", "<",
|
||||||
"and", "or", "&", "|", "^", "~", ":", "{", "}", "isinstance", "issubclass", "not", "is", "if", "else", "abs", "int",
|
"and", "or", "&", "|", "^", "~", ":", "{", "}", "abs", "int",
|
||||||
"bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip",
|
"bin", "hex", "oct", "chr", "ord", "len", "str", "bool", "float", "round", "min", "max", "sum", "zip",
|
||||||
"any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None",
|
"any", "all", "reversed", "enumerate", "list", "sorted", "hasattr", "for", "range", "type", "repr", "None",
|
||||||
"True", "False", "getattr", "dict"
|
"True", "False", "getattr", "dict"
|
||||||
|
|
112
source/safe_eval/__init__.py
Normal file
112
source/safe_eval/__init__.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
import ast
|
||||||
|
import copy
|
||||||
|
|
||||||
|
from source.safe_eval.safe_function import get_all_safe_functions
|
||||||
|
|
||||||
|
|
||||||
|
# dict of every value used by every safe_eval call
|
||||||
|
all_globals = {
|
||||||
|
"__builtins__": {},
|
||||||
|
"deepcopy": copy.deepcopy
|
||||||
|
} | {
|
||||||
|
func.__name__: func for func in get_all_safe_functions()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def safe_eval(template: str, env: dict[str, any] = None,
|
||||||
|
return_lambda: bool = False, lambda_args: list[str] = None) -> any:
|
||||||
|
"""
|
||||||
|
Run a python code in an eval function, but avoid all potential dangerous function.
|
||||||
|
:env: additional variables that will be used when evaluating the template
|
||||||
|
:return_lambda: if enabled, return a lambda function instead of the result of the expression
|
||||||
|
:lambda_args: arguments that the final lambda function can receive
|
||||||
|
|
||||||
|
:return: the evaluated expression or the lambda expression
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(template) == 0: return ""
|
||||||
|
if env is None: env = {}
|
||||||
|
|
||||||
|
# prepare the execution environment
|
||||||
|
globals_ = all_globals | env
|
||||||
|
locals_ = {}
|
||||||
|
|
||||||
|
# convert the template to an ast expression
|
||||||
|
stmt: ast.stmt = ast.parse(template).body[0]
|
||||||
|
if not isinstance(stmt, ast.Expr):
|
||||||
|
raise Exception(f'Invalid ast type for safe_eval : "{type(stmt).__name__}"')
|
||||||
|
|
||||||
|
# check every node for disabled expression
|
||||||
|
for node in ast.walk(stmt):
|
||||||
|
match type(node):
|
||||||
|
|
||||||
|
# when accessing any attribute
|
||||||
|
case ast.Attribute:
|
||||||
|
# ban all magical function, disabling the __class__.__bases__[0] ... tricks
|
||||||
|
if "__" in node.attr:
|
||||||
|
raise Exception(f'Magic attribute are forbidden : "{node.attr}"')
|
||||||
|
|
||||||
|
# ban modification to environment
|
||||||
|
if isinstance(node.ctx, ast.Store):
|
||||||
|
raise Exception(f'Can\'t set value of attribute : "{node.attr}"')
|
||||||
|
|
||||||
|
# when accessing any variable
|
||||||
|
case ast.Name:
|
||||||
|
# ban modification to environment, but allow custom variable to be changed
|
||||||
|
if isinstance(node.ctx, ast.Store) and node.id in globals_ | locals_:
|
||||||
|
raise Exception(f'Can\'t set value of environment : "{node.id}"')
|
||||||
|
|
||||||
|
case ast.Lambda:
|
||||||
|
# lambda expression are disabled because they can allow forbidden action
|
||||||
|
# example: (lambda x: x.append(1))(track.tags)
|
||||||
|
raise Exception(f'Lambda expression are not allowed.')
|
||||||
|
|
||||||
|
# when calling any function
|
||||||
|
case ast.Call:
|
||||||
|
# ban the function and method from the environment
|
||||||
|
for callnode in ast.walk(node.func):
|
||||||
|
if isinstance(callnode, ast.Attribute):
|
||||||
|
for attrnode in ast.walk(callnode.value):
|
||||||
|
if isinstance(attrnode, ast.Name) and attrnode.id in globals_ | locals_:
|
||||||
|
raise Exception(f'Calling this function is not allowed : "{callnode.attr}"')
|
||||||
|
|
||||||
|
# when assigning a value with ":="
|
||||||
|
case ast.NamedExpr:
|
||||||
|
# embed the value into a deepcopy, to avoid interaction with class attribute
|
||||||
|
node.value = ast.Call(
|
||||||
|
func=ast.Name(id="deepcopy", ctx=ast.Load()),
|
||||||
|
args=[node.value], keywords=[],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Forbidden type. Some of them can't be accessed with the eval mode, but just in case, still ban them
|
||||||
|
case (
|
||||||
|
ast.Assign | ast.AugAssign | # Assign should only be done by ":=" with check in eval
|
||||||
|
ast.Raise | ast.Assert | # Error should not be raised manually
|
||||||
|
ast.Delete | # Value should not be deleted
|
||||||
|
ast.Import | ast.ImportFrom | # Import could lead to extremely dangerous functions
|
||||||
|
ast.Lambda | ast.FunctionDef | # Defining functions can allow skipping some check
|
||||||
|
ast.Global | ast.Nonlocal | # Changing variables range could cause some issue
|
||||||
|
ast.ClassDef | # Declaring class could maybe allow for dangerous calls
|
||||||
|
ast.AsyncFor | ast.AsyncWith | ast.AsyncFunctionDef | ast.Await # Just in case
|
||||||
|
):
|
||||||
|
raise Exception(f'Forbidden syntax : "{type(node).__name__}"')
|
||||||
|
|
||||||
|
if return_lambda:
|
||||||
|
# if return_lambda is enabled, embed the whole expression into a lambda expression
|
||||||
|
stmt.value = ast.Lambda(
|
||||||
|
body=stmt.value,
|
||||||
|
args=ast.arguments(
|
||||||
|
args=[ast.arg(arg=lambda_arg) for lambda_arg in lambda_args],
|
||||||
|
posonlyargs=[], kwonlyargs=[],
|
||||||
|
kw_defaults=[], defaults=[],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# convert into a ast.Expression, object needed for the compilation
|
||||||
|
expression: ast.Expression = ast.Expression(stmt.value)
|
||||||
|
|
||||||
|
# if a node have been altered, fix the missing locations
|
||||||
|
ast.fix_missing_locations(expression)
|
||||||
|
|
||||||
|
# return the evaluated formula
|
||||||
|
return eval(compile(expression, "<string>", "eval"), globals_, locals_)
|
53
source/safe_eval/safe_function.py
Normal file
53
source/safe_eval/safe_function.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
|
||||||
|
from typing import Callable, Generator
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: exception class
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_safe_functions() -> Generator[list[Callable], None, None]:
|
||||||
|
"""
|
||||||
|
Return all the safe function defined in safe_function
|
||||||
|
"""
|
||||||
|
for obj_name in filter(lambda obj_name: "__" not in obj_name, dir(safe_function)):
|
||||||
|
obj = getattr(safe_function, obj_name)
|
||||||
|
if callable(obj): yield obj
|
||||||
|
|
||||||
|
yield from safe_builtins
|
||||||
|
|
||||||
|
|
||||||
|
# these functions are builtins function that don't need additional check to be safe
|
||||||
|
safe_builtins: list[Callable] = [
|
||||||
|
abs, all, any, ascii, bin, bool, chr, dict, enumerate, float, hasattr, hex, int,
|
||||||
|
isinstance, issubclass, len, list, max, min, oct, ord, range, repr, reversed,
|
||||||
|
round, sorted, str, sum, tuple, type, zip,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class safe_function:
|
||||||
|
"""
|
||||||
|
Safer version of some python builtins function
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def getattr(obj: any, name: str, default=None) -> any:
|
||||||
|
"""
|
||||||
|
Same as normal getattr, but magic attribute are banned
|
||||||
|
"""
|
||||||
|
if "__" in name: raise Exception(f'Magic method are not allowed : "{name}"')
|
||||||
|
return getattr(obj, name, default)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def type(obj: any):
|
||||||
|
"""
|
||||||
|
Same as normal type, but the syntax with 3 arguments (to create new type) is banned
|
||||||
|
"""
|
||||||
|
return type(obj)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def eval(template: str, env: dict | None = None):
|
||||||
|
"""
|
||||||
|
Allow a recursive safe_eval, but without the lambda functionality
|
||||||
|
"""
|
||||||
|
from source.safe_eval import safe_eval
|
||||||
|
return safe_eval(template, env)
|
Loading…
Reference in a new issue