aboutsummaryrefslogtreecommitdiff
path: root/tests/rule_based_toolchain/generate_factory.bzl
diff options
context:
space:
mode:
Diffstat (limited to 'tests/rule_based_toolchain/generate_factory.bzl')
-rw-r--r--tests/rule_based_toolchain/generate_factory.bzl130
1 files changed, 130 insertions, 0 deletions
diff --git a/tests/rule_based_toolchain/generate_factory.bzl b/tests/rule_based_toolchain/generate_factory.bzl
new file mode 100644
index 0000000..c58bb51
--- /dev/null
+++ b/tests/rule_based_toolchain/generate_factory.bzl
@@ -0,0 +1,130 @@
+# Copyright 2024 The Bazel Authors. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Generates provider factories."""
+
+load("@bazel_skylib//lib:structs.bzl", "structs")
+load("@rules_testing//lib:truth.bzl", "subjects")
+
+visibility("private")
+
+def generate_factory(type, name, attrs):
+ """Generates a factory for a custom struct.
+
+ There are three reasons we need to do so:
+ 1. It's very difficult to read providers printed by these types.
+ eg. If you have a 10 layer deep diamond dependency graph, and try to
+ print the top value, the bottom value will be printed 2^10 times.
+ 2. Collections of subjects are not well supported by rules_testing
+ eg. `FeatureInfo(flag_sets = [FlagSetInfo(...)])`
+ (You can do it, but the inner values are just regular bazel structs and
+ you can't do fluent assertions on them).
+ 3. Recursive types are not supported at all
+ eg. `FeatureInfo(implies = depset([FeatureInfo(...)]))`
+
+ To solve this, we create a factory that:
+ * Validates that the types of the children are correct.
+ * Inlines providers to their labels when unambiguous.
+
+ For example, given:
+
+ ```
+ foo = FeatureInfo(name = "foo", label = Label("//:foo"))
+ bar = FeatureInfo(..., implies = depset([foo]))
+ ```
+
+ It would convert itself a subject for the following struct:
+ `FeatureInfo(..., implies = depset([Label("//:foo")]))`
+
+ Args:
+ type: (type) The type to create a factory for (eg. FooInfo)
+ name: (str) The name of the type (eg. "FooInfo")
+ attrs: (dict[str, Factory]) The attributes associated with this type.
+
+ Returns:
+ A struct `FooFactory` suitable for use with
+ * `analysis_test(provider_subject_factories=[FooFactory])`
+ * `generate_factory(..., attrs=dict(foo = FooFactory))`
+ * `ProviderSequence(FooFactory)`
+ * `DepsetSequence(FooFactory)`
+ """
+ attrs["label"] = subjects.label
+
+ want_keys = sorted(attrs.keys())
+
+ def validate(*, value, meta):
+ if value == None:
+ meta.add_failure("Wanted a %s but got" % name, value)
+ got_keys = sorted(structs.to_dict(value).keys())
+ subjects.collection(got_keys, meta = meta.derive(details = [
+ "Value was not a %s - it has a different set of fields" % name,
+ ])).contains_exactly(want_keys).in_order()
+
+ def type_factory(value, *, meta):
+ validate(value = value, meta = meta)
+
+ transformed_value = {}
+ transformed_factories = {}
+ for field, factory in attrs.items():
+ field_value = getattr(value, field)
+
+ # If it's a type generated by generate_factory, inline it.
+ if hasattr(factory, "factory"):
+ factory.validate(value = field_value, meta = meta.derive(field))
+ transformed_value[field] = field_value.label
+ transformed_factories[field] = subjects.label
+ else:
+ transformed_value[field] = field_value
+ transformed_factories[field] = factory
+
+ return subjects.struct(
+ struct(**transformed_value),
+ meta = meta,
+ attrs = transformed_factories,
+ )
+
+ return struct(
+ type = type,
+ name = name,
+ factory = type_factory,
+ validate = validate,
+ )
+
+def _provider_collection(element_factory, fn):
+ def factory(value, *, meta):
+ value = fn(value)
+
+ # Validate that it really is the correct type
+ for i in range(len(value)):
+ element_factory.validate(
+ value = value[i],
+ meta = meta.derive("offset({})".format(i)),
+ )
+
+ # Inline the providers to just labels.
+ return subjects.collection([v.label for v in value], meta = meta)
+
+ return factory
+
+# This acts like a class, so we name it like one.
+# buildifier: disable=name-conventions
+ProviderSequence = lambda element_factory: _provider_collection(
+ element_factory,
+ fn = lambda x: list(x),
+)
+
+# buildifier: disable=name-conventions
+ProviderDepset = lambda element_factory: _provider_collection(
+ element_factory,
+ fn = lambda x: x.to_list(),
+)