Python ast 模块,NodeTransformer() 实例源码
我们从Python开源项目中,提取了以下50个代码示例,用于说明如何使用ast.NodeTransformer()。
def filterstr_to_filterfunc(filter_str: str, logged_in: bool) -> Callable[['Post'], bool]:
"""Takes an --only-if=... filter specification and makes a filter_func Callable out of it."""
# The filter_str is parsed,then all names occurring in its AST are replaced by loads to post.<name>. A
# function Post->bool is returned which evaluates the filter with the post as 'post' in its namespace.
class TransformFilterast(ast.NodeTransformer):
def visit_Name(self, node: ast.Name):
# pylint:disable=invalid-name,no-self-use
if not isinstance(node.ctx, ast.Load):
raise invalidargumentexception("Invalid filter: Modifying variables ({}) not allowed.".format(node.id))
if not hasattr(Post, node.id):
raise invalidargumentexception("Invalid filter: Name {} is not defined.".format(node.id))
if node.id in Post.LOGIN_REQUIRING_PROPERTIES and not logged_in:
raise invalidargumentexception("Invalid filter: Name {} requires being logged in.".format(node.id))
new_node = ast.Attribute(ast.copy_location(ast.Name('post', ast.Load()), node), node.id,
ast.copy_location(ast.Load(), node))
return ast.copy_location(new_node, node)
input_filename = '<--only-if parameter>'
compiled_filter = compile(TransformFilterast().visit(ast.parse(filter_str, filename=input_filename, mode='eval')),
filename=input_filename, mode='eval')
def filterfunc(post: 'Post') -> bool:
# pylint:disable=eval-used
return bool(eval(compiled_filter, {'post': post}))
return filterfunc
def subs(root, **kwargs):
'''Substitute ast.Name nodes for numbers in root using the mapping in
kwargs. Returns a new copy of root.
'''
root = copy.deepcopy(root)
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
return node
def visit_Name(self, node):
if node.id in kwargs and not isinstance(node.ctx, ast.Store):
replacement = kwargs[node.id]
if isinstance(replacement, int):
return ast.copy_location(ast.Num(n=replacement), node)
else:
return copy.copy(replacement)
else:
return node
return Transformer().visit(root)
def sub_subscript(root, subs):
root = copy.deepcopy(root)
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
return node
def visit_Subscript(self, node):
self.generic_visit(node)
try:
node_tup = subscript_to_tuple(node)
if node_tup in subs:
return subs[node_tup]
else:
return node
except ValueError:
return node
return Transformer().visit(root)
def generic_visit(self, node):
super(NodeTransformer, self).generic_visit(node)
if hasattr(node, 'body') and type(node.body) is list:
returns = [i for i, child in enumerate(node.body) if type(child) is ast.Return]
if len(returns) > 0:
for wait in self.get_waits():
node.body.insert(returns[0], wait)
inserts = []
for i, child in enumerate(node.body):
if type(child) is ast.Expr and self.is_concurrent_call(child.value):
self.encounter_call(child.value)
elif self.is_valid_assignment(child):
call = child.value
self.encounter_call(call)
name = child.targets[0].value
self.arguments.add(SchedulerRewriter.top_level_name(name))
index = child.targets[0].slice.value
call.func = ast.Attribute(call.func, 'assign', ast.Load())
call.args = [ast.Tuple([name, index], ast.Load())] + call.args
node.body[i] = ast.Expr(call)
elif self.references_arg(child):
inserts.insert(0, i)
for index in inserts:
for wait in self.get_waits():
node.body.insert(index, wait)
def visit(self, node):
"""Ensure statement only contains allowed nodes."""
if not isinstance(node, self.ALLOWED):
raise SyntaxError('Not allowed in environment markers.\n%s\n%s' %
(self.statement,
(' ' * node.col_offset) + '^'))
return ast.NodeTransformer.visit(self, node)
def resolve_negative_literals(_ast):
class RewriteUnaryOp(ast.NodeTransformer):
def visit_UnaryOp(self, node):
if isinstance(node.op, ast.USub) and isinstance(node.operand, ast.Num):
node.operand.n = 0 - node.operand.n
return node.operand
else:
return node
return RewriteUnaryOp().visit(_ast)
# Make a getter for a variable. This function gives an output that
# contains lists of 4-tuples:
# (i) the tail of the function name for the getter
# (ii) the code for the arguments that the function takes
# (iii) the code for the return
# (iv) the output type
#
# Here is an example:
#
# Input: my_variable: {foo: num,bar: decimal[5]}
#
# Output:
#
# [('__foo','','.foo','num'),
# ('__bar','arg0: num,','.bar[arg0]','decimal')]
#
# The getters will have code:
# def get_my_variable__foo() -> num: return self.foo
# def get_my_variable__bar(arg0: nun) -> decimal: return self.bar[arg0]
def apply_ast_transform(func, ast_transformer, *,
keep_original=True, globals_dict=None, debug=0):
"""
Apply the AST transform class to a function
Args:
keep_original: True to retain the old function in attribute `.f_original`
globals_dict: pass any external function in your NodeTransformer into this
"""
if (inspect.isclass(ast_transformer)
and issubclass(ast_transformer, DecoratorAST)):
ast_transformer = ast_transformer()
else:
assert isinstance(ast_transformer, DecoratorAST)
old_ast = get_func_ast(func)
# _,starting_line = inspect.getsourcelines(func)
if debug:
print("======= OLD AST =======")
ast_print(old_ast)
visitor = ast_transformer
new_ast = visitor.visit(old_ast)
if debug:
print("======= NEW AST =======")
ast_print(new_ast)
ast.fix_missing_locations(new_ast)
co = compile(new_ast, '<ast_demo>', 'exec')
fake_locals = {}
# exec will define the new function into fake_locals scope
# this is to avoid conflict with vars in the real locals()
# https://stackoverflow.com/questions/24733831/using-a-function-defined-in-an-execed-string-in-python-3
exec(co, globals_dict, fake_locals)
new_f = fake_locals[func.__name__]
new_f.f_original = func if keep_original else None
return new_f
def test_transformer_call_visitor(self):
class BuggyTransformer(fatoptimizer.tools.NodeTransformer):
def visit_Module(self, node):
# visit_Module() calls indirectly visit_Binop(),
# but the exception must only be wrapped once
self.generic_visit(node)
def visit_Binop(self, node):
raise Exception("bug")
visitor = BuggyTransformer("<string>")
self.check_call_visitor(visitor)
def test_transformer_pass_optimizer_error(self):
class BuggyTransformer(fatoptimizer.tools.NodeTransformer):
def visit_Module(self, node):
# visit_Module() calls indirectly visit_Binop()
self.generic_visit(node)
def visit_Binop(self, node):
raise fatoptimizer.OptimizerError
visitor = BuggyTransformer("<string>")
self.check_pass_optimizer_error(visitor)
def flatten_array_declarations(root):
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
return node
def visit_Assign(self, node):
if isinstance(node.value, ast.Subscript) and isinstance(node.value.value, ast.Call):
subscr = node.value
call = subscr.value
if len(node.targets) > 1:
error.error('Cannot use multiple assignment in array declaration.', node)
variable_name = node.targets[0].id
value_type = call.func.id
declaration_args = call.args
# Get the indices being accessed.
shape = slice_node_to_tuple_of_numbers(subscr.slice)
new_assigns = []
for indices in itertools.product(*[range(n) for n in shape]):
index_name = flattened_array_name(variable_name, indices)
new_index_name_node = ast.copy_location(ast.Name(index_name, ast.Store()), node)
new_value_type_node = ast.copy_location(ast.Name(value_type, node)
new_declaration_args = [copy.deepcopy(arg) for arg in declaration_args]
new_call_node = ast.copy_location(ast.Call(new_value_type_node, new_declaration_args, [], None, None), node)
new_assign = ast.Assign([new_index_name_node], new_call_node)
new_assign = ast.copy_location(new_assign, node)
new_assigns.append(new_assign)
return new_assigns
else:
return node
return Transformer().visit(root)
def flatten_array_lookups(root):
class Transformer(ast.NodeTransformer):
def visit_FunctionDef(self, node):
self.generic_visit(node)
# Get the indices being accessed.
indices = slice_node_to_tuple_of_numbers(node.slice)
variable_name = node.value.id
index_name = flattened_array_name(variable_name, indices)
return ast.copy_location(ast.Name(index_name, node.ctx), node)
return Transformer().visit(root)
def add_input_indices(root, input_vars, index_var):
class AddInputIndicesVisitor(ast.NodeTransformer):
def visit_Subscript(self, node):
if get_var_name(node) in input_vars:
return extend_subscript_for_input(node, index_var)
return node
def visit_Name(self, node):
if node.id in input_vars:
return ast.Subscript(node, ast.Index(index_var), node.ctx)
return node
vis = AddInputIndicesVisitor()
root = vis.visit(root)
return ast.fix_missing_locations(root)
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 dio@foxmail.com 举报,一经查实,本站将立刻删除。