Module safe_eval
[hide private]
[frames] | no frames]

Source Code for Module safe_eval

  1  #    Copyright (C) 2007 Jeremy S. Sanders 
  2  #    Email: Jeremy Sanders <jeremy@jeremysanders.net> 
  3  # 
  4  #    This program is free software; you can redistribute it and/or modify 
  5  #    it under the terms of the GNU General Public License as published by 
  6  #    the Free Software Foundation; either version 2 of the License, or 
  7  #    (at your option) any later version. 
  8  # 
  9  #    This program is distributed in the hope that it will be useful, 
 10  #    but WITHOUT ANY WARRANTY; without even the implied warranty of 
 11  #    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the 
 12  #    GNU General Public License for more details. 
 13  # 
 14  #    You should have received a copy of the GNU General Public License 
 15  #    along with this program; if not, write to the Free Software 
 16  #    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA 
 17  ############################################################################### 
 18   
 19  # $Id: safe_eval.py 725 2008-02-12 12:27:38Z jeremysanders $ 
 20   
 21  """ 
 22  'Safe' python code evaluation 
 23   
 24  Based on the public domain code of Babar K. Zafar 
 25  http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/496746 
 26  (version 0.1 or 1.2 May 27 2006) 
 27   
 28  The idea is to examine the compiled ast tree and chack for invalid 
 29  entries 
 30   
 31  I have removed the timeout checking as this probably isn't a serious 
 32  problem for veusz documents 
 33  """ 
 34   
 35  import parser 
 36  import inspect, compiler.ast 
 37  import thread, time 
 38  import __builtin__ 
 39  import os.path 
 40   
 41  #import numpy as N 
 42   
 43  #---------------------------------------------------------------------- 
 44  # Module globals. 
 45  #---------------------------------------------------------------------- 
 46   
 47  # Toggle module level debugging mode. 
 48  DEBUG = False 
 49   
 50  # List of all AST node classes in compiler/ast.py. 
 51  all_ast_nodes = [name for (name, obj) in inspect.getmembers(compiler.ast) 
 52                   if inspect.isclass(obj) and 
 53                   issubclass(obj, compiler.ast.Node)] 
 54   
 55  # List of all builtin functions and types (ignoring exception classes). 
 56  all_builtins = [name for (name, obj) in inspect.getmembers(__builtin__) 
 57                  if inspect.isbuiltin(obj) or 
 58                  (inspect.isclass(obj) and not issubclass(obj, Exception))] 
 59   
 60  #---------------------------------------------------------------------- 
 61  # Utilties. 
 62  #---------------------------------------------------------------------- 
 63   
64 -def classname(obj):
65 return obj.__class__.__name__
66
67 -def get_node_lineno(node):
68 return (node.lineno) and node.lineno or 0
69 70 #---------------------------------------------------------------------- 71 # Restricted AST nodes & builtins. 72 #---------------------------------------------------------------------- 73 74 # Deny evaluation of code if the AST contain any of the following nodes: 75 unallowed_ast_nodes = ( 76 # 'Add', 'And', 77 # 'AssAttr', 'AssList', 'AssName', 'AssTuple', 78 # 'Assert', 'Assign', 'AugAssign', 79 'Backquote', 80 # 'Bitand', 'Bitor', 'Bitxor', 'Break', 81 # 'CallFunc', 'Class', 'Compare', 'Const', 'Continue', 82 # 'Decorators', 'Dict', 'Discard', 'Div', 83 # 'Ellipsis', 'EmptyNode', 84 'Exec', 85 # 'Expression', 'FloorDiv', 86 # 'For', 87 'From', 88 # 'Function', 89 # 'GenExpr', 'GenExprFor', 'GenExprIf', 'GenExprInner', 90 # 'Getattr', 'Global', 'If', 91 'Import', 92 # 'Invert', 93 # 'Keyword', 'Lambda', 'LeftShift', 94 # 'List', 'ListComp', 'ListCompFor', 'ListCompIf', 'Mod', 95 # 'Module', 96 # 'Mul', 'Name', 'Node', 'Not', 'Or', 'Pass', 'Power', 97 # 'Print', 'Printnl', 98 'Raise', 99 # 'Return', 'RightShift', 'Slice', 'Sliceobj', 100 # 'Stmt', 'Sub', 'Subscript', 101 'TryExcept', 'TryFinally', 102 # 'Tuple', 'UnaryAdd', 'UnarySub', 103 # 'While','Yield' 104 ) 105 106 # Deny evaluation of code if it tries to access any of the following builtins: 107 unallowed_builtins = ( 108 '__import__', 109 # 'abs', 'apply', 'basestring', 'bool', 'buffer', 110 # 'callable', 'chr', 'classmethod', 'cmp', 'coerce', 111 'compile', 112 # 'complex', 113 'delattr', 114 # 'dict', 115 'dir', 116 # 'divmod', 'enumerate', 117 'eval', 'execfile', 'file', 118 # 'filter', 'float', 'frozenset', 119 'getattr', 'globals', 'hasattr', 120 # 'hash', 'hex', 'id', 121 'input', 122 # 'int', 'intern', 'isinstance', 'issubclass', 'iter', 123 # 'len', 'list', 124 'locals', 125 # 'long', 'map', 'max', 'min', 'object', 'oct', 126 'open', 127 # 'ord', 'pow', 'property', 'range', 128 'raw_input', 129 # 'reduce', 130 'reload', 131 # 'repr', 'reversed', 'round', 'set', 132 'setattr', 133 # 'slice', 'sorted', 'staticmethod', 'str', 'sum', 'super', 134 # 'tuple', 'type', 'unichr', 'unicode', 135 'vars', 136 # 'xrange', 'zip' 137 ) 138 139 # checks there are no obvious mistakes above 140 for ast_name in unallowed_ast_nodes: 141 assert ast_name in all_ast_nodes 142 for name in unallowed_builtins: 143 assert name in all_builtins 144 145 # faster lookup 146 unallowed_ast_nodes = dict( (i, True) for i in unallowed_ast_nodes ) 147 unallowed_builtins = dict( (i, True) for i in unallowed_builtins ) 148 149 #---------------------------------------------------------------------- 150 # Restricted attributes. 151 #---------------------------------------------------------------------- 152 153 # In addition to these we deny access to all lowlevel attrs (__xxx__). 154 unallowed_attr = ( 155 'im_class', 'im_func', 'im_self', 156 'func_code', 'func_defaults', 'func_globals', 'func_name', 157 'tb_frame', 'tb_next', 158 'f_back', 'f_builtins', 'f_code', 'f_exc_traceback', 159 'f_exc_type', 'f_exc_value', 'f_globals', 'f_locals' ) 160 unallowed_attr = dict( (i, True) for i in unallowed_attr ) 161
162 -def is_unallowed_attr(name):
163 if name == '__file__': 164 return False 165 return ( (name[:2] == '__' and name[-2:] == '__') or 166 (name in unallowed_attr) )
167 168 #---------------------------------------------------------------------- 169 # SafeEvalVisitor. 170 #---------------------------------------------------------------------- 171
172 -class SafeEvalError(object):
173 """ 174 Base class for all which occur while walking the AST. 175 176 Attributes: 177 errmsg = short decription about the nature of the error 178 lineno = line offset to where error occured in source code 179 """
180 - def __init__(self, errmsg, lineno):
181 self.errmsg, self.lineno = errmsg, lineno
182 - def __str__(self):
183 return "line %d : %s" % (self.lineno, self.errmsg)
184
185 -class SafeEvalASTNodeError(SafeEvalError):
186 "Expression/statement in AST evaluates to a restricted AST node type." 187 pass
188 -class SafeEvalBuiltinError(SafeEvalError):
189 "Expression/statement in tried to access a restricted builtin." 190 pass
191 -class SafeEvalAttrError(SafeEvalError):
192 "Expression/statement in tried to access a restricted attribute." 193 pass
194
195 -class SafeEvalVisitor(object):
196 """ 197 Data-driven visitor which walks the AST for some code and makes 198 sure it doesn't contain any expression/statements which are 199 declared as restricted in 'unallowed_ast_nodes'. We'll also make 200 sure that there aren't any attempts to access/lookup restricted 201 builtin declared in 'unallowed_builtins'. By default we also won't 202 allow access to lowlevel stuff which can be used to dynamically 203 access non-local envrioments. 204 205 Interface: 206 walk(ast) = validate AST and return True if AST is 'safe' 207 208 Attributes: 209 errors = list of SafeEvalError if walk() returned False 210 211 Implementation: 212 213 The visitor will automatically generate methods for all of the 214 available AST node types and redirect them to self.ok or self.fail 215 reflecting the configuration in 'unallowed_ast_nodes'. While 216 walking the AST we simply forward the validating step to each of 217 node callbacks which take care of reporting errors. 218 """ 219
220 - def __init__(self):
221 "Initialize visitor by generating callbacks for all AST node types." 222 self.errors = [] 223 for ast_name in all_ast_nodes: 224 # Don't reset any overridden callbacks. 225 if not getattr(self, 'visit' + ast_name, None): 226 if ast_name in unallowed_ast_nodes: 227 setattr(self, 'visit' + ast_name, self.fail) 228 else: 229 setattr(self, 'visit' + ast_name, self.ok)
230
231 - def walk(self, ast):
232 "Validate each node in AST and return True if AST is 'safe'." 233 self.visit(ast) 234 return self.errors == []
235
236 - def visit(self, node, *args):
237 "Recursively validate node and all of its children." 238 fn = getattr(self, 'visit' + classname(node)) 239 if DEBUG: self.trace(node) 240 fn(node, *args) 241 for child in node.getChildNodes(): 242 self.visit(child, *args)
243
244 - def visitName(self, node, *args):
245 "Disallow any attempts to access a restricted builtin/attr." 246 name = node.getChildren()[0] 247 lineno = get_node_lineno(node) 248 if name in unallowed_builtins: 249 self.errors.append(SafeEvalBuiltinError( \ 250 "access to builtin '%s' is denied" % name, lineno)) 251 elif is_unallowed_attr(name): 252 self.errors.append(SafeEvalAttrError( \ 253 "access to attribute '%s' is denied" % name, lineno))
254
255 - def visitGetattr(self, node, *args):
256 "Disallow any attempts to access a restricted attribute." 257 name = node.attrname 258 lineno = get_node_lineno(node) 259 if is_unallowed_attr(name): 260 self.errors.append(SafeEvalAttrError( \ 261 "access to attribute '%s' is denied" % name, lineno))
262
263 - def ok(self, node, *args):
264 "Default callback for 'harmless' AST nodes." 265 pass
266
267 - def fail(self, node, *args):
268 "Default callback for unallowed AST nodes." 269 lineno = get_node_lineno(node) 270 self.errors.append(SafeEvalASTNodeError( \ 271 "execution of '%s' statements is denied" % classname(node), 272 lineno))
273
274 - def trace(self, node):
275 "Debugging utility for tracing the validation of AST nodes." 276 print classname(node) 277 for attr in dir(node): 278 if attr[:2] != '__': 279 print ' ' * 4, "%-15.15s" % attr, getattr(node, attr)
280 281 ########################################################################## 282 # Veusz evaluation functions 283 ########################################################################## 284
285 -def checkContextOkay(context):
286 """Check the context statements will be executed in. 287 288 Returns True if context is okay 289 """ 290 291 ctx_errkeys, ctx_errors = [], [] 292 for (key, obj) in context.items(): 293 if inspect.isbuiltin(obj): 294 ctx_errkeys.append(key) 295 ctx_errors.append("key '%s' : unallowed builtin %s" % (key, obj)) 296 if inspect.ismodule(obj): 297 ctx_errkeys.append(key) 298 ctx_errors.append("key '%s' : unallowed module %s" % (key, obj)) 299 300 if ctx_errors: 301 raise SafeEvalContextException(ctx_errkeys, ctx_errors)
302 303 ## set up environment in dict 304 #veusz_eval_context = {} 305 306 # XXX disabled numpy support 307 ## add callables (not modules) and floats which don't override builtins 308 #for name, val in N.__dict__.iteritems(): 309 # if ( (callable(val) or type(val)==float) and 310 # name not in __builtin__.__dict__ and 311 # name[:1] != '_' and name[-1:] != '_' ): 312 # veusz_eval_context[name] = val 313 314 ## useful safe functions 315 #veusz_eval_context['os_path_join'] = os.path.join 316 #veusz_eval_context['os_path_dirname'] = os.path.dirname 317
318 -def checkCode(code):
319 """Check code, returning errors (if any) or None if okay""" 320 321 try: 322 ast = compiler.parse(code) 323 except SyntaxError, e: 324 return [e] 325 checker = SafeEvalVisitor() 326 327 if checker.walk(ast): 328 return None 329 else: 330 return checker.errors
331 332 #---------------------------------------------------------------------- 333 # Safe 'eval' replacement. 334 #---------------------------------------------------------------------- 335
336 -class SafeEvalException(Exception):
337 "Base class for all safe-eval related errors." 338 pass
339
340 -class SafeEvalCodeException(SafeEvalException):
341 """ 342 Exception class for reporting all errors which occured while 343 validating AST for source code in safe_eval(). 344 345 Attributes: 346 code = raw source code which failed to validate 347 errors = list of SafeEvalError 348 """
349 - def __init__(self, code, errors):
350 self.code, self.errors = code, errors
351 - def __str__(self):
352 return '\n'.join([str(err) for err in self.errors])
353
354 -class SafeEvalContextException(SafeEvalException):
355 """ 356 Exception class for reporting unallowed objects found in the dict 357 intended to be used as the local enviroment in safe_eval(). 358 359 Attributes: 360 keys = list of keys of the unallowed objects 361 errors = list of strings describing the nature of the error 362 for each key in 'keys' 363 """
364 - def __init__(self, keys, errors):
365 self.keys, self.errors = keys, errors
366 - def __str__(self):
367 return '\n'.join([str(err) for err in self.errors])
368
369 -class SafeEvalTimeoutException(SafeEvalException):
370 """ 371 Exception class for reporting that code evaluation execeeded 372 the given timelimit. 373 374 Attributes: 375 timeout = time limit in seconds 376 """
377 - def __init__(self, timeout):
378 self.timeout = timeout
379 - def __str__(self):
380 return "Timeout limit execeeded (%s secs) during exec" % self.timeout
381
382 -def exec_timed(code, context, timeout_secs):
383 """ 384 Dynamically execute 'code' using 'context' as the global enviroment. 385 SafeEvalTimeoutException is raised if execution does not finish within 386 the given timelimit. 387 """ 388 assert(timeout_secs > 0) 389 390 signal_finished = False 391 392 def alarm(secs): 393 def wait(secs): 394 for n in xrange(timeout_secs): 395 time.sleep(1) 396 if signal_finished: break 397 else: 398 thread.interrupt_main()
399 thread.start_new_thread(wait, (secs,)) 400 401 try: 402 alarm(timeout_secs) 403 exec code in context 404 signal_finished = True 405 except KeyboardInterrupt: 406 raise SafeEvalTimeoutException(timeout_secs) 407
408 -def timed_safe_eval(code, context = {}, timeout_secs = 5):
409 """ 410 Validate source code and make sure it contains no unauthorized 411 expression/statements as configured via 'unallowed_ast_nodes' and 412 'unallowed_builtins'. By default this means that code is not 413 allowed import modules or access dangerous builtins like 'open' or 414 'eval'. If code is considered 'safe' it will be executed via 415 'exec' using 'context' as the global environment. More details on 416 how code is executed can be found in the Python Reference Manual 417 section 6.14 (ignore the remark on '__builtins__'). The 'context' 418 enviroment is also validated and is not allowed to contain modules 419 or builtins. The following exception will be raised on errors: 420 421 if 'context' contains unallowed objects = 422 SafeEvalContextException 423 424 if code is didn't validate and is considered 'unsafe' = 425 SafeEvalCodeException 426 427 if code did not execute within the given timelimit = 428 SafeEvalTimeoutException 429 """ 430 ctx_errkeys, ctx_errors = [], [] 431 for (key, obj) in context.items(): 432 if inspect.isbuiltin(obj): 433 ctx_errkeys.append(key) 434 ctx_errors.append("key '%s' : unallowed builtin %s" % (key, obj)) 435 if inspect.ismodule(obj): 436 ctx_errkeys.append(key) 437 ctx_errors.append("key '%s' : unallowed module %s" % (key, obj)) 438 439 if ctx_errors: 440 raise SafeEvalContextException(ctx_errkeys, ctx_errors) 441 442 ast = compiler.parse(code) 443 checker = SafeEvalVisitor() 444 445 if checker.walk(ast): 446 exec_timed(code, context, timeout_secs) 447 else: 448 raise SafeEvalCodeException(code, checker.errors)
449 450 #---------------------------------------------------------------------- 451 # Basic tests. 452 #---------------------------------------------------------------------- 453 454 import unittest 455
456 -class TestSafeEval(unittest.TestCase):
457 - def test_builtin(self):
458 # attempt to access a unsafe builtin 459 self.assertRaises(SafeEvalException, 460 timed_safe_eval, "open('test.txt', 'w')")
461
462 - def test_getattr(self):
463 # attempt to get arround direct attr access 464 self.assertRaises(SafeEvalException, \ 465 timed_safe_eval, "getattr(int, '__abs__')")
466
467 - def test_func_globals(self):
468 # attempt to access global enviroment where fun was defined 469 self.assertRaises(SafeEvalException, \ 470 timed_safe_eval, "def x(): pass; print x.func_globals")
471
472 - def test_lowlevel(self):
473 # lowlevel tricks to access 'object' 474 self.assertRaises(SafeEvalException, \ 475 timed_safe_eval, "().__class__.mro()[1].__subclasses__()")
476
477 - def test_timeout_ok(self):
478 # attempt to exectute 'slow' code which finishes within timelimit 479 def test(): time.sleep(2) 480 env = {'test':test} 481 timed_safe_eval("test()", env, timeout_secs = 5)
482
483 - def test_timeout_exceed(self):
484 # attempt to exectute code which never teminates 485 self.assertRaises(SafeEvalException, \ 486 timed_safe_eval, "while 1: pass")
487
488 - def test_invalid_context(self):
489 # can't pass an enviroment with modules or builtins 490 env = {'f' : __builtins__.open, 'g' : time} 491 self.assertRaises(SafeEvalException, \ 492 timed_safe_eval, "print 1", env)
493
494 - def test_callback(self):
495 # modify local variable via callback 496 self.value = 0 497 def test(): self.value = 1 498 env = {'test':test} 499 timed_safe_eval("test()", env) 500 self.assertEqual(self.value, 1)
501 502 if __name__ == "__main__": 503 unittest.main() 504 505 #---------------------------------------------------------------------- 506 # End unittests 507 #---------------------------------------------------------------------- 508