aboutsummaryrefslogtreecommitdiff
path: root/tests/unittest_transforms.py
blob: 63ac10dd295eda4fcd67ab089e61513b14dde8f8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
# Copyright (c) 2015-2018, 2020 Claudiu Popa <pcmanticore@gmail.com>
# Copyright (c) 2015-2016 Ceridwen <ceridwenv@gmail.com>
# Copyright (c) 2016 Jakub Wilk <jwilk@jwilk.net>
# Copyright (c) 2018 Bryce Guinta <bryce.paul.guinta@gmail.com>
# Copyright (c) 2019 Ashley Whetter <ashley@awhetter.co.uk>
# Copyright (c) 2020-2021 hippo91 <guillaume.peillex@gmail.com>
# Copyright (c) 2021 Pierre Sassoulas <pierre.sassoulas@gmail.com>
# Copyright (c) 2021 Marc Mueller <30130371+cdce8p@users.noreply.github.com>

# Licensed under the LGPL: https://www.gnu.org/licenses/old-licenses/lgpl-2.1.en.html
# For details: https://github.com/PyCQA/astroid/blob/main/LICENSE


import contextlib
import time
import unittest
from typing import Callable, Iterator, Optional

from astroid import MANAGER, builder, nodes, parse, transforms
from astroid.manager import AstroidManager
from astroid.nodes.node_classes import Call, Compare, Const, Name
from astroid.nodes.scoped_nodes import FunctionDef, Module


@contextlib.contextmanager
def add_transform(
    manager: AstroidManager,
    node: type,
    transform: Callable,
    predicate: Optional[Callable] = None,
) -> Iterator:
    manager.register_transform(node, transform, predicate)
    try:
        yield
    finally:
        manager.unregister_transform(node, transform, predicate)


class TestTransforms(unittest.TestCase):
    def setUp(self) -> None:
        self.transformer = transforms.TransformVisitor()

    def parse_transform(self, code: str) -> Module:
        module = parse(code, apply_transforms=False)
        return self.transformer.visit(module)

    def test_function_inlining_transform(self) -> None:
        def transform_call(node: Call) -> Const:
            # Let's do some function inlining
            inferred = next(node.infer())
            return inferred

        self.transformer.register_transform(nodes.Call, transform_call)

        module = self.parse_transform(
            """
        def test(): return 42
        test() #@
        """
        )

        self.assertIsInstance(module.body[1], nodes.Expr)
        self.assertIsInstance(module.body[1].value, nodes.Const)
        self.assertEqual(module.body[1].value.value, 42)

    def test_recursive_transforms_into_astroid_fields(self) -> None:
        # Test that the transformer walks properly the tree
        # by going recursively into the _astroid_fields per each node.
        def transform_compare(node: Compare) -> Const:
            # Let's check the values of the ops
            _, right = node.ops[0]
            # Assume they are Consts and they were transformed before
            # us.
            return nodes.const_factory(node.left.value < right.value)

        def transform_name(node: Name) -> Const:
            # Should be Consts
            return next(node.infer())

        self.transformer.register_transform(nodes.Compare, transform_compare)
        self.transformer.register_transform(nodes.Name, transform_name)

        module = self.parse_transform(
            """
        a = 42
        b = 24
        a < b
        """
        )

        self.assertIsInstance(module.body[2], nodes.Expr)
        self.assertIsInstance(module.body[2].value, nodes.Const)
        self.assertFalse(module.body[2].value.value)

    def test_transform_patches_locals(self) -> None:
        def transform_function(node: FunctionDef) -> None:
            assign = nodes.Assign()
            name = nodes.AssignName(name="value")
            assign.targets = [name]
            assign.value = nodes.const_factory(42)
            node.body.append(assign)

        self.transformer.register_transform(nodes.FunctionDef, transform_function)

        module = self.parse_transform(
            """
        def test():
            pass
        """
        )

        func = module.body[0]
        self.assertEqual(len(func.body), 2)
        self.assertIsInstance(func.body[1], nodes.Assign)
        self.assertEqual(func.body[1].as_string(), "value = 42")

    def test_predicates(self) -> None:
        def transform_call(node: Call) -> Const:
            inferred = next(node.infer())
            return inferred

        def should_inline(node: Call) -> bool:
            return node.func.name.startswith("inlineme")

        self.transformer.register_transform(nodes.Call, transform_call, should_inline)

        module = self.parse_transform(
            """
        def inlineme_1():
            return 24
        def dont_inline_me():
            return 42
        def inlineme_2():
            return 2
        inlineme_1()
        dont_inline_me()
        inlineme_2()
        """
        )
        values = module.body[-3:]
        self.assertIsInstance(values[0], nodes.Expr)
        self.assertIsInstance(values[0].value, nodes.Const)
        self.assertEqual(values[0].value.value, 24)
        self.assertIsInstance(values[1], nodes.Expr)
        self.assertIsInstance(values[1].value, nodes.Call)
        self.assertIsInstance(values[2], nodes.Expr)
        self.assertIsInstance(values[2].value, nodes.Const)
        self.assertEqual(values[2].value.value, 2)

    def test_transforms_are_separated(self) -> None:
        # Test that the transforming is done at a separate
        # step, which means that we are not doing inference
        # on a partially constructed tree anymore, which was the
        # source of crashes in the past when certain inference rules
        # were used in a transform.
        def transform_function(node: FunctionDef) -> Const:
            if node.decorators:
                for decorator in node.decorators.nodes:
                    inferred = next(decorator.infer())
                    if inferred.qname() == "abc.abstractmethod":
                        return next(node.infer_call_result())
            return None

        manager = MANAGER
        with add_transform(manager, nodes.FunctionDef, transform_function):
            module = builder.parse(
                """
            import abc
            from abc import abstractmethod

            class A(object):
                @abc.abstractmethod
                def ala(self):
                    return 24

                @abstractmethod
                def bala(self):
                    return 42
            """
            )

        cls = module["A"]
        ala = cls.body[0]
        bala = cls.body[1]
        self.assertIsInstance(ala, nodes.Const)
        self.assertEqual(ala.value, 24)
        self.assertIsInstance(bala, nodes.Const)
        self.assertEqual(bala.value, 42)

    def test_transforms_are_called_for_builtin_modules(self) -> None:
        # Test that transforms are called for builtin modules.
        def transform_function(node: FunctionDef) -> FunctionDef:
            name = nodes.AssignName(name="value")
            node.args.args = [name]
            return node

        manager = MANAGER

        def predicate(node: FunctionDef) -> bool:
            return node.root().name == "time"

        with add_transform(manager, nodes.FunctionDef, transform_function, predicate):
            builder_instance = builder.AstroidBuilder()
            module = builder_instance.module_build(time)

        asctime = module["asctime"]
        self.assertEqual(len(asctime.args.args), 1)
        self.assertIsInstance(asctime.args.args[0], nodes.AssignName)
        self.assertEqual(asctime.args.args[0].name, "value")

    def test_builder_apply_transforms(self) -> None:
        def transform_function(node):
            return nodes.const_factory(42)

        manager = MANAGER
        with add_transform(manager, nodes.FunctionDef, transform_function):
            astroid_builder = builder.AstroidBuilder(apply_transforms=False)
            module = astroid_builder.string_build("""def test(): pass""")

        # The transform wasn't applied.
        self.assertIsInstance(module.body[0], nodes.FunctionDef)

    def test_transform_crashes_on_is_subtype_of(self) -> None:
        # Test that we don't crash when having is_subtype_of
        # in a transform, as per issue #188. This happened
        # before, when the transforms weren't in their own step.
        def transform_class(cls):
            if cls.is_subtype_of("django.db.models.base.Model"):
                return cls
            return cls

        self.transformer.register_transform(nodes.ClassDef, transform_class)

        self.parse_transform(
            """
            # Change environ to automatically call putenv() if it exists
            import os
            putenv = os.putenv
            try:
                # This will fail if there's no putenv
                putenv
            except NameError:
                pass
            else:
                import UserDict
        """
        )


if __name__ == "__main__":
    unittest.main()