aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSid Nayyar <sidnayyar@google.com>2024-04-16 12:33:15 +0100
committerGiuliano Procida <gprocida@google.com>2024-04-25 22:06:08 +0100
commitcee934c475d281bee29eb13541311531dfb64103 (patch)
tree3c2ae2dfb5f93bcb2723a33f2082eeea02bf6c9d
parent5ef4e13addf4037dc4f83ebcd76e926fb59884d7 (diff)
downloadstg-cee934c475d281bee29eb13541311531dfb64103.tar.gz
rust: add `Variant` node
These nodes will be used to represent 'fieldful' or tagged Rust enums. STG Rust ABI representation is unstable and is not yet subject to format versioning. PiperOrigin-RevId: 625282784 Change-Id: Ife43024161f7d47cc98584c1f9d7afd08f08a345
-rw-r--r--comparison.cc17
-rw-r--r--comparison.h1
-rw-r--r--equality.h7
-rw-r--r--fidelity.cc6
-rw-r--r--fingerprint.cc6
-rw-r--r--graph.h31
-rw-r--r--naming.cc6
-rw-r--r--naming.h1
-rw-r--r--proto_reader.cc7
-rw-r--r--proto_writer.cc13
-rw-r--r--stable_hash.cc7
-rw-r--r--stable_hash.h1
-rw-r--r--stg.proto17
-rw-r--r--substitution.h5
-rw-r--r--type_normalisation.cc5
-rw-r--r--type_resolution.cc10
-rw-r--r--unification.cc8
17 files changed, 143 insertions, 5 deletions
diff --git a/comparison.cc b/comparison.cc
index d5067b6..8e0dacb 100644
--- a/comparison.cc
+++ b/comparison.cc
@@ -560,6 +560,23 @@ Result Compare::operator()(const Enumeration& x1, const Enumeration& x2) {
return result;
}
+Result Compare::operator()(const Variant& x1, const Variant& x2) {
+ Result result;
+ // Compare two identically named variants recursively, holding diffs.
+ // Everything else treated as distinct. No recursion.
+ if (x1.name != x2.name) {
+ return result.MarkIncomparable();
+ }
+ result.diff_.holds_changes = true; // Anonymous variants are not allowed.
+
+ result.MaybeAddNodeDiff("bytesize", x1.bytesize, x2.bytesize);
+ const auto type_diff =
+ (*this)(x1.discriminant_type_id, x2.discriminant_type_id);
+ result.MaybeAddEdgeDiff("discriminant", type_diff);
+ CompareNodes(result, *this, x1.members, x2.members);
+ return result;
+}
+
Result Compare::operator()(const Function& x1, const Function& x2) {
Result result;
const auto type_diff = (*this)(x1.return_type_id, x2.return_type_id);
diff --git a/comparison.h b/comparison.h
index 7d49803..034c16a 100644
--- a/comparison.h
+++ b/comparison.h
@@ -288,6 +288,7 @@ struct Compare {
Result operator()(const VariantMember&, const VariantMember&);
Result operator()(const StructUnion&, const StructUnion&);
Result operator()(const Enumeration&, const Enumeration&);
+ Result operator()(const Variant&, const Variant&);
Result operator()(const Function&, const Function&);
Result operator()(const ElfSymbol&, const ElfSymbol&);
Result operator()(const Interface&, const Interface&);
diff --git a/equality.h b/equality.h
index 6fb4e72..eb47db9 100644
--- a/equality.h
+++ b/equality.h
@@ -194,6 +194,13 @@ struct Equals {
return result;
}
+ bool operator()(const Variant& x1, const Variant& x2) {
+ return x1.name == x2.name
+ && x1.bytesize == x2.bytesize
+ && (*this)(x1.discriminant_type_id, x2.discriminant_type_id)
+ && (*this)(x1.members, x2.members);
+ }
+
bool operator()(const Function& x1, const Function& x2) {
return (*this)(x1.parameters, x2.parameters)
&& (*this)(x1.return_type_id, x2.return_type_id);
diff --git a/fidelity.cc b/fidelity.cc
index 250c9d9..66ab9d2 100644
--- a/fidelity.cc
+++ b/fidelity.cc
@@ -56,6 +56,7 @@ struct Fidelity {
void operator()(const VariantMember&, Id);
void operator()(const StructUnion&, Id);
void operator()(const Enumeration&, Id);
+ void operator()(const Variant&, Id);
void operator()(const Function&, Id);
void operator()(const ElfSymbol&, Id);
void operator()(const Interface&, Id);
@@ -151,6 +152,11 @@ void Fidelity::operator()(const Enumeration& x, Id id) {
}
}
+void Fidelity::operator()(const Variant& x, Id id) {
+ types.emplace(describe(id).ToString(), TypeFidelity::FULLY_DEFINED);
+ (*this)(x.members);
+}
+
void Fidelity::operator()(const Function& x, Id) {
(*this)(x.return_type_id);
(*this)(x.parameters);
diff --git a/fingerprint.cc b/fingerprint.cc
index 6f1c265..05a4cdc 100644
--- a/fingerprint.cc
+++ b/fingerprint.cc
@@ -138,6 +138,12 @@ struct Hasher {
return h;
}
+ HashValue operator()(const Variant& x) {
+ auto h = hash('v', x.name, x.bytesize, (*this)(x.discriminant_type_id));
+ ToDo(x.members);
+ return h;
+ }
+
HashValue operator()(const Function& x) {
auto h = hash('F', (*this)(x.return_type_id));
for (const auto& parameter : x.parameters) {
diff --git a/graph.h b/graph.h
index bb51625..04ff7cd 100644
--- a/graph.h
+++ b/graph.h
@@ -243,6 +243,20 @@ struct Enumeration {
std::optional<Definition> definition;
};
+struct Variant {
+ Variant(const std::string& name, uint64_t bytesize, Id discriminant_type_id,
+ const std::vector<Id>& members)
+ : name(name),
+ bytesize(bytesize),
+ discriminant_type_id(discriminant_type_id),
+ members(members) {}
+
+ std::string name;
+ uint64_t bytesize;
+ Id discriminant_type_id;
+ std::vector<Id> members;
+};
+
struct Function {
Function(Id return_type_id, const std::vector<Id>& parameters)
: return_type_id(return_type_id), parameters(parameters) {}
@@ -382,6 +396,9 @@ class Graph {
} else if constexpr (std::is_same_v<Node, Enumeration>) {
reference = {Which::ENUMERATION, enumeration_.size()};
enumeration_.emplace_back(std::forward<Args>(args)...);
+ } else if constexpr (std::is_same_v<Node, Variant>) {
+ reference = {Which::VARIANT, variant_.size()};
+ variant_.emplace_back(std::forward<Args>(args)...);
} else if constexpr (std::is_same_v<Node, Function>) {
reference = {Which::FUNCTION, function_.size()};
function_.emplace_back(std::forward<Args>(args)...);
@@ -456,6 +473,7 @@ class Graph {
VARIANT_MEMBER,
STRUCT_UNION,
ENUMERATION,
+ VARIANT,
FUNCTION,
ELF_SYMBOL,
INTERFACE,
@@ -476,6 +494,7 @@ class Graph {
std::vector<VariantMember> variant_member_;
std::vector<StructUnion> struct_union_;
std::vector<Enumeration> enumeration_;
+ std::vector<Variant> variant_;
std::vector<Function> function_;
std::vector<ElfSymbol> elf_symbol_;
std::vector<Interface> interface_;
@@ -513,6 +532,8 @@ Result Graph::Apply(FunctionObject& function, Id id, Args&&... args) const {
return function(struct_union_[ix], std::forward<Args>(args)...);
case Which::ENUMERATION:
return function(enumeration_[ix], std::forward<Args>(args)...);
+ case Which::VARIANT:
+ return function(variant_[ix], std::forward<Args>(args)...);
case Which::FUNCTION:
return function(function_[ix], std::forward<Args>(args)...);
case Which::ELF_SYMBOL:
@@ -572,6 +593,9 @@ Result Graph::Apply2(
case Which::ENUMERATION:
return function(enumeration_[ix1], enumeration_[ix2],
std::forward<Args>(args)...);
+ case Which::VARIANT:
+ return function(variant_[ix1], variant_[ix2],
+ std::forward<Args>(args)...);
case Which::FUNCTION:
return function(function_[ix1], function_[ix2],
std::forward<Args>(args)...);
@@ -628,6 +652,13 @@ struct InterfaceKey {
return "enum " + x.name;
}
+ std::string operator()(const stg::Variant& x) const {
+ if (x.name.empty()) {
+ Die() << "anonymous variant interface type";
+ }
+ return "variant " + x.name;
+ }
+
std::string operator()(const stg::ElfSymbol& x) const {
return VersionedSymbolName(x);
}
diff --git a/naming.cc b/naming.cc
index c30de56..77548c1 100644
--- a/naming.cc
+++ b/naming.cc
@@ -225,6 +225,12 @@ Name Describe::operator()(const Enumeration& x) {
return Name{os.str()};
}
+Name Describe::operator()(const Variant& x) {
+ std::ostringstream os;
+ os << "variant " << x.name;
+ return Name{os.str()};
+}
+
Name Describe::operator()(const Function& x) {
std::ostringstream os;
os << '(';
diff --git a/naming.h b/naming.h
index 0e5c3a1..87496b3 100644
--- a/naming.h
+++ b/naming.h
@@ -71,6 +71,7 @@ struct Describe {
Name operator()(const VariantMember&);
Name operator()(const StructUnion&);
Name operator()(const Enumeration&);
+ Name operator()(const Variant&);
Name operator()(const Function&);
Name operator()(const ElfSymbol&);
Name operator()(const Interface&);
diff --git a/proto_reader.cc b/proto_reader.cc
index 88ec127..f2683ef 100644
--- a/proto_reader.cc
+++ b/proto_reader.cc
@@ -65,6 +65,7 @@ struct Transformer {
void AddNode(const BaseClass&);
void AddNode(const Method&);
void AddNode(const Member&);
+ void AddNode(const Variant&);
void AddNode(const StructUnion&);
void AddNode(const Enumeration&);
void AddNode(const VariantMember&);
@@ -115,6 +116,7 @@ Id Transformer::Transform(const proto::STG& x) {
AddNodes(x.variant_member());
AddNodes(x.struct_union());
AddNodes(x.enumeration());
+ AddNodes(x.variant());
AddNodes(x.function());
AddNodes(x.elf_symbol());
AddNodes(x.symbols());
@@ -224,6 +226,11 @@ void Transformer::AddNode(const Enumeration& x) {
}
}
+void Transformer::AddNode(const Variant& x) {
+ AddNode<stg::Variant>(GetId(x.id()), x.name(), x.bytesize(),
+ GetId(x.discriminant_type_id()), x.member_id());
+}
+
void Transformer::AddNode(const Function& x) {
AddNode<stg::Function>(GetId(x.id()), GetId(x.return_type_id()),
x.parameter_id());
diff --git a/proto_writer.cc b/proto_writer.cc
index a4ee972..ccbd683 100644
--- a/proto_writer.cc
+++ b/proto_writer.cc
@@ -77,6 +77,7 @@ struct Transform {
void operator()(const stg::VariantMember&, uint32_t);
void operator()(const stg::StructUnion&, uint32_t);
void operator()(const stg::Enumeration&, uint32_t);
+ void operator()(const stg::Variant&, uint32_t);
void operator()(const stg::Function&, uint32_t);
void operator()(const stg::ElfSymbol&, uint32_t);
void operator()(const stg::Interface&, uint32_t);
@@ -263,6 +264,18 @@ void Transform<MapId>::operator()(const stg::Enumeration& x, uint32_t id) {
}
template <typename MapId>
+void Transform<MapId>::operator()(const stg::Variant& x, uint32_t id) {
+ auto& variant = *stg.add_variant();
+ variant.set_id(id);
+ variant.set_name(x.name);
+ variant.set_bytesize(x.bytesize);
+ variant.set_discriminant_type_id((*this)(x.discriminant_type_id));
+ for (const auto id : x.members) {
+ variant.add_member_id((*this)(id));
+ }
+}
+
+template <typename MapId>
void Transform<MapId>::operator()(const stg::Function& x, uint32_t id) {
auto& function = *stg.add_function();
function.set_id(id);
diff --git a/stable_hash.cc b/stable_hash.cc
index 725cf61..a8f9366 100644
--- a/stable_hash.cc
+++ b/stable_hash.cc
@@ -159,6 +159,13 @@ HashValue StableHash::operator()(const Enumeration& x) {
hash, DecayHashCombineInReverse<8>(x.definition->enumerators, hash_enum));
}
+HashValue StableHash::operator()(const Variant& x) {
+ HashValue hash = hash_('V', x.name, x.bytesize);
+ hash = DecayHashCombine<8>(hash, (*this)(x.discriminant_type_id));
+ return DecayHashCombine<2>(hash,
+ DecayHashCombineInReverse<8>(x.members, *this));
+}
+
HashValue StableHash::operator()(const Function& x) {
return DecayHashCombine<2>(hash_('f', (*this)(x.return_type_id)),
DecayHashCombineInReverse<4>(x.parameters, *this));
diff --git a/stable_hash.h b/stable_hash.h
index 75cbb92..b0b9265 100644
--- a/stable_hash.h
+++ b/stable_hash.h
@@ -48,6 +48,7 @@ class StableHash {
HashValue operator()(const VariantMember&);
HashValue operator()(const StructUnion&);
HashValue operator()(const Enumeration&);
+ HashValue operator()(const Variant&);
HashValue operator()(const Function&);
HashValue operator()(const ElfSymbol&);
HashValue operator()(const Interface&);
diff --git a/stg.proto b/stg.proto
index 9c21d81..7703b79 100644
--- a/stg.proto
+++ b/stg.proto
@@ -201,6 +201,14 @@ message Enumeration {
optional Definition definition = 3;
}
+message Variant {
+ fixed32 id = 1;
+ string name = 2;
+ uint64 bytesize = 3;
+ fixed32 discriminant_type_id = 4;
+ repeated fixed32 member_id = 5;
+}
+
message Function {
fixed32 id = 1;
fixed32 return_type_id = 2;
@@ -278,8 +286,9 @@ message STG {
repeated VariantMember variant_member = 15;
repeated StructUnion struct_union = 16;
repeated Enumeration enumeration = 17;
- repeated Function function = 18;
- repeated ElfSymbol elf_symbol = 19;
- repeated Symbols symbols = 20;
- repeated Interface interface = 21;
+ repeated Variant variant = 18;
+ repeated Function function = 19;
+ repeated ElfSymbol elf_symbol = 20;
+ repeated Symbols symbols = 21;
+ repeated Interface interface = 22;
}
diff --git a/substitution.h b/substitution.h
index de0dad0..863115f 100644
--- a/substitution.h
+++ b/substitution.h
@@ -119,6 +119,11 @@ struct Substitute {
}
}
+ void operator()(Variant& x) {
+ Update(x.discriminant_type_id);
+ Update(x.members);
+ }
+
void operator()(Function& x) {
Update(x.parameters);
Update(x.return_type_id);
diff --git a/type_normalisation.cc b/type_normalisation.cc
index 70b0493..aeac699 100644
--- a/type_normalisation.cc
+++ b/type_normalisation.cc
@@ -145,6 +145,11 @@ struct FindQualifiedTypesAndFunctions {
}
}
+ void operator()(const Variant& x, Id) {
+ (*this)(x.discriminant_type_id);
+ (*this)(x.members);
+ }
+
void operator()(const Function& x, Id node_id) {
functions.emplace(node_id);
for (auto& id : x.parameters) {
diff --git a/type_resolution.cc b/type_resolution.cc
index d28db66..c9fe51d 100644
--- a/type_resolution.cc
+++ b/type_resolution.cc
@@ -45,7 +45,7 @@ struct NamedTypes {
seen.Reserve(graph.Limit());
}
- enum class Tag { STRUCT, UNION, ENUM, TYPEDEF };
+ enum class Tag { STRUCT, UNION, ENUM, TYPEDEF, VARIANT };
using Type = std::pair<Tag, std::string>;
struct Info {
std::vector<Id> definitions;
@@ -160,6 +160,14 @@ struct NamedTypes {
}
}
+ void operator()(const Variant& x, Id id) {
+ const auto& name = x.name;
+ auto& info = GetInfo(Tag::VARIANT, name);
+ info.definitions.push_back(id);
+ ++definitions;
+ (*this)(x.members);
+ }
+
void operator()(const Function& x, Id) {
(*this)(x.return_type_id);
(*this)(x.parameters);
diff --git a/unification.cc b/unification.cc
index 77891ea..0d77012 100644
--- a/unification.cc
+++ b/unification.cc
@@ -205,6 +205,14 @@ struct Unifier {
return result ? definition2.has_value() ? Right : Left : Neither;
}
+ Winner operator()(const Variant& x1, const Variant& x2) {
+ return x1.name == x2.name
+ && x1.bytesize == x2.bytesize
+ && (*this)(x1.discriminant_type_id, x2.discriminant_type_id)
+ && (*this)(x1.members, x2.members)
+ ? Right : Neither;
+ }
+
Winner operator()(const Function& x1, const Function& x2) {
return (*this)(x1.parameters, x2.parameters)
&& (*this)(x1.return_type_id, x2.return_type_id)