Spaces:
Build error
Build error
import os | |
from Cython.TestUtils import TransformTest | |
from Cython.Compiler.ParseTreeTransforms import * | |
from Cython.Compiler.Nodes import * | |
from Cython.Compiler import Main, Symtab | |
class TestNormalizeTree(TransformTest): | |
def test_parserbehaviour_is_what_we_coded_for(self): | |
t = self.fragment(u"if x: y").root | |
self.assertLines(u""" | |
(root): StatListNode | |
stats[0]: IfStatNode | |
if_clauses[0]: IfClauseNode | |
condition: NameNode | |
body: ExprStatNode | |
expr: NameNode | |
""", self.treetypes(t)) | |
def test_wrap_singlestat(self): | |
t = self.run_pipeline([NormalizeTree(None)], u"if x: y") | |
self.assertLines(u""" | |
(root): StatListNode | |
stats[0]: IfStatNode | |
if_clauses[0]: IfClauseNode | |
condition: NameNode | |
body: StatListNode | |
stats[0]: ExprStatNode | |
expr: NameNode | |
""", self.treetypes(t)) | |
def test_wrap_multistat(self): | |
t = self.run_pipeline([NormalizeTree(None)], u""" | |
if z: | |
x | |
y | |
""") | |
self.assertLines(u""" | |
(root): StatListNode | |
stats[0]: IfStatNode | |
if_clauses[0]: IfClauseNode | |
condition: NameNode | |
body: StatListNode | |
stats[0]: ExprStatNode | |
expr: NameNode | |
stats[1]: ExprStatNode | |
expr: NameNode | |
""", self.treetypes(t)) | |
def test_statinexpr(self): | |
t = self.run_pipeline([NormalizeTree(None)], u""" | |
a, b = x, y | |
""") | |
self.assertLines(u""" | |
(root): StatListNode | |
stats[0]: SingleAssignmentNode | |
lhs: TupleNode | |
args[0]: NameNode | |
args[1]: NameNode | |
rhs: TupleNode | |
args[0]: NameNode | |
args[1]: NameNode | |
""", self.treetypes(t)) | |
def test_wrap_offagain(self): | |
t = self.run_pipeline([NormalizeTree(None)], u""" | |
x | |
y | |
if z: | |
x | |
""") | |
self.assertLines(u""" | |
(root): StatListNode | |
stats[0]: ExprStatNode | |
expr: NameNode | |
stats[1]: ExprStatNode | |
expr: NameNode | |
stats[2]: IfStatNode | |
if_clauses[0]: IfClauseNode | |
condition: NameNode | |
body: StatListNode | |
stats[0]: ExprStatNode | |
expr: NameNode | |
""", self.treetypes(t)) | |
def test_pass_eliminated(self): | |
t = self.run_pipeline([NormalizeTree(None)], u"pass") | |
self.assertTrue(len(t.stats) == 0) | |
class TestWithTransform(object): # (TransformTest): # Disabled! | |
def test_simplified(self): | |
t = self.run_pipeline([WithTransform(None)], u""" | |
with x: | |
y = z ** 3 | |
""") | |
self.assertCode(u""" | |
$0_0 = x | |
$0_2 = $0_0.__exit__ | |
$0_0.__enter__() | |
$0_1 = True | |
try: | |
try: | |
$1_0 = None | |
y = z ** 3 | |
except: | |
$0_1 = False | |
if (not $0_2($1_0)): | |
raise | |
finally: | |
if $0_1: | |
$0_2(None, None, None) | |
""", t) | |
def test_basic(self): | |
t = self.run_pipeline([WithTransform(None)], u""" | |
with x as y: | |
y = z ** 3 | |
""") | |
self.assertCode(u""" | |
$0_0 = x | |
$0_2 = $0_0.__exit__ | |
$0_3 = $0_0.__enter__() | |
$0_1 = True | |
try: | |
try: | |
$1_0 = None | |
y = $0_3 | |
y = z ** 3 | |
except: | |
$0_1 = False | |
if (not $0_2($1_0)): | |
raise | |
finally: | |
if $0_1: | |
$0_2(None, None, None) | |
""", t) | |
class TestInterpretCompilerDirectives(TransformTest): | |
""" | |
This class tests the parallel directives AST-rewriting and importing. | |
""" | |
# Test the parallel directives (c)importing | |
import_code = u""" | |
cimport cython.parallel | |
cimport cython.parallel as par | |
from cython cimport parallel as par2 | |
from cython cimport parallel | |
from cython.parallel cimport threadid as tid | |
from cython.parallel cimport threadavailable as tavail | |
from cython.parallel cimport prange | |
""" | |
expected_directives_dict = { | |
u'cython.parallel': u'cython.parallel', | |
u'par': u'cython.parallel', | |
u'par2': u'cython.parallel', | |
u'parallel': u'cython.parallel', | |
u"tid": u"cython.parallel.threadid", | |
u"tavail": u"cython.parallel.threadavailable", | |
u"prange": u"cython.parallel.prange", | |
} | |
def setUp(self): | |
super(TestInterpretCompilerDirectives, self).setUp() | |
compilation_options = Main.CompilationOptions(Main.default_options) | |
ctx = compilation_options.create_context() | |
transform = InterpretCompilerDirectives(ctx, ctx.compiler_directives) | |
transform.module_scope = Symtab.ModuleScope('__main__', None, ctx) | |
self.pipeline = [transform] | |
self.debug_exception_on_error = DebugFlags.debug_exception_on_error | |
def tearDown(self): | |
DebugFlags.debug_exception_on_error = self.debug_exception_on_error | |
def test_parallel_directives_cimports(self): | |
self.run_pipeline(self.pipeline, self.import_code) | |
parallel_directives = self.pipeline[0].parallel_directives | |
self.assertEqual(parallel_directives, self.expected_directives_dict) | |
def test_parallel_directives_imports(self): | |
self.run_pipeline(self.pipeline, | |
self.import_code.replace(u'cimport', u'import')) | |
parallel_directives = self.pipeline[0].parallel_directives | |
self.assertEqual(parallel_directives, self.expected_directives_dict) | |
# TODO: Re-enable once they're more robust. | |
if False: | |
from Cython.Debugger import DebugWriter | |
from Cython.Debugger.Tests.TestLibCython import DebuggerTestCase | |
else: | |
# skip test, don't let it inherit unittest.TestCase | |
DebuggerTestCase = object | |
class TestDebugTransform(DebuggerTestCase): | |
def elem_hasattrs(self, elem, attrs): | |
return all(attr in elem.attrib for attr in attrs) | |
def test_debug_info(self): | |
try: | |
assert os.path.exists(self.debug_dest) | |
t = DebugWriter.etree.parse(self.debug_dest) | |
# the xpath of the standard ElementTree is primitive, don't use | |
# anything fancy | |
L = list(t.find('/Module/Globals')) | |
assert L | |
xml_globals = dict((e.attrib['name'], e.attrib['type']) for e in L) | |
self.assertEqual(len(L), len(xml_globals)) | |
L = list(t.find('/Module/Functions')) | |
assert L | |
xml_funcs = dict((e.attrib['qualified_name'], e) for e in L) | |
self.assertEqual(len(L), len(xml_funcs)) | |
# test globals | |
self.assertEqual('CObject', xml_globals.get('c_var')) | |
self.assertEqual('PythonObject', xml_globals.get('python_var')) | |
# test functions | |
funcnames = ('codefile.spam', 'codefile.ham', 'codefile.eggs', | |
'codefile.closure', 'codefile.inner') | |
required_xml_attrs = 'name', 'cname', 'qualified_name' | |
assert all(f in xml_funcs for f in funcnames) | |
spam, ham, eggs = [xml_funcs[funcname] for funcname in funcnames] | |
self.assertEqual(spam.attrib['name'], 'spam') | |
self.assertNotEqual('spam', spam.attrib['cname']) | |
assert self.elem_hasattrs(spam, required_xml_attrs) | |
# test locals of functions | |
spam_locals = list(spam.find('Locals')) | |
assert spam_locals | |
spam_locals.sort(key=lambda e: e.attrib['name']) | |
names = [e.attrib['name'] for e in spam_locals] | |
self.assertEqual(list('abcd'), names) | |
assert self.elem_hasattrs(spam_locals[0], required_xml_attrs) | |
# test arguments of functions | |
spam_arguments = list(spam.find('Arguments')) | |
assert spam_arguments | |
self.assertEqual(1, len(list(spam_arguments))) | |
# test step-into functions | |
step_into = spam.find('StepIntoFunctions') | |
spam_stepinto = [x.attrib['name'] for x in step_into] | |
assert spam_stepinto | |
self.assertEqual(2, len(spam_stepinto)) | |
assert 'puts' in spam_stepinto | |
assert 'some_c_function' in spam_stepinto | |
except: | |
f = open(self.debug_dest) | |
try: | |
print(f.read()) | |
finally: | |
f.close() | |
raise | |
if __name__ == "__main__": | |
import unittest | |
unittest.main() | |