aboutsummaryrefslogtreecommitdiff
path: root/cc/toolchains/impl/args_utils.bzl
blob: 55b4841d95f5c03b588df4dc62732c679e9978fc (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2024 The Bazel Authors. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Helper functions for working with args."""

load(":variables.bzl", "get_type")

visibility([
    "//cc/toolchains",
    "//tests/rule_based_toolchain/...",
])

def get_action_type(args_list, action_type):
    """Returns the corresponding entry in ArgsListInfo.by_action.

    Args:
        args_list: (ArgsListInfo) The args list to look through
        action_type: (ActionTypeInfo) The action type to look up.
    Returns:
        The information corresponding to this action type.

    """
    for args in args_list.by_action:
        if args.action == action_type:
            return args

    return struct(action = action_type, args = tuple(), files = depset([]))

def validate_nested_args(*, nested_args, variables, actions, label, fail = fail):
    """Validates the typing for an nested_args invocation.

    Args:
        nested_args: (NestedArgsInfo) The nested_args to validate
        variables: (Dict[str, VariableInfo]) A mapping from variable name to
          the metadata (variable type and valid actions).
        actions: (List[ActionTypeInfo]) The actions we require these variables
          to be valid for.
        label: (Label) The label of the rule we're currently validating.
          Used for error messages.
        fail: The fail function. Use for testing only.
    """
    stack = [(nested_args, {})]

    for _ in range(9999999):
        if not stack:
            break
        nested_args, overrides = stack.pop()
        if nested_args.iterate_over != None or nested_args.unwrap_options:
            # Make sure we don't keep using the same object.
            overrides = dict(**overrides)

        if nested_args.iterate_over != None:
            type = get_type(
                name = nested_args.iterate_over,
                variables = variables,
                overrides = overrides,
                actions = actions,
                args_label = label,
                nested_label = nested_args.label,
                fail = fail,
            )
            if type["name"] == "list":
                # Rewrite the type of the thing we iterate over from a List[T]
                # to a T.
                overrides[nested_args.iterate_over] = type["elements"]
            elif type["name"] == "option" and type["elements"]["name"] == "list":
                # Rewrite Option[List[T]] to T.
                overrides[nested_args.iterate_over] = type["elements"]["elements"]
            else:
                fail("Attempting to iterate over %s, but it was not a list - it was a %s" % (nested_args.iterate_over, type["repr"]))

        # 1) Validate variables marked with after_option_unwrap = False.
        # 2) Unwrap Option[T] to T as required.
        # 3) Validate variables marked with after_option_unwrap = True.
        for after_option_unwrap in [False, True]:
            for var_name, requirements in nested_args.requires_types.items():
                for requirement in requirements:
                    if requirement.after_option_unwrap == after_option_unwrap:
                        type = get_type(
                            name = var_name,
                            variables = variables,
                            overrides = overrides,
                            actions = actions,
                            args_label = label,
                            nested_label = nested_args.label,
                            fail = fail,
                        )
                        if type["name"] not in requirement.valid_types:
                            fail("{msg}, but {var_name} has type {type}".format(
                                var_name = var_name,
                                msg = requirement.msg,
                                type = type["repr"],
                            ))

            # Only unwrap the options after the first iteration of this loop.
            if not after_option_unwrap:
                for var in nested_args.unwrap_options:
                    type = get_type(
                        name = var,
                        variables = variables,
                        overrides = overrides,
                        actions = actions,
                        args_label = label,
                        nested_label = nested_args.label,
                        fail = fail,
                    )
                    if type["name"] == "option":
                        overrides[var] = type["elements"]

        for child in nested_args.nested:
            stack.append((child, overrides))