summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Geisler <mgeisler@google.com>2024-04-09 13:35:45 +0200
committerMartin Geisler <mgeisler@google.com>2024-04-09 20:40:35 +0200
commit51f31ccd15c1b26d5f9a23adb7ebd680c73f6e52 (patch)
tree23b070c61b3b54e6fa7298a5ee7e2cce8d154a58
parent64c67b1b32441f8096c819e80b8d2a77b1fb7b28 (diff)
downloadmls-rs-upstream.tar.gz
Import 'mls-rs' crateupstream
Request Document: go/android-rust-importing-crates For CL Reviewers: go/android3p#cl-review For Build Team: go/ab-third-party-imports Bug: http://b/330708876 Test: m libmls_rs Change-Id: Ib0a891a4d7bf582ebea9ba7a1447ea959e42e0d3
-rw-r--r--.cargo/config.toml2
-rw-r--r--.cargo_vcs_info.json6
-rw-r--r--Cargo.lock1782
-rw-r--r--Cargo.toml389
l---------LICENSE1
-rw-r--r--LICENSE-apache176
-rw-r--r--LICENSE-mit9
-rw-r--r--METADATA21
-rw-r--r--MODULE_LICENSE_APACHE20
-rw-r--r--OWNERS2
-rw-r--r--README.md71
-rw-r--r--benches/group_add.rs87
-rw-r--r--benches/group_application.rs48
-rw-r--r--benches/group_commit.rs29
-rw-r--r--benches/group_receive_commit.rs39
-rw-r--r--benches/group_serialize.rs30
-rw-r--r--cargo_embargo.json16
-rw-r--r--examples/basic_server_usage.rs193
-rw-r--r--examples/basic_usage.rs79
-rw-r--r--examples/custom.rs448
-rw-r--r--examples/large_group.rs183
-rw-r--r--examples/x509.rs38
-rw-r--r--src/client.rs1049
-rw-r--r--src/client_builder.rs1029
-rw-r--r--src/client_config.rs68
-rw-r--r--src/crypto.rs43
-rw-r--r--src/extension.rs52
-rw-r--r--src/extension/built_in.rs330
-rw-r--r--src/external_client.rs142
-rw-r--r--src/external_client/builder.rs602
-rw-r--r--src/external_client/config.rs54
-rw-r--r--src/external_client/group.rs1354
-rw-r--r--src/grease.rs227
-rw-r--r--src/group/ciphertext_processor.rs410
-rw-r--r--src/group/ciphertext_processor/message_key.rs57
-rw-r--r--src/group/ciphertext_processor/reuse_guard.rs133
-rw-r--r--src/group/ciphertext_processor/sender_data_key.rs360
-rw-r--r--src/group/commit.rs1601
-rw-r--r--src/group/confirmation_tag.rs150
-rw-r--r--src/group/context.rs98
-rw-r--r--src/group/epoch.rs165
-rw-r--r--src/group/exported_tree.rs51
-rw-r--r--src/group/external_commit.rs266
-rw-r--r--src/group/framing.rs741
-rw-r--r--src/group/group_info.rs95
-rw-r--r--src/group/interop_test_vectors.rs9
-rw-r--r--src/group/interop_test_vectors/framing.rs461
-rw-r--r--src/group/interop_test_vectors/passive_client.rs732
-rw-r--r--src/group/interop_test_vectors/serialization.rs169
-rw-r--r--src/group/interop_test_vectors/tree_kem.rs185
-rw-r--r--src/group/interop_test_vectors/tree_modifications.rs177
-rw-r--r--src/group/key_schedule.rs988
-rw-r--r--src/group/membership_tag.rs163
-rw-r--r--src/group/message_processor.rs1039
-rw-r--r--src/group/message_signature.rs274
-rw-r--r--src/group/message_verifier.rs680
-rw-r--r--src/group/mls_rules.rs283
-rw-r--r--src/group/mod.rs4236
-rw-r--r--src/group/padding.rs109
-rw-r--r--src/group/proposal.rs578
-rw-r--r--src/group/proposal_cache.rs4216
-rw-r--r--src/group/proposal_filter.rs23
-rw-r--r--src/group/proposal_filter/bundle.rs633
-rw-r--r--src/group/proposal_filter/filtering.rs580
-rw-r--r--src/group/proposal_filter/filtering_common.rs579
-rw-r--r--src/group/proposal_filter/filtering_lite.rs225
-rw-r--r--src/group/proposal_ref.rs226
-rw-r--r--src/group/resumption.rs299
-rw-r--r--src/group/roster.rs91
-rw-r--r--src/group/secret_tree.rs1115
-rw-r--r--src/group/snapshot.rs325
-rw-r--r--src/group/state.rs43
-rw-r--r--src/group/state_repo.rs573
-rw-r--r--src/group/state_repo_light.rs132
-rw-r--r--src/group/test_utils.rs521
-rw-r--r--src/group/transcript_hash.rs293
-rw-r--r--src/group/util.rs202
-rw-r--r--src/hash_reference.rs166
-rw-r--r--src/identity.rs182
-rw-r--r--src/identity/basic.rs99
-rw-r--r--src/iter.rs96
-rw-r--r--src/key_package/generator.rs339
-rw-r--r--src/key_package/mod.rs332
-rw-r--r--src/key_package/validator.rs39
-rw-r--r--src/lib.rs218
-rw-r--r--src/message.rs0
-rw-r--r--src/psk.rs200
-rw-r--r--src/psk/resolver.rs95
-rw-r--r--src/psk/secret.rs239
-rw-r--r--src/signer.rs357
-rw-r--r--src/storage_provider.rs14
-rw-r--r--src/storage_provider/group_state.rs43
-rw-r--r--src/storage_provider/in_memory.rs11
-rw-r--r--src/storage_provider/in_memory/group_state_storage.rs354
-rw-r--r--src/storage_provider/in_memory/key_package_storage.rs120
-rw-r--r--src/storage_provider/in_memory/psk_storage.rs83
-rw-r--r--src/storage_provider/key_package.rs5
-rw-r--r--src/storage_provider/sqlite.rs5
-rw-r--r--src/test_utils/benchmarks.rs140
-rw-r--r--src/test_utils/fuzz_tests.rs109
-rw-r--r--src/test_utils/mod.rs184
-rw-r--r--src/tree_kem/capabilities.rs5
-rw-r--r--src/tree_kem/hpke_encryption.rs172
-rw-r--r--src/tree_kem/interop_test_vectors.rs199
-rw-r--r--src/tree_kem/kem.rs699
-rw-r--r--src/tree_kem/leaf_node.rs688
-rw-r--r--src/tree_kem/leaf_node_validator.rs708
-rw-r--r--src/tree_kem/lifetime.rs119
-rw-r--r--src/tree_kem/math.rs383
-rw-r--r--src/tree_kem/mod.rs1490
-rw-r--r--src/tree_kem/node.rs577
-rw-r--r--src/tree_kem/parent_hash.rs431
-rw-r--r--src/tree_kem/path_secret.rs265
-rw-r--r--src/tree_kem/private.rs310
-rw-r--r--src/tree_kem/tree_hash.rs432
-rw-r--r--src/tree_kem/tree_index.rs505
-rw-r--r--src/tree_kem/tree_utils.rs191
-rw-r--r--src/tree_kem/tree_validator.rs356
-rw-r--r--src/tree_kem/update_path.rs274
-rw-r--r--test_utils/src/scenario_utils.rs338
-rw-r--r--tests/client_tests.rs847
-rw-r--r--webdriver.json9
122 files changed, 45313 insertions, 0 deletions
diff --git a/.cargo/config.toml b/.cargo/config.toml
new file mode 100644
index 0000000..3b8e38a
--- /dev/null
+++ b/.cargo/config.toml
@@ -0,0 +1,2 @@
+[target.'cfg(all(target_arch = "wasm32", target_os = "unknown"))']
+rustflags = ["--cfg", "mls_build_async"]
diff --git a/.cargo_vcs_info.json b/.cargo_vcs_info.json
new file mode 100644
index 0000000..e1e5d8f
--- /dev/null
+++ b/.cargo_vcs_info.json
@@ -0,0 +1,6 @@
+{
+ "git": {
+ "sha1": "f56ef7cac2ba0c0c9bcaa7e8bb1e8413b24352bf"
+ },
+ "path_in_vcs": "mls-rs"
+} \ No newline at end of file
diff --git a/Cargo.lock b/Cargo.lock
new file mode 100644
index 0000000..817b99a
--- /dev/null
+++ b/Cargo.lock
@@ -0,0 +1,1782 @@
+# This file is automatically @generated by Cargo.
+# It is not intended for manual editing.
+version = 3
+
+[[package]]
+name = "ahash"
+version = "0.8.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "91429305e9f0a25f6205c5b8e0d2db09e0708a7a6df0f42212bb56c32c8ac97a"
+dependencies = [
+ "cfg-if",
+ "once_cell",
+ "version_check",
+ "zerocopy",
+]
+
+[[package]]
+name = "aho-corasick"
+version = "1.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0"
+dependencies = [
+ "memchr",
+]
+
+[[package]]
+name = "allocator-api2"
+version = "0.2.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
+
+[[package]]
+name = "anes"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
+
+[[package]]
+name = "anstyle"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87"
+
+[[package]]
+name = "arbitrary"
+version = "1.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
+dependencies = [
+ "derive_arbitrary",
+]
+
+[[package]]
+name = "assert_matches"
+version = "1.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9b34d609dfbaf33d6889b2b7106d3ca345eacad44200913df5ba02bfd31d2ba9"
+
+[[package]]
+name = "async-trait"
+version = "0.1.77"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c980ee35e870bd1a4d2c8294d4c04d0499e67bca1e4b5cefcc693c2fa00caea9"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "atomic-polyfill"
+version = "1.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8cf2bce30dfe09ef0bfaef228b9d414faaf7e563035494d7fe092dba54b300f4"
+dependencies = [
+ "critical-section",
+]
+
+[[package]]
+name = "autocfg"
+version = "1.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
+
+[[package]]
+name = "bitflags"
+version = "2.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
+
+[[package]]
+name = "bumpalo"
+version = "3.14.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
+
+[[package]]
+name = "cast"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
+
+[[package]]
+name = "cc"
+version = "1.0.83"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
+dependencies = [
+ "libc",
+]
+
+[[package]]
+name = "cfg-if"
+version = "1.0.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
+
+[[package]]
+name = "ciborium"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926"
+dependencies = [
+ "ciborium-io",
+ "ciborium-ll",
+ "serde",
+]
+
+[[package]]
+name = "ciborium-io"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656"
+
+[[package]]
+name = "ciborium-ll"
+version = "0.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b"
+dependencies = [
+ "ciborium-io",
+ "half",
+]
+
+[[package]]
+name = "clap"
+version = "4.4.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bfaff671f6b22ca62406885ece523383b9b64022e341e53e009a62ebc47a45f2"
+dependencies = [
+ "clap_builder",
+]
+
+[[package]]
+name = "clap_builder"
+version = "4.4.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a216b506622bb1d316cd51328dce24e07bdff4a6128a47c7e7fad11878d5adbb"
+dependencies = [
+ "anstyle",
+ "clap_lex",
+]
+
+[[package]]
+name = "clap_lex"
+version = "0.6.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1"
+
+[[package]]
+name = "console_error_panic_hook"
+version = "0.1.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a06aeb73f470f66dcdbf7223caeebb85984942f22f1adb2a088cf9668146bbbc"
+dependencies = [
+ "cfg-if",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "const-oid"
+version = "0.9.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f"
+
+[[package]]
+name = "criterion"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
+dependencies = [
+ "anes",
+ "cast",
+ "ciborium",
+ "clap",
+ "criterion-plot",
+ "futures",
+ "is-terminal",
+ "itertools 0.10.5",
+ "num-traits",
+ "once_cell",
+ "oorandom",
+ "plotters",
+ "rayon",
+ "regex",
+ "serde",
+ "serde_derive",
+ "serde_json",
+ "tinytemplate",
+ "walkdir",
+]
+
+[[package]]
+name = "criterion-plot"
+version = "0.5.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
+dependencies = [
+ "cast",
+ "itertools 0.10.5",
+]
+
+[[package]]
+name = "critical-section"
+version = "1.1.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7059fff8937831a9ae6f0fe4d658ffabf58f2ca96aa9dec1c889f936f705f216"
+
+[[package]]
+name = "crossbeam-deque"
+version = "0.8.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef"
+dependencies = [
+ "cfg-if",
+ "crossbeam-epoch",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "crossbeam-epoch"
+version = "0.9.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7"
+dependencies = [
+ "autocfg",
+ "cfg-if",
+ "crossbeam-utils",
+ "memoffset",
+ "scopeguard",
+]
+
+[[package]]
+name = "crossbeam-utils"
+version = "0.8.16"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294"
+dependencies = [
+ "cfg-if",
+]
+
+[[package]]
+name = "darling"
+version = "0.20.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0209d94da627ab5605dcccf08bb18afa5009cfbef48d8a8b7d7bdbc79be25c5e"
+dependencies = [
+ "darling_core",
+ "darling_macro",
+]
+
+[[package]]
+name = "darling_core"
+version = "0.20.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "177e3443818124b357d8e76f53be906d60937f0d3a90773a664fa63fa253e621"
+dependencies = [
+ "fnv",
+ "ident_case",
+ "proc-macro2",
+ "quote",
+ "strsim",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "darling_macro"
+version = "0.20.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "836a9bbc7ad63342d6d6e7b815ccab164bc77a2d95d84bc3117a8c0d5c98e2d5"
+dependencies = [
+ "darling_core",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "debug_tree"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2d1ec383f2d844902d3c34e4253ba11ae48513cdaddc565cf1a6518db09a8e57"
+dependencies = [
+ "once_cell",
+]
+
+[[package]]
+name = "der"
+version = "0.7.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fffa369a668c8af7dbf8b5e56c9f744fbd399949ed171606040001947de40b1c"
+dependencies = [
+ "const-oid",
+ "der_derive",
+ "zeroize",
+]
+
+[[package]]
+name = "der_derive"
+version = "0.7.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5fe87ce4529967e0ba1dcf8450bab64d97dfd5010a6256187ffe2e43e6f0e049"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "derive_arbitrary"
+version = "1.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "either"
+version = "1.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
+
+[[package]]
+name = "errno"
+version = "0.3.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
+dependencies = [
+ "libc",
+ "windows-sys 0.52.0",
+]
+
+[[package]]
+name = "ext-trait"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d772df1c1a777963712fb68e014235e80863d6a91a85c4e06ba2d16243a310e5"
+dependencies = [
+ "ext-trait-proc_macros",
+]
+
+[[package]]
+name = "ext-trait-proc_macros"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ab7934152eaf26aa5aa9f7371408ad5af4c31357073c9e84c3b9d7f11ad639a"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "extension-traits"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a296e5a895621edf9fa8329c83aa1cb69a964643e36cf54d8d7a69b789089537"
+dependencies = [
+ "ext-trait",
+]
+
+[[package]]
+name = "fallible-iterator"
+version = "0.3.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649"
+
+[[package]]
+name = "fallible-streaming-iterator"
+version = "0.1.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a"
+
+[[package]]
+name = "fnv"
+version = "1.0.7"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3f9eec918d3f24069decb9af1554cad7c880e2da24a9afd88aca000531ab82c1"
+
+[[package]]
+name = "foreign-types"
+version = "0.3.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1"
+dependencies = [
+ "foreign-types-shared",
+]
+
+[[package]]
+name = "foreign-types-shared"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b"
+
+[[package]]
+name = "futures"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "da0290714b38af9b4a7b094b8a37086d1b4e61f2df9122c3cad2577669145335"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-executor",
+ "futures-io",
+ "futures-sink",
+ "futures-task",
+ "futures-util",
+]
+
+[[package]]
+name = "futures-channel"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ff4dd66668b557604244583e3e1e1eada8c5c2e96a6d0d6653ede395b78bbacb"
+dependencies = [
+ "futures-core",
+ "futures-sink",
+]
+
+[[package]]
+name = "futures-core"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eb1d22c66e66d9d72e1758f0bd7d4fd0bee04cad842ee34587d68c07e45d088c"
+
+[[package]]
+name = "futures-executor"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0f4fb8693db0cf099eadcca0efe2a5a22e4550f98ed16aba6c48700da29597bc"
+dependencies = [
+ "futures-core",
+ "futures-task",
+ "futures-util",
+]
+
+[[package]]
+name = "futures-io"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8bf34a163b5c4c52d0478a4d757da8fb65cabef42ba90515efee0f6f9fa45aaa"
+
+[[package]]
+name = "futures-macro"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "53b153fd91e4b0147f4aced87be237c98248656bb01050b96bf3ee89220a8ddb"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "futures-sink"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e36d3378ee38c2a36ad710c5d30c2911d752cb941c00c72dbabfb786a7970817"
+
+[[package]]
+name = "futures-task"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "efd193069b0ddadc69c46389b740bbccdd97203899b48d09c5f7969591d6bae2"
+
+[[package]]
+name = "futures-test"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73ad78d6c79a3c76f8bc7496240d0586e069ed6797824fdd8c41d7c42b145b8d"
+dependencies = [
+ "futures-core",
+ "futures-executor",
+ "futures-io",
+ "futures-macro",
+ "futures-sink",
+ "futures-task",
+ "futures-util",
+ "pin-project",
+ "pin-utils",
+]
+
+[[package]]
+name = "futures-util"
+version = "0.3.29"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a19526d624e703a3179b3d322efec918b6246ea0fa51d41124525f00f1cc8104"
+dependencies = [
+ "futures-channel",
+ "futures-core",
+ "futures-io",
+ "futures-sink",
+ "futures-task",
+ "memchr",
+ "pin-project-lite",
+ "pin-utils",
+ "slab",
+]
+
+[[package]]
+name = "getrandom"
+version = "0.2.11"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
+dependencies = [
+ "cfg-if",
+ "js-sys",
+ "libc",
+ "wasi",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "half"
+version = "1.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
+
+[[package]]
+name = "hashbrown"
+version = "0.14.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604"
+dependencies = [
+ "ahash",
+ "allocator-api2",
+]
+
+[[package]]
+name = "hashlink"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "692eaaf7f7607518dd3cef090f1474b61edc5301d8012f09579920df68b725ee"
+dependencies = [
+ "hashbrown",
+]
+
+[[package]]
+name = "heck"
+version = "0.4.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
+
+[[package]]
+name = "hermit-abi"
+version = "0.3.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
+
+[[package]]
+name = "hex"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
+dependencies = [
+ "serde",
+]
+
+[[package]]
+name = "ident_case"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b9e0384b61958566e926dc50660321d12159025e767c18e043daf26b70104c39"
+
+[[package]]
+name = "is-terminal"
+version = "0.4.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b"
+dependencies = [
+ "hermit-abi",
+ "rustix",
+ "windows-sys 0.48.0",
+]
+
+[[package]]
+name = "itertools"
+version = "0.10.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
+dependencies = [
+ "either",
+]
+
+[[package]]
+name = "itertools"
+version = "0.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
+dependencies = [
+ "either",
+]
+
+[[package]]
+name = "itoa"
+version = "1.0.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38"
+
+[[package]]
+name = "js-sys"
+version = "0.3.66"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca"
+dependencies = [
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "libc"
+version = "0.2.153"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd"
+
+[[package]]
+name = "libsqlite3-sys"
+version = "0.28.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0c10584274047cb335c23d3e61bcef8e323adae7c5c8c760540f73610177fc3f"
+dependencies = [
+ "cc",
+ "pkg-config",
+ "vcpkg",
+]
+
+[[package]]
+name = "linux-raw-sys"
+version = "0.4.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
+
+[[package]]
+name = "log"
+version = "0.4.20"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
+
+[[package]]
+name = "macro_rules_attribute"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cf0c9b980bf4f3a37fd7b1c066941dd1b1d0152ce6ee6e8fe8c49b9f6810d862"
+dependencies = [
+ "macro_rules_attribute-proc_macro",
+ "paste",
+]
+
+[[package]]
+name = "macro_rules_attribute-proc_macro"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "58093314a45e00c77d5c508f76e77c3396afbbc0d01506e7fae47b018bac2b1d"
+
+[[package]]
+name = "maybe-async"
+version = "0.2.10"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5cf92c10c7e361d6b99666ec1c6f9805b0bea2c3bd8c78dc6fe98ac5bd78db11"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "memchr"
+version = "2.6.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167"
+
+[[package]]
+name = "memoffset"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "mls-rs"
+version = "0.39.1"
+dependencies = [
+ "arbitrary",
+ "assert_matches",
+ "async-trait",
+ "cfg-if",
+ "criterion",
+ "debug_tree",
+ "futures",
+ "futures-test",
+ "getrandom",
+ "hex",
+ "itertools 0.12.0",
+ "maybe-async",
+ "mls-rs-codec",
+ "mls-rs-core",
+ "mls-rs-crypto-openssl",
+ "mls-rs-crypto-webcrypto",
+ "mls-rs-identity-x509",
+ "mls-rs-provider-sqlite",
+ "once_cell",
+ "portable-atomic",
+ "portable-atomic-util",
+ "rand",
+ "rand_core",
+ "rayon",
+ "safer-ffi",
+ "safer-ffi-gen",
+ "serde",
+ "serde_json",
+ "spin",
+ "thiserror",
+ "wasm-bindgen",
+ "wasm-bindgen-test",
+ "zeroize",
+]
+
+[[package]]
+name = "mls-rs-codec"
+version = "0.5.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2a76178499583f153a23c86ba45346baac05def4f7ecac23bf770f3c270dfee5"
+dependencies = [
+ "mls-rs-codec-derive",
+ "thiserror",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "mls-rs-codec-derive"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "264e7261ff5d3d0d0bbc3af01df1d294727440861bf9ce3d30e93c90191ab9f3"
+dependencies = [
+ "darling",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "mls-rs-core"
+version = "0.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "55cdac7a57ab7953e63ee1c6a05d273ca435a32c149968d11d8aba300cb4143f"
+dependencies = [
+ "arbitrary",
+ "async-trait",
+ "hex",
+ "maybe-async",
+ "mls-rs-codec",
+ "safer-ffi",
+ "safer-ffi-gen",
+ "serde",
+ "serde_bytes",
+ "thiserror",
+ "wasm-bindgen",
+ "zeroize",
+]
+
+[[package]]
+name = "mls-rs-crypto-hpke"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e3ee07057dfb114bc254d05dc163b72e2f54538a1c18434396e9af9966833881"
+dependencies = [
+ "async-trait",
+ "cfg-if",
+ "maybe-async",
+ "mls-rs-core",
+ "mls-rs-crypto-traits",
+ "thiserror",
+ "zeroize",
+]
+
+[[package]]
+name = "mls-rs-crypto-openssl"
+version = "0.9.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8976f9651ef302e2f369b7d21febf07cd610c391ea4d910889aa3f9a3e58a058"
+dependencies = [
+ "async-trait",
+ "maybe-async",
+ "mls-rs-core",
+ "mls-rs-crypto-hpke",
+ "mls-rs-crypto-traits",
+ "mls-rs-identity-x509",
+ "openssl",
+ "thiserror",
+ "zeroize",
+]
+
+[[package]]
+name = "mls-rs-crypto-traits"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "af670b03ef8413b9bf7af33cd921ac87ef4e9daa26e4eaf0200b2e8954471d43"
+dependencies = [
+ "async-trait",
+ "maybe-async",
+ "mls-rs-core",
+]
+
+[[package]]
+name = "mls-rs-crypto-webcrypto"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5956ca840ac872b6d26171f3045a700ec51358e87064e945fb577f1bab574943"
+dependencies = [
+ "async-trait",
+ "const-oid",
+ "der",
+ "js-sys",
+ "maybe-async",
+ "mls-rs-core",
+ "mls-rs-crypto-hpke",
+ "mls-rs-crypto-traits",
+ "serde",
+ "serde-wasm-bindgen",
+ "thiserror",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "web-sys",
+ "zeroize",
+]
+
+[[package]]
+name = "mls-rs-identity-x509"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8da83f76a159e3828350d3a5279e911cacf18ec83998008c922781904da1a6f8"
+dependencies = [
+ "async-trait",
+ "maybe-async",
+ "mls-rs-core",
+ "thiserror",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "mls-rs-provider-sqlite"
+version = "0.11.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c9dafe81888e7ad5b289dd37226bdc64d9878524958a14b839af055ba5e6391e"
+dependencies = [
+ "async-trait",
+ "hex",
+ "maybe-async",
+ "mls-rs-core",
+ "rand",
+ "rusqlite",
+ "thiserror",
+ "zeroize",
+]
+
+[[package]]
+name = "num-traits"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "once_cell"
+version = "1.18.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d"
+dependencies = [
+ "atomic-polyfill",
+ "critical-section",
+]
+
+[[package]]
+name = "oorandom"
+version = "11.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
+
+[[package]]
+name = "openssl"
+version = "0.10.60"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "79a4c6c3a2b158f7f8f2a2fc5a969fa3a068df6fc9dbb4a43845436e3af7c800"
+dependencies = [
+ "bitflags",
+ "cfg-if",
+ "foreign-types",
+ "libc",
+ "once_cell",
+ "openssl-macros",
+ "openssl-sys",
+]
+
+[[package]]
+name = "openssl-macros"
+version = "0.1.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "openssl-sys"
+version = "0.9.96"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3812c071ba60da8b5677cc12bcb1d42989a65553772897a7e0355545a819838f"
+dependencies = [
+ "cc",
+ "libc",
+ "pkg-config",
+ "vcpkg",
+]
+
+[[package]]
+name = "paste"
+version = "1.0.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
+
+[[package]]
+name = "pin-project"
+version = "1.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "fda4ed1c6c173e3fc7a83629421152e01d7b1f9b7f65fb301e490e8cfc656422"
+dependencies = [
+ "pin-project-internal",
+]
+
+[[package]]
+name = "pin-project-internal"
+version = "1.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4359fd9c9171ec6e8c62926d6faaf553a8dc3f64e1507e76da7911b4f6a04405"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "pin-project-lite"
+version = "0.2.13"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
+
+[[package]]
+name = "pin-utils"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
+
+[[package]]
+name = "pkg-config"
+version = "0.3.27"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964"
+
+[[package]]
+name = "plotters"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45"
+dependencies = [
+ "num-traits",
+ "plotters-backend",
+ "plotters-svg",
+ "wasm-bindgen",
+ "web-sys",
+]
+
+[[package]]
+name = "plotters-backend"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609"
+
+[[package]]
+name = "plotters-svg"
+version = "0.3.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab"
+dependencies = [
+ "plotters-backend",
+]
+
+[[package]]
+name = "portable-atomic"
+version = "1.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3bccab0e7fd7cc19f820a1c8c91720af652d0c88dc9664dd72aef2614f04af3b"
+dependencies = [
+ "critical-section",
+]
+
+[[package]]
+name = "portable-atomic-util"
+version = "0.1.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b04f7d7db7264d097636257272594f781611c464553fc76ebc899847d50c4cc"
+dependencies = [
+ "portable-atomic",
+]
+
+[[package]]
+name = "ppv-lite86"
+version = "0.2.17"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
+
+[[package]]
+name = "prettyplease"
+version = "0.1.25"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6c8646e95016a7a6c4adea95bafa8a16baab64b583356217f2c85db4a39d9a86"
+dependencies = [
+ "proc-macro2",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "proc-macro-error"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
+dependencies = [
+ "proc-macro-error-attr",
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+ "version_check",
+]
+
+[[package]]
+name = "proc-macro-error-attr"
+version = "1.0.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "version_check",
+]
+
+[[package]]
+name = "proc-macro2"
+version = "1.0.78"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e2422ad645d89c99f8f3e6b88a9fdeca7fabeac836b1002371c4367c8f984aae"
+dependencies = [
+ "unicode-ident",
+]
+
+[[package]]
+name = "quote"
+version = "1.0.35"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
+dependencies = [
+ "proc-macro2",
+]
+
+[[package]]
+name = "rand"
+version = "0.8.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
+dependencies = [
+ "libc",
+ "rand_chacha",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_chacha"
+version = "0.3.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
+dependencies = [
+ "ppv-lite86",
+ "rand_core",
+]
+
+[[package]]
+name = "rand_core"
+version = "0.6.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
+dependencies = [
+ "getrandom",
+]
+
+[[package]]
+name = "rayon"
+version = "1.8.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1"
+dependencies = [
+ "either",
+ "rayon-core",
+]
+
+[[package]]
+name = "rayon-core"
+version = "1.12.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed"
+dependencies = [
+ "crossbeam-deque",
+ "crossbeam-utils",
+]
+
+[[package]]
+name = "regex"
+version = "1.10.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-automata",
+ "regex-syntax",
+]
+
+[[package]]
+name = "regex-automata"
+version = "0.4.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f"
+dependencies = [
+ "aho-corasick",
+ "memchr",
+ "regex-syntax",
+]
+
+[[package]]
+name = "regex-syntax"
+version = "0.8.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
+
+[[package]]
+name = "rusqlite"
+version = "0.31.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "b838eba278d213a8beaf485bd313fd580ca4505a00d5871caeb1457c55322cae"
+dependencies = [
+ "bitflags",
+ "fallible-iterator",
+ "fallible-streaming-iterator",
+ "hashlink",
+ "libsqlite3-sys",
+ "smallvec",
+]
+
+[[package]]
+name = "rustix"
+version = "0.38.31"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6ea3e1a662af26cd7a3ba09c0297a31af215563ecf42817c98df621387f4e949"
+dependencies = [
+ "bitflags",
+ "errno",
+ "libc",
+ "linux-raw-sys",
+ "windows-sys 0.52.0",
+]
+
+[[package]]
+name = "ryu"
+version = "1.0.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741"
+
+[[package]]
+name = "safer-ffi"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "395ace5aff9629c7268ca8255aceb945525b2cb644015f3caec5131a6a537c11"
+dependencies = [
+ "libc",
+ "macro_rules_attribute",
+ "paste",
+ "safer_ffi-proc_macros",
+ "scopeguard",
+ "uninit",
+ "unwind_safe",
+ "with_builtin_macros",
+]
+
+[[package]]
+name = "safer-ffi-gen"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bdc3e72a8e99de537461ab5e9331d32c08dd558493e844cb70753a9890fdbc48"
+dependencies = [
+ "once_cell",
+ "safer-ffi",
+ "safer-ffi-gen-macro",
+]
+
+[[package]]
+name = "safer-ffi-gen-macro"
+version = "0.9.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "836aa8cd7b269dcdd3d81cca1ddc136aa1d2b05f30b6a34c0ff075152b2e3771"
+dependencies = [
+ "heck",
+ "proc-macro-error",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+ "thiserror",
+]
+
+[[package]]
+name = "safer_ffi-proc_macros"
+version = "0.1.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9255504d5467bae9e07d58b8de446ba6739b29bf72e1fa35b2387e30d29dcbfe"
+dependencies = [
+ "macro_rules_attribute",
+ "prettyplease",
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "same-file"
+version = "1.0.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
+dependencies = [
+ "winapi-util",
+]
+
+[[package]]
+name = "scoped-tls"
+version = "1.0.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "e1cf6437eb19a8f4a6cc0f7dca544973b0b78843adbfeb3683d1a94a0024a294"
+
+[[package]]
+name = "scopeguard"
+version = "1.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
+
+[[package]]
+name = "serde"
+version = "1.0.193"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
+dependencies = [
+ "serde_derive",
+]
+
+[[package]]
+name = "serde-wasm-bindgen"
+version = "0.6.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "17ba92964781421b6cef36bf0d7da26d201e96d84e1b10e7ae6ed416e516906d"
+dependencies = [
+ "js-sys",
+ "serde",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "serde_bytes"
+version = "0.11.14"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8b8497c313fd43ab992087548117643f6fcd935cbf36f176ffda0aacf9591734"
+dependencies = [
+ "serde",
+]
+
+[[package]]
+name = "serde_derive"
+version = "1.0.193"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "serde_json"
+version = "1.0.108"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
+dependencies = [
+ "itoa",
+ "ryu",
+ "serde",
+]
+
+[[package]]
+name = "slab"
+version = "0.4.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67"
+dependencies = [
+ "autocfg",
+]
+
+[[package]]
+name = "smallvec"
+version = "1.11.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
+
+[[package]]
+name = "spin"
+version = "0.9.8"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67"
+dependencies = [
+ "portable-atomic",
+]
+
+[[package]]
+name = "strsim"
+version = "0.10.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623"
+
+[[package]]
+name = "syn"
+version = "1.0.109"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "syn"
+version = "2.0.50"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "74f1bdc9872430ce9b75da68329d1c1746faf50ffac5f19e02b71e37ff881ffb"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "unicode-ident",
+]
+
+[[package]]
+name = "thiserror"
+version = "1.0.57"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1e45bcbe8ed29775f228095caf2cd67af7a4ccf756ebff23a306bf3e8b47b24b"
+dependencies = [
+ "thiserror-impl",
+]
+
+[[package]]
+name = "thiserror-impl"
+version = "1.0.57"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a953cb265bef375dae3de6663da4d3804eee9682ea80d8e2542529b73c531c81"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "tinytemplate"
+version = "1.2.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
+dependencies = [
+ "serde",
+ "serde_json",
+]
+
+[[package]]
+name = "unicode-ident"
+version = "1.0.12"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
+
+[[package]]
+name = "uninit"
+version = "0.5.1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3e130f2ed46ca5d8ec13c7ff95836827f92f5f5f37fd2b2bf16f33c408d98bb6"
+dependencies = [
+ "extension-traits",
+]
+
+[[package]]
+name = "unwind_safe"
+version = "0.1.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0976c77def3f1f75c4ef892a292c31c0bbe9e3d0702c63044d7c76db298171a3"
+
+[[package]]
+name = "vcpkg"
+version = "0.2.15"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426"
+
+[[package]]
+name = "version_check"
+version = "0.9.4"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
+
+[[package]]
+name = "walkdir"
+version = "2.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
+dependencies = [
+ "same-file",
+ "winapi-util",
+]
+
+[[package]]
+name = "wasi"
+version = "0.11.0+wasi-snapshot-preview1"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
+
+[[package]]
+name = "wasm-bindgen"
+version = "0.2.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e"
+dependencies = [
+ "cfg-if",
+ "wasm-bindgen-macro",
+]
+
+[[package]]
+name = "wasm-bindgen-backend"
+version = "0.2.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826"
+dependencies = [
+ "bumpalo",
+ "log",
+ "once_cell",
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-futures"
+version = "0.4.39"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac36a15a220124ac510204aec1c3e5db8a22ab06fd6706d881dc6149f8ed9a12"
+dependencies = [
+ "cfg-if",
+ "js-sys",
+ "wasm-bindgen",
+ "web-sys",
+]
+
+[[package]]
+name = "wasm-bindgen-macro"
+version = "0.2.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2"
+dependencies = [
+ "quote",
+ "wasm-bindgen-macro-support",
+]
+
+[[package]]
+name = "wasm-bindgen-macro-support"
+version = "0.2.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+ "wasm-bindgen-backend",
+ "wasm-bindgen-shared",
+]
+
+[[package]]
+name = "wasm-bindgen-shared"
+version = "0.2.89"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f"
+
+[[package]]
+name = "wasm-bindgen-test"
+version = "0.3.39"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2cf9242c0d27999b831eae4767b2a146feb0b27d332d553e605864acd2afd403"
+dependencies = [
+ "console_error_panic_hook",
+ "js-sys",
+ "scoped-tls",
+ "wasm-bindgen",
+ "wasm-bindgen-futures",
+ "wasm-bindgen-test-macro",
+]
+
+[[package]]
+name = "wasm-bindgen-test-macro"
+version = "0.3.39"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "794645f5408c9a039fd09f4d113cdfb2e7eba5ff1956b07bcf701cf4b394fe89"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "web-sys"
+version = "0.3.66"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f"
+dependencies = [
+ "js-sys",
+ "wasm-bindgen",
+]
+
+[[package]]
+name = "winapi"
+version = "0.3.9"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
+dependencies = [
+ "winapi-i686-pc-windows-gnu",
+ "winapi-x86_64-pc-windows-gnu",
+]
+
+[[package]]
+name = "winapi-i686-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
+
+[[package]]
+name = "winapi-util"
+version = "0.1.6"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
+dependencies = [
+ "winapi",
+]
+
+[[package]]
+name = "winapi-x86_64-pc-windows-gnu"
+version = "0.4.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
+
+[[package]]
+name = "windows-sys"
+version = "0.48.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
+dependencies = [
+ "windows-targets 0.48.5",
+]
+
+[[package]]
+name = "windows-sys"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
+dependencies = [
+ "windows-targets 0.52.0",
+]
+
+[[package]]
+name = "windows-targets"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
+dependencies = [
+ "windows_aarch64_gnullvm 0.48.5",
+ "windows_aarch64_msvc 0.48.5",
+ "windows_i686_gnu 0.48.5",
+ "windows_i686_msvc 0.48.5",
+ "windows_x86_64_gnu 0.48.5",
+ "windows_x86_64_gnullvm 0.48.5",
+ "windows_x86_64_msvc 0.48.5",
+]
+
+[[package]]
+name = "windows-targets"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd"
+dependencies = [
+ "windows_aarch64_gnullvm 0.52.0",
+ "windows_aarch64_msvc 0.52.0",
+ "windows_i686_gnu 0.52.0",
+ "windows_i686_msvc 0.52.0",
+ "windows_x86_64_gnu 0.52.0",
+ "windows_x86_64_gnullvm 0.52.0",
+ "windows_x86_64_msvc 0.52.0",
+]
+
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
+
+[[package]]
+name = "windows_aarch64_gnullvm"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
+
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
+
+[[package]]
+name = "windows_aarch64_msvc"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
+
+[[package]]
+name = "windows_i686_gnu"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
+
+[[package]]
+name = "windows_i686_gnu"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
+
+[[package]]
+name = "windows_i686_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
+
+[[package]]
+name = "windows_i686_msvc"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
+
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
+
+[[package]]
+name = "windows_x86_64_gnu"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
+
+[[package]]
+name = "windows_x86_64_gnullvm"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
+
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.48.5"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"
+
+[[package]]
+name = "windows_x86_64_msvc"
+version = "0.52.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
+
+[[package]]
+name = "with_builtin_macros"
+version = "0.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "a59d55032495429b87f9d69954c6c8602e4d3f3e0a747a12dea6b0b23de685da"
+dependencies = [
+ "with_builtin_macros-proc_macros",
+]
+
+[[package]]
+name = "with_builtin_macros-proc_macros"
+version = "0.0.3"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "15bd7679c15e22924f53aee34d4e448c45b674feb6129689af88593e129f8f42"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 1.0.109",
+]
+
+[[package]]
+name = "zerocopy"
+version = "0.7.28"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "7d6f15f7ade05d2a4935e34a457b936c23dc70a05cc1d97133dc99e7a3fe0f0e"
+dependencies = [
+ "zerocopy-derive",
+]
+
+[[package]]
+name = "zerocopy-derive"
+version = "0.7.28"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "dbbad221e3f78500350ecbd7dfa4e63ef945c05f4c61cb7f4d3f84cd0bba649b"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
+
+[[package]]
+name = "zeroize"
+version = "1.7.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d"
+dependencies = [
+ "serde",
+ "zeroize_derive",
+]
+
+[[package]]
+name = "zeroize_derive"
+version = "1.4.2"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "ce36e65b0d2999d2aafac989fb249189a141aee1f53c612c1f37d72631959f69"
+dependencies = [
+ "proc-macro2",
+ "quote",
+ "syn 2.0.50",
+]
diff --git a/Cargo.toml b/Cargo.toml
new file mode 100644
index 0000000..134a0d6
--- /dev/null
+++ b/Cargo.toml
@@ -0,0 +1,389 @@
+# THIS FILE IS AUTOMATICALLY GENERATED BY CARGO
+#
+# When uploading crates to the registry Cargo will automatically
+# "normalize" Cargo.toml files for maximal compatibility
+# with all versions of Cargo and also rewrite `path` dependencies
+# to registry (e.g., crates.io) dependencies.
+#
+# If you are reading this file be aware that the original Cargo.toml
+# will likely look very different (and much more reasonable).
+# See Cargo.toml.orig for the original contents.
+
+[package]
+edition = "2021"
+rust-version = "1.68.2"
+name = "mls-rs"
+version = "0.39.1"
+exclude = ["test_data"]
+description = "An implementation of Messaging Layer Security (RFC 9420)"
+homepage = "https://github.com/awslabs/mls-rs"
+readme = "README.md"
+keywords = [
+ "crypto",
+ "cryptography",
+ "security",
+ "mls",
+ "e2ee",
+]
+categories = [
+ "no-std",
+ "cryptography",
+]
+license = "Apache-2.0 OR MIT"
+repository = "https://github.com/awslabs/mls-rs"
+
+[package.metadata.docs.rs]
+features = [
+ "external_client",
+ "sqlite",
+]
+rustdoc-args = [
+ "--cfg",
+ "docsrs",
+]
+
+[[example]]
+name = "basic_usage"
+required-features = []
+
+[[example]]
+name = "x509"
+required-features = ["x509"]
+
+[[example]]
+name = "large_group"
+required-features = []
+
+[[example]]
+name = "custom"
+required-features = ["std"]
+
+[[example]]
+name = "basic_server_usage"
+required-features = ["external_client"]
+
+[[test]]
+name = "client_tests"
+required-features = ["test_util"]
+
+[[bench]]
+name = "group_add"
+harness = false
+required-features = ["benchmark_util"]
+
+[[bench]]
+name = "group_commit"
+harness = false
+required-features = ["benchmark_util"]
+
+[[bench]]
+name = "group_receive_commit"
+harness = false
+required-features = ["benchmark_util"]
+
+[[bench]]
+name = "group_application"
+harness = false
+required-features = ["benchmark_util"]
+
+[[bench]]
+name = "group_serialize"
+harness = false
+required-features = ["benchmark_util"]
+
+[dependencies.arbitrary]
+version = "1"
+features = ["derive"]
+optional = true
+
+[dependencies.cfg-if]
+version = "1"
+
+[dependencies.debug_tree]
+version = "0.4.0"
+optional = true
+
+[dependencies.hex]
+version = "^0.4.3"
+features = [
+ "serde",
+ "alloc",
+]
+optional = true
+default-features = false
+
+[dependencies.itertools]
+version = "0.12.0"
+features = ["use_alloc"]
+default-features = false
+
+[dependencies.maybe-async]
+version = "0.2.10"
+
+[dependencies.mls-rs-codec]
+version = "0.5.2"
+default-features = false
+
+[dependencies.mls-rs-core]
+version = "0.18.0"
+default-features = false
+
+[dependencies.mls-rs-crypto-openssl]
+version = "0.9.0"
+optional = true
+
+[dependencies.mls-rs-identity-x509]
+version = "0.11.0"
+optional = true
+default-features = false
+
+[dependencies.mls-rs-provider-sqlite]
+version = "0.11.0"
+optional = true
+default-features = false
+
+[dependencies.once_cell]
+version = "1.18"
+optional = true
+
+[dependencies.rayon]
+version = "1"
+optional = true
+
+[dependencies.safer-ffi]
+version = "0.1.3"
+optional = true
+default-features = false
+
+[dependencies.safer-ffi-gen]
+version = "0.9.2"
+optional = true
+default-features = false
+
+[dependencies.serde]
+version = "1.0"
+features = [
+ "alloc",
+ "derive",
+]
+optional = true
+default-features = false
+
+[dependencies.spin]
+version = "0.9.8"
+features = [
+ "mutex",
+ "spin_mutex",
+]
+default-features = false
+
+[dependencies.thiserror]
+version = "1.0.40"
+optional = true
+
+[dependencies.zeroize]
+version = "1"
+features = [
+ "alloc",
+ "zeroize_derive",
+]
+default-features = false
+
+[dev-dependencies.assert_matches]
+version = "1.5.0"
+
+[dev-dependencies.criterion]
+version = "0.5.1"
+features = [
+ "async_futures",
+ "html_reports",
+]
+default-features = false
+
+[dev-dependencies.hex]
+version = "^0.4.3"
+features = [
+ "serde",
+ "alloc",
+]
+default-features = false
+
+[dev-dependencies.rand]
+version = "0.8"
+
+[dev-dependencies.serde]
+version = "1.0"
+features = [
+ "alloc",
+ "derive",
+]
+default-features = false
+
+[dev-dependencies.serde_json]
+version = "^1.0"
+
+[features]
+arbitrary = [
+ "std",
+ "dep:arbitrary",
+ "mls-rs-core/arbitrary",
+]
+benchmark_util = [
+ "test_util",
+ "default",
+ "dep:mls-rs-crypto-openssl",
+]
+by_ref_proposal = []
+custom_proposal = []
+default = [
+ "std",
+ "rayon",
+ "rfc_compliant",
+ "tree_index",
+ "fast_serialize",
+]
+external_client = ["std"]
+fast_serialize = ["mls-rs-core/fast_serialize"]
+ffi = [
+ "dep:safer-ffi",
+ "dep:safer-ffi-gen",
+ "mls-rs-core/ffi",
+]
+fuzz_util = [
+ "test_util",
+ "default",
+ "dep:once_cell",
+ "dep:mls-rs-crypto-openssl",
+]
+grease = ["std"]
+out_of_order = ["private_message"]
+prior_epoch = []
+private_message = []
+psk = []
+rayon = [
+ "std",
+ "dep:rayon",
+]
+rfc_compliant = [
+ "state_update",
+ "private_message",
+ "custom_proposal",
+ "out_of_order",
+ "psk",
+ "x509",
+ "prior_epoch",
+ "by_ref_proposal",
+ "mls-rs-core/rfc_compliant",
+]
+secret_tree_access = []
+serde = [
+ "mls-rs-core/serde",
+ "zeroize/serde",
+ "dep:serde",
+ "dep:hex",
+]
+sqlcipher = [
+ "sqlite",
+ "mls-rs-provider-sqlite/sqlcipher",
+]
+sqlcipher-bundled = [
+ "sqlite",
+ "mls-rs-provider-sqlite/sqlcipher-bundled",
+]
+sqlite = [
+ "std",
+ "mls-rs-provider-sqlite/sqlite",
+]
+sqlite-bundled = [
+ "sqlite",
+ "mls-rs-provider-sqlite/sqlite-bundled",
+]
+state_update = []
+std = [
+ "mls-rs-core/std",
+ "mls-rs-codec/std",
+ "mls-rs-identity-x509?/std",
+ "hex/std",
+ "futures/std",
+ "itertools/use_std",
+ "safer-ffi-gen?/std",
+ "zeroize/std",
+ "dep:debug_tree",
+ "dep:thiserror",
+ "serde?/std",
+]
+test_util = []
+tree_index = []
+x509 = [
+ "mls-rs-core/x509",
+ "dep:mls-rs-identity-x509",
+]
+
+[target."cfg(mls_build_async)".dependencies.async-trait]
+version = "0.1.74"
+
+[target."cfg(mls_build_async)".dependencies.futures]
+version = "0.3.25"
+features = ["alloc"]
+default-features = false
+
+[target."cfg(mls_build_async)".dev-dependencies.futures-test]
+version = "0.3.25"
+
+[target."cfg(not(target_arch = \"wasm32\"))".dev-dependencies.criterion]
+version = "0.5.1"
+features = [
+ "async_futures",
+ "html_reports",
+]
+
+[target."cfg(not(target_arch = \"wasm32\"))".dev-dependencies.mls-rs-crypto-openssl]
+version = "0.9.0"
+
+[target."cfg(not(target_has_atomic = \"ptr\"))".dependencies.portable-atomic]
+version = "1.5.1"
+features = ["critical-section"]
+default-features = false
+
+[target."cfg(not(target_has_atomic = \"ptr\"))".dependencies.portable-atomic-util]
+version = "0.1.2"
+features = ["alloc"]
+default-features = false
+
+[target."cfg(not(target_has_atomic = \"ptr\"))".dependencies.spin]
+version = "0.9.8"
+features = ["portable_atomic"]
+default-features = false
+
+[target."cfg(target_arch = \"wasm32\")".dependencies.getrandom]
+version = "0.2"
+features = [
+ "js",
+ "custom",
+]
+default-features = false
+
+[target."cfg(target_arch = \"wasm32\")".dependencies.rand_core]
+version = "0.6"
+features = ["alloc"]
+default-features = false
+
+[target."cfg(target_arch = \"wasm32\")".dependencies.wasm-bindgen]
+version = "^0.2.79"
+
+[target."cfg(target_arch = \"wasm32\")".dev-dependencies.criterion]
+version = "0.5.1"
+features = [
+ "plotters",
+ "cargo_bench_support",
+ "async_futures",
+ "html_reports",
+]
+default-features = false
+
+[target."cfg(target_arch = \"wasm32\")".dev-dependencies.mls-rs-crypto-webcrypto]
+version = "0.4.0"
+
+[target."cfg(target_arch = \"wasm32\")".dev-dependencies.wasm-bindgen-test]
+version = "0.3.26"
+default-features = false
diff --git a/LICENSE b/LICENSE
new file mode 120000
index 0000000..4ce7dad
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1 @@
+LICENSE-apache \ No newline at end of file
diff --git a/LICENSE-apache b/LICENSE-apache
new file mode 100644
index 0000000..831fbc5
--- /dev/null
+++ b/LICENSE-apache
@@ -0,0 +1,176 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, orother modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
diff --git a/LICENSE-mit b/LICENSE-mit
new file mode 100644
index 0000000..e547c4a
--- /dev/null
+++ b/LICENSE-mit
@@ -0,0 +1,9 @@
+MIT License
+
+Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/METADATA b/METADATA
new file mode 100644
index 0000000..9e35532
--- /dev/null
+++ b/METADATA
@@ -0,0 +1,21 @@
+name: "mls-rs"
+description: "An implementation of Messaging Layer Security (RFC 9420)"
+third_party {
+ identifier {
+ type: "crates.io"
+ value: "mls-rs"
+ }
+ identifier {
+ type: "Archive"
+ value: "https://static.crates.io/crates/mls-rs/mls-rs-0.39.1.crate"
+ primary_source: true
+ }
+ version: "0.39.1"
+ # Dual-licensed, using the least restrictive per go/thirdpartylicenses#same.
+ license_type: NOTICE
+ last_upgrade_date {
+ year: 2024
+ month: 4
+ day: 9
+ }
+}
diff --git a/MODULE_LICENSE_APACHE2 b/MODULE_LICENSE_APACHE2
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/MODULE_LICENSE_APACHE2
diff --git a/OWNERS b/OWNERS
new file mode 100644
index 0000000..48bea6e
--- /dev/null
+++ b/OWNERS
@@ -0,0 +1,2 @@
+# Bug component: 688011
+include platform/prebuilts/rust:main:/OWNERS
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..d675533
--- /dev/null
+++ b/README.md
@@ -0,0 +1,71 @@
+# mls-rs &emsp; [![Build Status]][actions] [![Latest Version]][crates.io] [![API Documentation]][docs.rs] [![codecov](https://codecov.io/gh/awslabs/mls-rs/graph/badge.svg?token=6655ESMTZT)](https://codecov.io/gh/awslabs/mls-rs)
+
+[build status]: https://img.shields.io/github/checks-status/awslabs/mls-rs/main
+[actions]: https://github.com/awslabs/mls-rs/actions?query=branch%3Amain++
+[latest version]: https://img.shields.io/crates/v/mls-rs.svg
+[crates.io]: https://crates.io/crates/mls-rs
+[api documentation]: https://docs.rs/mls-rs/badge.svg
+[docs.rs]: https://docs.rs/mls-rs
+
+<!-- cargo-sync-readme start -->
+
+An implementation of the [IETF Messaging Layer Security](https://messaginglayersecurity.rocks)
+end-to-end encryption (E2EE) protocol.
+
+## What is MLS?
+
+MLS is a new IETF end-to-end encryption standard that is designed to
+provide transport agnostic, asynchronous, and highly performant
+communication between a group of clients.
+
+## MLS Protocol Features
+
+- Multi-party E2EE [group evolution](https://www.rfc-editor.org/rfc/rfc9420.html#name-cryptographic-state-and-evo)
+ via a propose-then-commit mechanism.
+- Asynchronous by design with pre-computed [key packages](https://www.rfc-editor.org/rfc/rfc9420.html#name-key-packages),
+ allowing members to be added to a group while offline.
+- Customizable credential system with built in support for X.509 certificates.
+- [Extension system](https://www.rfc-editor.org/rfc/rfc9420.html#name-extensions)
+ allowing for application specific data to be negotiated via the protocol.
+- Strong forward secrecy and post compromise security.
+- Crypto agility via support for multiple [cipher suites](https://www.rfc-editor.org/rfc/rfc9420.html#name-cipher-suites).
+- Pre-shared key support.
+- Subgroup branching.
+- Group reinitialization for breaking changes such as protocol upgrades.
+
+## Features
+
+- Easy to use client interface that can manage multiple MLS identities and groups.
+- 100% RFC 9420 conformance with support for all default credential, proposal,
+ and extension types.
+- Support for WASM builds.
+- Configurable storage for key packages, secrets and group state
+ via traits along with provided "in memory" and SQLite implementations.
+- Support for custom user proposal and extension types.
+- Ability to create user defined credentials with custom validation
+ routines that can bridge to existing credential schemes.
+- OpenSSL and Rust Crypto based cipher suite implementations.
+- Crypto agility with support for user defined cipher suite.
+- Extensive test suite including security and interop focused tests against
+ pre-computed test vectors.
+
+## Crypto Providers
+
+For cipher suite descriptions see the RFC documentation [here](https://www.rfc-editor.org/rfc/rfc9420.html#name-mls-cipher-suites)
+
+| Name | Cipher Suites | X509 Support |
+| ----------- | ------------- | --------------- |
+| OpenSSL | 1-7 | Stable |
+| AWS-LC | 1,2,3,5,7 | Stable |
+| Rust Crypto | 1,2,3 | ⚠️ Experimental |
+| Web Crypto | ⚠️ Experimental 2,5,7 | Unsupported |
+
+## Security Notice
+
+This library has been validated for conformance to the RFC 9420 specification but has not yet received a full security audit by a 3rd party.
+
+<!-- cargo-sync-readme end -->
+
+## License
+
+This library is licensed under the Apache-2.0 or the MIT License.
diff --git a/benches/group_add.rs b/benches/group_add.rs
new file mode 100644
index 0000000..e318107
--- /dev/null
+++ b/benches/group_add.rs
@@ -0,0 +1,87 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use criterion::{BatchSize, BenchmarkId, Criterion};
+use mls_rs::{
+ client_builder::MlsConfig,
+ identity::{
+ basic::{BasicCredential, BasicIdentityProvider},
+ SigningIdentity,
+ },
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ CipherSuite, CipherSuiteProvider, Client, CryptoProvider,
+};
+use mls_rs_crypto_openssl::OpensslCryptoProvider;
+
+fn bench(c: &mut Criterion) {
+ let alice = make_client("alice")
+ .create_group(Default::default())
+ .unwrap();
+
+ const MAX_ADD_COUNT: usize = 1000;
+
+ let key_packages = (0..MAX_ADD_COUNT)
+ .map(|i| {
+ make_client(&format!("bob-{i}"))
+ .generate_key_package_message()
+ .unwrap()
+ })
+ .collect::<Vec<_>>();
+
+ let mut group = c.benchmark_group("group_add");
+
+ std::iter::successors(Some(1), |&i| Some(i * 10))
+ .take_while(|&i| i <= MAX_ADD_COUNT)
+ .for_each(|size| {
+ group.bench_with_input(BenchmarkId::from_parameter(size), &size, |b, &size| {
+ b.iter_batched_ref(
+ || alice.clone(),
+ |alice| {
+ key_packages[..size]
+ .iter()
+ .cloned()
+ .fold(alice.commit_builder(), |builder, key_package| {
+ builder.add_member(key_package).unwrap()
+ })
+ .build()
+ .unwrap();
+ },
+ BatchSize::SmallInput,
+ );
+ });
+ });
+
+ group.finish();
+}
+
+criterion::criterion_group!(benches, bench);
+criterion::criterion_main!(benches);
+
+fn make_client(name: &str) -> Client<impl MlsConfig> {
+ let crypto_provider = OpensslCryptoProvider::new();
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+
+ let (secret_key, public_key) = crypto_provider
+ .cipher_suite_provider(cipher_suite)
+ .unwrap()
+ .signature_key_generate()
+ .unwrap();
+
+ Client::builder()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider)
+ .mls_rules(
+ DefaultMlsRules::new()
+ .with_commit_options(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .signing_identity(
+ SigningIdentity::new(
+ BasicCredential::new(name.as_bytes().to_vec()).into_credential(),
+ public_key,
+ ),
+ secret_key,
+ cipher_suite,
+ )
+ .build()
+}
diff --git a/benches/group_application.rs b/benches/group_application.rs
new file mode 100644
index 0000000..3e829c7
--- /dev/null
+++ b/benches/group_application.rs
@@ -0,0 +1,48 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use criterion::{BatchSize, BenchmarkId, Criterion, Throughput};
+use mls_rs::test_utils::benchmarks::load_group_states;
+use mls_rs::CipherSuite;
+use rand::RngCore;
+
+fn bench(c: &mut Criterion) {
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+ let group_states = load_group_states(cipher_suite).pop().unwrap();
+
+ let mut bytes = vec![0; 1000000];
+ rand::thread_rng().fill_bytes(&mut bytes);
+
+ let bytes = &bytes;
+ let mut n = 100;
+ let mut bench_group = c.benchmark_group("group_application");
+
+ while n <= 1000000 {
+ bench_group.throughput(Throughput::Bytes(n as u64));
+ bench_group.bench_with_input(
+ BenchmarkId::new(format!("{cipher_suite:?}"), n),
+ &n,
+ |b, _| {
+ b.iter_batched_ref(
+ || group_states.clone(),
+ move |group_states| {
+ let msg = group_states
+ .sender
+ .encrypt_application_message(&bytes[..n], vec![])
+ .unwrap();
+
+ group_states.receiver.process_incoming_message(msg).unwrap();
+ },
+ BatchSize::SmallInput,
+ )
+ },
+ );
+
+ n *= 10;
+ }
+ bench_group.finish();
+}
+
+criterion::criterion_group!(benches, bench);
+criterion::criterion_main!(benches);
diff --git a/benches/group_commit.rs b/benches/group_commit.rs
new file mode 100644
index 0000000..4817ad3
--- /dev/null
+++ b/benches/group_commit.rs
@@ -0,0 +1,29 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use criterion::{BatchSize, BenchmarkId, Criterion};
+use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite};
+
+fn bench(c: &mut Criterion) {
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+ let group_states = load_group_states(cipher_suite);
+ let mut bench_group = c.benchmark_group("group_commit");
+
+ for (i, group_states) in group_states.into_iter().enumerate() {
+ bench_group.bench_with_input(
+ BenchmarkId::new(format!("{cipher_suite:?}"), i),
+ &i,
+ |b, _| {
+ b.iter_batched_ref(
+ || group_states.sender.clone(),
+ move |sender| sender.commit(vec![]).unwrap(),
+ BatchSize::SmallInput,
+ )
+ },
+ );
+ }
+}
+
+criterion::criterion_group!(benches, bench);
+criterion::criterion_main!(benches);
diff --git a/benches/group_receive_commit.rs b/benches/group_receive_commit.rs
new file mode 100644
index 0000000..9a8a765
--- /dev/null
+++ b/benches/group_receive_commit.rs
@@ -0,0 +1,39 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use criterion::{BatchSize, BenchmarkId, Criterion};
+use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite};
+
+fn bench(c: &mut Criterion) {
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+ let group_states = load_group_states(cipher_suite);
+ let mut bench_group = c.benchmark_group("group_receive_commit");
+
+ for (i, mut group_states) in group_states.into_iter().enumerate() {
+ bench_group.bench_with_input(
+ BenchmarkId::new(format!("{cipher_suite:?}"), i),
+ &i,
+ |b, _| {
+ b.iter_batched_ref(
+ || {
+ let commit = group_states.sender.commit(Vec::new()).unwrap();
+ group_states.sender.clear_pending_commit();
+ (commit, group_states.receiver.clone())
+ },
+ move |(commit, receiver)| {
+ receiver
+ .process_incoming_message(commit.commit_message.clone())
+ .unwrap();
+ },
+ BatchSize::SmallInput,
+ )
+ },
+ );
+ }
+
+ bench_group.finish();
+}
+
+criterion::criterion_group!(benches, bench);
+criterion::criterion_main!(benches);
diff --git a/benches/group_serialize.rs b/benches/group_serialize.rs
new file mode 100644
index 0000000..a8b69e6
--- /dev/null
+++ b/benches/group_serialize.rs
@@ -0,0 +1,30 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::{test_utils::benchmarks::load_group_states, CipherSuite};
+
+use criterion::{BenchmarkId, Criterion};
+
+fn bench_serialize(c: &mut Criterion) {
+ use criterion::BatchSize;
+
+ let cs = CipherSuite::CURVE25519_AES128;
+ let group_states = load_group_states(cs);
+ let mut bench_group = c.benchmark_group("group_serialize");
+
+ for (i, group_states) in group_states.into_iter().enumerate() {
+ bench_group.bench_with_input(BenchmarkId::new(format!("{cs:?}"), i), &i, |b, _| {
+ b.iter_batched_ref(
+ || group_states.sender.clone(),
+ move |sender| sender.write_to_storage().unwrap(),
+ BatchSize::SmallInput,
+ )
+ });
+ }
+
+ bench_group.finish();
+}
+
+criterion::criterion_group!(benches, bench_serialize);
+criterion::criterion_main!(benches);
diff --git a/cargo_embargo.json b/cargo_embargo.json
new file mode 100644
index 0000000..99a56a9
--- /dev/null
+++ b/cargo_embargo.json
@@ -0,0 +1,16 @@
+{
+ "run_cargo": false,
+ "features": [
+ "std",
+ "rayon",
+ "state_update",
+ "private_message",
+ "custom_proposal",
+ "out_of_order",
+ "psk",
+ "prior_epoch",
+ "by_ref_proposal",
+ "tree_index",
+ "fast_serialize"
+ ]
+}
diff --git a/examples/basic_server_usage.rs b/examples/basic_server_usage.rs
new file mode 100644
index 0000000..fba71da
--- /dev/null
+++ b/examples/basic_server_usage.rs
@@ -0,0 +1,193 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::{
+ client_builder::MlsConfig,
+ error::MlsError,
+ external_client::{
+ builder::MlsConfig as ExternalMlsConfig, ExternalClient, ExternalReceivedMessage,
+ ExternalSnapshot,
+ },
+ group::{CachedProposal, ReceivedMessage},
+ identity::{
+ basic::{BasicCredential, BasicIdentityProvider},
+ SigningIdentity,
+ },
+ CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, MlsMessage,
+};
+use mls_rs_core::crypto::SignatureSecretKey;
+
+const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
+
+fn cipher_suite_provider() -> impl CipherSuiteProvider {
+ crypto_provider()
+ .cipher_suite_provider(CIPHERSUITE)
+ .unwrap()
+}
+
+fn crypto_provider() -> impl CryptoProvider + Clone {
+ mls_rs_crypto_openssl::OpensslCryptoProvider::default()
+}
+
+#[derive(Default)]
+struct BasicServer {
+ group_state: Vec<u8>,
+ cached_proposals: Vec<Vec<u8>>,
+ message_queue: Vec<Vec<u8>>,
+}
+
+impl BasicServer {
+ // Client uploads group data after creating the group
+ fn create_group(group_info: &[u8]) -> Result<Self, MlsError> {
+ let server = make_server();
+ let group_info = MlsMessage::from_bytes(group_info)?;
+
+ let group = server.observe_group(group_info, None)?;
+
+ Ok(Self {
+ group_state: group.snapshot().to_bytes()?,
+ ..Default::default()
+ })
+ }
+
+ // Client uploads a proposal. This doesn't change the server's group state, so clients can
+ // upload prposals without synchronization (`cached_proposals` and `message_queue` collect
+ // all proposals in any order).
+ fn upload_proposal(&mut self, proposal: Vec<u8>) -> Result<(), MlsError> {
+ let server = make_server();
+ let group_state = ExternalSnapshot::from_bytes(&self.group_state)?;
+ let mut group = server.load_group(group_state)?;
+
+ let proposal_msg = MlsMessage::from_bytes(&proposal)?;
+ let res = group.process_incoming_message(proposal_msg)?;
+
+ let ExternalReceivedMessage::Proposal(proposal_desc) = res else {
+ panic!("expected proposal message!")
+ };
+
+ self.cached_proposals
+ .push(proposal_desc.cached_proposal().to_bytes()?);
+
+ self.message_queue.push(proposal);
+
+ Ok(())
+ }
+
+ // Client uploads a commit. This changes the server's group state, so in a real application,
+ // it must be synchronized. That is, only one `upload_commit` operation can succeed.
+ fn upload_commit(&mut self, commit: Vec<u8>) -> Result<(), MlsError> {
+ let server = make_server();
+ let group_state = ExternalSnapshot::from_bytes(&self.group_state)?;
+ let mut group = server.load_group(group_state)?;
+
+ for p in &self.cached_proposals {
+ group.insert_proposal(CachedProposal::from_bytes(p)?);
+ }
+
+ let commit_msg = MlsMessage::from_bytes(&commit)?;
+ let res = group.process_incoming_message(commit_msg)?;
+
+ let ExternalReceivedMessage::Commit(_commit_desc) = res else {
+ panic!("expected commit message!")
+ };
+
+ self.cached_proposals = Vec::new();
+ self.group_state = group.snapshot().to_bytes()?;
+ self.message_queue.push(commit);
+
+ Ok(())
+ }
+
+ pub fn download_messages(&self, i: usize) -> &[Vec<u8>] {
+ &self.message_queue[i..]
+ }
+}
+
+fn make_server() -> ExternalClient<impl ExternalMlsConfig> {
+ ExternalClient::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(crypto_provider())
+ .build()
+}
+
+fn make_client(name: &str) -> Result<Client<impl MlsConfig>, MlsError> {
+ let (secret, signing_identity) = make_identity(name);
+
+ Ok(Client::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(crypto_provider())
+ .signing_identity(signing_identity, secret, CIPHERSUITE)
+ .build())
+}
+
+fn make_identity(name: &str) -> (SignatureSecretKey, SigningIdentity) {
+ let cipher_suite = cipher_suite_provider();
+ let (secret, public) = cipher_suite.signature_key_generate().unwrap();
+
+ // Create a basic credential for the session.
+ // NOTE: BasicCredential is for demonstration purposes and not recommended for production.
+ // X.509 credentials are recommended.
+ let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
+ let identity = SigningIdentity::new(basic_identity.into_credential(), public);
+
+ (secret, identity)
+}
+
+fn main() -> Result<(), MlsError> {
+ // Create clients for Alice and Bob
+ let alice = make_client("alice")?;
+ let bob = make_client("bob")?;
+
+ // Alice creates a group with bob
+ let mut alice_group = alice.create_group(ExtensionList::default())?;
+ let bob_key_package = bob.generate_key_package_message()?;
+
+ let welcome = &alice_group
+ .commit_builder()
+ .add_member(bob_key_package)?
+ .build()?
+ .welcome_messages[0];
+
+ let (mut bob_group, _) = bob.join_group(None, welcome)?;
+ alice_group.apply_pending_commit()?;
+
+ // Server starts observing Alice's group
+ let group_info = alice_group.group_info_message(true)?.to_bytes()?;
+ let mut server = BasicServer::create_group(&group_info)?;
+
+ // Bob uploads a proposal
+ let proposal = bob_group
+ .propose_group_context_extensions(ExtensionList::new(), Vec::new())?
+ .to_bytes()?;
+
+ server.upload_proposal(proposal)?;
+
+ // Alice downloads all messages and commits
+ for m in server.download_messages(0) {
+ alice_group.process_incoming_message(MlsMessage::from_bytes(m)?)?;
+ }
+
+ let commit = alice_group
+ .commit(b"changing extensions".to_vec())?
+ .commit_message
+ .to_bytes()?;
+
+ server.upload_commit(commit)?;
+
+ // Alice waits for an ACK from the server and applies the commit
+ alice_group.apply_pending_commit()?;
+
+ // Bob downloads the commit
+ let message = server.download_messages(1).first().unwrap();
+
+ let res = bob_group.process_incoming_message(MlsMessage::from_bytes(message)?)?;
+
+ let ReceivedMessage::Commit(commit_desc) = res else {
+ panic!("expected commit message")
+ };
+
+ assert_eq!(&commit_desc.authenticated_data, b"changing extensions");
+
+ Ok(())
+}
diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs
new file mode 100644
index 0000000..c49af8f
--- /dev/null
+++ b/examples/basic_usage.rs
@@ -0,0 +1,79 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::{
+ client_builder::MlsConfig,
+ error::MlsError,
+ identity::{
+ basic::{BasicCredential, BasicIdentityProvider},
+ SigningIdentity,
+ },
+ CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList,
+};
+
+const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
+
+fn make_client<P: CryptoProvider + Clone>(
+ crypto_provider: P,
+ name: &str,
+) -> Result<Client<impl MlsConfig>, MlsError> {
+ let cipher_suite = crypto_provider.cipher_suite_provider(CIPHERSUITE).unwrap();
+
+ // Generate a signature key pair.
+ let (secret, public) = cipher_suite.signature_key_generate().unwrap();
+
+ // Create a basic credential for the session.
+ // NOTE: BasicCredential is for demonstration purposes and not recommended for production.
+ // X.509 credentials are recommended.
+ let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
+ let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public);
+
+ Ok(Client::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(crypto_provider)
+ .signing_identity(signing_identity, secret, CIPHERSUITE)
+ .build())
+}
+
+fn main() -> Result<(), MlsError> {
+ let crypto_provider = mls_rs_crypto_openssl::OpensslCryptoProvider::default();
+
+ // Create clients for Alice and Bob
+ let alice = make_client(crypto_provider.clone(), "alice")?;
+ let bob = make_client(crypto_provider.clone(), "bob")?;
+
+ // Alice creates a new group.
+ let mut alice_group = alice.create_group(ExtensionList::default())?;
+
+ // Bob generates a key package that Alice needs to add Bob to the group.
+ let bob_key_package = bob.generate_key_package_message()?;
+
+ // Alice issues a commit that adds Bob to the group.
+ let alice_commit = alice_group
+ .commit_builder()
+ .add_member(bob_key_package)?
+ .build()?;
+
+ // Alice confirms that the commit was accepted by the group so it can be applied locally.
+ // This would normally happen after a server confirmed your commit was accepted and can
+ // be broadcasted.
+ alice_group.apply_pending_commit()?;
+
+ // Bob joins the group with the welcome message created as part of Alice's commit.
+ let (mut bob_group, _) = bob.join_group(None, &alice_commit.welcome_messages[0])?;
+
+ // Alice encrypts an application message to Bob.
+ let msg = alice_group.encrypt_application_message(b"hello world", Default::default())?;
+
+ // Bob decrypts the application message from Alice.
+ let msg = bob_group.process_incoming_message(msg)?;
+
+ println!("Received message: {:?}", msg);
+
+ // Alice and bob write the group state to their configured storage engine
+ alice_group.write_to_storage()?;
+ bob_group.write_to_storage()?;
+
+ Ok(())
+}
diff --git a/examples/custom.rs b/examples/custom.rs
new file mode 100644
index 0000000..f5d9327
--- /dev/null
+++ b/examples/custom.rs
@@ -0,0 +1,448 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// The example shows how how to create an MLS extension implementing an access control policy
+/// based on the concept of users, similar to
+/// https://bifurcation.github.io/ietf-mimi-protocol/draft-ralston-mimi-protocol.html.
+///
+/// A user, e.g. "bob@b.example", owns zero or more MLS members, e.g. Bob's tablet and PC.
+/// Users do not have MLS cryptographic state, while MLS members do. At any point in time,
+/// the MLS group has a fixed set of users and for each user, zero or more MLS members they
+/// own. Each user also has a role, e.g. a regular user or moderator (which may possibly change
+/// over time).
+///
+/// The goal is to implement the following rule:
+/// 1. Each MLS member belongs to a user in the group.
+///
+/// To this end, we implement the following:
+/// * A GroupContext extension containing the current list of users. MLS guarantees agreement
+/// on the list.
+/// * An AddUser proposal that modifies the user list.
+/// * An MLS credential type for MLS members with the owning user's public key and signature.
+/// When MLS members join using MLS Add proposals, the signature is verified.
+/// * Proposal validation rules that enforce 1. above.
+///
+use assert_matches::assert_matches;
+use mls_rs::{
+ client_builder::{MlsConfig, PaddingMode},
+ error::MlsError,
+ group::{
+ proposal::{MlsCustomProposal, Proposal},
+ Roster, Sender,
+ },
+ mls_rules::{
+ CommitDirection, CommitOptions, CommitSource, EncryptionOptions, ProposalBundle,
+ ProposalSource,
+ },
+ CipherSuite, CipherSuiteProvider, Client, CryptoProvider, ExtensionList, IdentityProvider,
+ MlsRules,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{SignaturePublicKey, SignatureSecretKey},
+ error::IntoAnyError,
+ extension::{ExtensionError, ExtensionType, MlsCodecExtension},
+ group::ProposalType,
+ identity::{Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity},
+ time::MlsTime,
+};
+
+use std::fmt::Display;
+
+const CIPHER_SUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
+
+const ROSTER_EXTENSION_V1: ExtensionType = ExtensionType::new(65000);
+const ADD_USER_PROPOSAL_V1: ProposalType = ProposalType::new(65001);
+const CREDENTIAL_V1: CredentialType = CredentialType::new(65002);
+
+fn crypto() -> impl CryptoProvider + Clone {
+ mls_rs_crypto_openssl::OpensslCryptoProvider::new()
+}
+
+fn cipher_suite() -> impl CipherSuiteProvider {
+ crypto().cipher_suite_provider(CIPHER_SUITE).unwrap()
+}
+
+#[derive(MlsSize, MlsDecode, MlsEncode)]
+#[repr(u8)]
+enum UserRole {
+ Regular = 1u8,
+ Moderator = 2u8,
+}
+
+#[derive(MlsSize, MlsDecode, MlsEncode)]
+struct UserCredential {
+ name: String,
+ role: UserRole,
+ public_key: SignaturePublicKey,
+}
+
+#[derive(MlsSize, MlsDecode, MlsEncode)]
+struct MemberCredential {
+ name: String,
+ user_public_key: SignaturePublicKey, // Identifies the user
+ signature: Vec<u8>,
+}
+
+#[derive(MlsSize, MlsEncode)]
+struct MemberCredentialTBS<'a> {
+ name: &'a str,
+ user_public_key: &'a SignaturePublicKey,
+ public_key: &'a SignaturePublicKey,
+}
+
+/// The roster will be stored in the custom RosterExtension, an extension in the MLS GroupContext
+#[derive(MlsSize, MlsDecode, MlsEncode)]
+struct RosterExtension {
+ roster: Vec<UserCredential>,
+}
+
+impl MlsCodecExtension for RosterExtension {
+ fn extension_type() -> ExtensionType {
+ ROSTER_EXTENSION_V1
+ }
+}
+
+/// The custom AddUser proposal will be used to update the RosterExtension
+#[derive(MlsSize, MlsDecode, MlsEncode)]
+struct AddUserProposal {
+ new_user: UserCredential,
+}
+
+impl MlsCustomProposal for AddUserProposal {
+ fn proposal_type() -> ProposalType {
+ ADD_USER_PROPOSAL_V1
+ }
+}
+
+/// MlsRules tell MLS how to handle our custom proposal
+#[derive(Debug, Clone, Copy)]
+struct CustomMlsRules;
+
+impl MlsRules for CustomMlsRules {
+ type Error = CustomError;
+
+ fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ extension_list: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ // Find our extension
+ let mut roster: RosterExtension =
+ extension_list.get_as().ok().flatten().ok_or(CustomError)?;
+
+ // Find AddUser proposals
+ let add_user_proposals = proposals
+ .custom_proposals()
+ .iter()
+ .filter(|p| p.proposal.proposal_type() == ADD_USER_PROPOSAL_V1);
+
+ for add_user_info in add_user_proposals {
+ let add_user = AddUserProposal::from_custom_proposal(&add_user_info.proposal)?;
+
+ // Eventually we should check for duplicates
+ roster.roster.push(add_user.new_user);
+ }
+
+ // Issue GroupContextExtensions proposal to modify our roster (eventually we don't have to do this if there were no AddUser proposals)
+ let mut new_extensions = extension_list.clone();
+ new_extensions.set_from(roster)?;
+ let gce_proposal = Proposal::GroupContextExtensions(new_extensions);
+ proposals.add(gce_proposal, Sender::Member(0), ProposalSource::Local);
+
+ Ok(proposals)
+ }
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(CommitOptions::new())
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(EncryptionOptions::new(false, PaddingMode::None))
+ }
+}
+
+// The IdentityProvider will tell MLS how to validate members' identities. We will use custom identity
+// type to store our User structs.
+impl MlsCredential for MemberCredential {
+ type Error = CustomError;
+
+ fn credential_type() -> CredentialType {
+ CREDENTIAL_V1
+ }
+
+ fn into_credential(self) -> Result<Credential, Self::Error> {
+ Ok(Credential::Custom(CustomCredential::new(
+ Self::credential_type(),
+ self.mls_encode_to_vec()?,
+ )))
+ }
+}
+
+#[derive(Debug, Clone, Copy)]
+struct CustomIdentityProvider;
+
+impl IdentityProvider for CustomIdentityProvider {
+ type Error = CustomError;
+
+ fn validate_member(
+ &self,
+ signing_identity: &SigningIdentity,
+ _: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ let Some(extensions) = extensions else {
+ return Ok(());
+ };
+
+ let roster = extensions
+ .get_as::<RosterExtension>()
+ .ok()
+ .flatten()
+ .ok_or(CustomError)?;
+
+ // Retrieve the MemberCredential from the MLS credential
+ let Credential::Custom(custom) = &signing_identity.credential else {
+ return Err(CustomError);
+ };
+
+ if custom.credential_type != CREDENTIAL_V1 {
+ return Err(CustomError);
+ }
+
+ let member = MemberCredential::mls_decode(&mut &*custom.data)?;
+
+ // Validate the MemberCredential
+
+ let tbs = MemberCredentialTBS {
+ name: &member.name,
+ user_public_key: &member.user_public_key,
+ public_key: &signing_identity.signature_key,
+ }
+ .mls_encode_to_vec()?;
+
+ cipher_suite()
+ .verify(&member.user_public_key, &member.signature, &tbs)
+ .map_err(|_| CustomError)?;
+
+ let user_in_roster = roster
+ .roster
+ .iter()
+ .any(|u| u.public_key == member.user_public_key);
+
+ if !user_in_roster {
+ return Err(CustomError);
+ }
+
+ Ok(())
+ }
+
+ fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ _: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ Ok(signing_identity.mls_encode_to_vec()?)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ vec![CREDENTIAL_V1]
+ }
+
+ fn valid_successor(
+ &self,
+ _: &SigningIdentity,
+ _: &SigningIdentity,
+ _: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(true)
+ }
+
+ fn validate_external_sender(
+ &self,
+ _: &SigningIdentity,
+ _: Option<MlsTime>,
+ _: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ Ok(())
+ }
+}
+
+// Convenience structs to create users and members
+
+struct User {
+ credential: UserCredential,
+ signer: SignatureSecretKey,
+}
+
+impl User {
+ fn new(name: &str, role: UserRole) -> Result<Self, CustomError> {
+ let (signer, public_key) = cipher_suite()
+ .signature_key_generate()
+ .map_err(|_| CustomError)?;
+
+ let credential = UserCredential {
+ name: name.into(),
+ role,
+ public_key,
+ };
+
+ Ok(Self { credential, signer })
+ }
+}
+
+struct Member {
+ credential: MemberCredential,
+ public_key: SignaturePublicKey,
+ signer: SignatureSecretKey,
+}
+
+impl Member {
+ fn new(name: &str, user: &User) -> Result<Self, CustomError> {
+ let (signer, public_key) = cipher_suite()
+ .signature_key_generate()
+ .map_err(|_| CustomError)?;
+
+ let tbs = MemberCredentialTBS {
+ name,
+ user_public_key: &user.credential.public_key,
+ public_key: &public_key,
+ }
+ .mls_encode_to_vec()?;
+
+ let signature = cipher_suite()
+ .sign(&user.signer, &tbs)
+ .map_err(|_| CustomError)?;
+
+ let credential = MemberCredential {
+ name: name.into(),
+ user_public_key: user.credential.public_key.clone(),
+ signature,
+ };
+
+ Ok(Self {
+ credential,
+ signer,
+ public_key,
+ })
+ }
+}
+
+// Set up Client to use our custom providers
+fn make_client(member: Member) -> Result<Client<impl MlsConfig>, CustomError> {
+ let mls_credential = member.credential.into_credential()?;
+ let signing_identity = SigningIdentity::new(mls_credential, member.public_key);
+
+ Ok(Client::builder()
+ .identity_provider(CustomIdentityProvider)
+ .mls_rules(CustomMlsRules)
+ .custom_proposal_type(ADD_USER_PROPOSAL_V1)
+ .extension_type(ROSTER_EXTENSION_V1)
+ .crypto_provider(crypto())
+ .signing_identity(signing_identity, member.signer, CIPHER_SUITE)
+ .build())
+}
+
+fn main() -> Result<(), CustomError> {
+ let alice = User::new("alice", UserRole::Moderator)?;
+ let bob = User::new("bob", UserRole::Regular)?;
+
+ let alice_tablet = Member::new("alice tablet", &alice)?;
+ let alice_pc = Member::new("alice pc", &alice)?;
+ let bob_tablet = Member::new("bob tablet", &bob)?;
+
+ // Alice creates the group with our RosterExtension containing her user
+ let mut context_extensions = ExtensionList::new();
+ let roster = vec![alice.credential];
+ context_extensions.set_from(RosterExtension { roster })?;
+
+ let mut alice_tablet_group = make_client(alice_tablet)?.create_group(context_extensions)?;
+
+ // Alice can add her other device
+ let alice_pc_client = make_client(alice_pc)?;
+ let key_package = alice_pc_client.generate_key_package_message()?;
+
+ let welcome = alice_tablet_group
+ .commit_builder()
+ .add_member(key_package)?
+ .build()?
+ .welcome_messages
+ .remove(0);
+
+ alice_tablet_group.apply_pending_commit()?;
+ let (mut alice_pc_group, _) = alice_pc_client.join_group(None, &welcome)?;
+
+ // Alice cannot add bob's devices yet
+ let bob_tablet_client = make_client(bob_tablet)?;
+ let key_package = bob_tablet_client.generate_key_package_message()?;
+
+ let res = alice_tablet_group
+ .commit_builder()
+ .add_member(key_package.clone())?
+ .build();
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+
+ // Alice can add bob's user and device
+ let add_bob = AddUserProposal {
+ new_user: bob.credential,
+ };
+
+ let commit = alice_tablet_group
+ .commit_builder()
+ .custom_proposal(add_bob.to_custom_proposal()?)
+ .add_member(key_package)?
+ .build()?;
+
+ bob_tablet_client.join_group(None, &commit.welcome_messages[0])?;
+ alice_tablet_group.apply_pending_commit()?;
+ alice_pc_group.process_incoming_message(commit.commit_message)?;
+
+ Ok(())
+}
+
+#[derive(Debug, thiserror::Error)]
+struct CustomError;
+
+impl IntoAnyError for CustomError {
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(Box::new(self))
+ }
+}
+
+impl Display for CustomError {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.write_str("Custom Error")
+ }
+}
+
+impl From<MlsError> for CustomError {
+ fn from(_: MlsError) -> Self {
+ Self
+ }
+}
+
+impl From<mls_rs_codec::Error> for CustomError {
+ fn from(_: mls_rs_codec::Error) -> Self {
+ Self
+ }
+}
+
+impl From<ExtensionError> for CustomError {
+ fn from(_: ExtensionError) -> Self {
+ Self
+ }
+}
diff --git a/examples/large_group.rs b/examples/large_group.rs
new file mode 100644
index 0000000..c437743
--- /dev/null
+++ b/examples/large_group.rs
@@ -0,0 +1,183 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::{
+ client_builder::MlsConfig,
+ error::MlsError,
+ identity::{
+ basic::{BasicCredential, BasicIdentityProvider},
+ SigningIdentity,
+ },
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ CipherSuite, CipherSuiteProvider, Client, CryptoProvider, Group,
+};
+
+const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
+const GROUP_SIZES: [usize; 8] = [2, 3, 5, 9, 17, 33, 65, 129];
+
+enum Case {
+ Best,
+ Worst,
+}
+
+fn bench_commit_size<P: CryptoProvider + Clone>(
+ case_group: Case,
+ crypto_provider: &P,
+) -> Result<(Vec<usize>, Vec<usize>), MlsError> {
+ let mut small_bench = vec![];
+ let mut large_bench = vec![];
+
+ for num_groups in GROUP_SIZES.iter().copied() {
+ let (small_commit, large_commit) = match case_group {
+ Case::Best => {
+ let mut groups = make_groups_best_case(num_groups, crypto_provider)?;
+ let small_commit = groups[num_groups - 1].commit(vec![])?.commit_message;
+ let large_commit = groups[0].commit(vec![])?.commit_message;
+ (small_commit, large_commit)
+ }
+ Case::Worst => {
+ let mut groups = make_groups_worst_case(num_groups, crypto_provider)?;
+ let small_commit = groups[num_groups - 1].commit(vec![])?.commit_message;
+ let large_commit = groups[0].commit(vec![])?.commit_message;
+ (small_commit, large_commit)
+ }
+ };
+
+ small_bench.push(small_commit.to_bytes()?.len());
+ large_bench.push(large_commit.to_bytes()?.len());
+ }
+
+ Ok((small_bench, large_bench))
+}
+
+// Bob[0] crates a group. Repeat for `i=0` to `num_groups - 1` times : Bob[i] adds Bob[i+1]
+fn make_groups_best_case<P: CryptoProvider + Clone>(
+ num_groups: usize,
+ crypto_provider: &P,
+) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
+ let bob_client = make_client(crypto_provider.clone(), &make_name(0))?;
+
+ let bob_group = bob_client.create_group(Default::default())?;
+
+ let mut groups = vec![bob_group];
+
+ for i in 0..(num_groups - 1) {
+ let bob_client = make_client(crypto_provider.clone(), &make_name(i + 1))?;
+
+ // The new client generates a key package.
+ let bob_kpkg = bob_client.generate_key_package_message()?;
+
+ // Last group sends a commit adding the new client to the group.
+ let commit = groups
+ .last_mut()
+ .unwrap()
+ .commit_builder()
+ .add_member(bob_kpkg)?
+ .build()?;
+
+ // All other groups process the commit.
+ for group in groups.iter_mut().rev().skip(1) {
+ group.process_incoming_message(commit.commit_message.clone())?;
+ }
+
+ // The last group applies the generated commit.
+ groups.last_mut().unwrap().apply_pending_commit()?;
+
+ // The new member joins.
+ let (bob_group, _info) = bob_client.join_group(None, &commit.welcome_messages[0])?;
+
+ groups.push(bob_group);
+ }
+
+ Ok(groups)
+}
+
+// Alice creates a group by adding `num_groups - 1` clients in one commit.
+fn make_groups_worst_case<P: CryptoProvider + Clone>(
+ num_groups: usize,
+ crypto_provider: &P,
+) -> Result<Vec<Group<impl MlsConfig>>, MlsError> {
+ let alice_client = make_client(crypto_provider.clone(), &make_name(0))?;
+
+ let mut alice_group = alice_client.create_group(Default::default())?;
+
+ let bob_clients = (0..(num_groups - 1))
+ .map(|i| make_client(crypto_provider.clone(), &make_name(i + 1)))
+ .collect::<Result<Vec<_>, _>>()?;
+
+ // Alice adds all Bob's clients in a single commit.
+ let mut commit_builder = alice_group.commit_builder();
+
+ for bob_client in &bob_clients {
+ let bob_kpkg = bob_client.generate_key_package_message()?;
+ commit_builder = commit_builder.add_member(bob_kpkg)?;
+ }
+
+ let welcome_message = &commit_builder.build()?.welcome_messages[0];
+
+ alice_group.apply_pending_commit()?;
+
+ // Bob's clients join the group.
+ let mut groups = vec![alice_group];
+
+ for bob_client in &bob_clients {
+ let (bob_group, _info) = bob_client.join_group(None, welcome_message)?;
+ groups.push(bob_group);
+ }
+
+ Ok(groups)
+}
+
+fn make_client<P: CryptoProvider + Clone>(
+ crypto_provider: P,
+ name: &str,
+) -> Result<Client<impl MlsConfig>, MlsError> {
+ let cipher_suite = crypto_provider.cipher_suite_provider(CIPHERSUITE).unwrap();
+
+ // Generate a signature key pair.
+ let (secret, public) = cipher_suite.signature_key_generate().unwrap();
+
+ // Create a basic credential for the session.
+ // NOTE: BasicCredential is for demonstration purposes and not recommended for production.
+ // X.509 credentials are recommended.
+ let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
+ let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public);
+
+ Ok(Client::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(crypto_provider)
+ .mls_rules(
+ DefaultMlsRules::new()
+ .with_commit_options(CommitOptions::new().with_path_required(true)),
+ )
+ .signing_identity(signing_identity, secret, CIPHERSUITE)
+ .build())
+}
+
+fn make_name(i: usize) -> String {
+ format!("bob {i:08}")
+}
+
+fn main() -> Result<(), MlsError> {
+ let crypto_provider = mls_rs_crypto_openssl::OpensslCryptoProvider::default();
+
+ println!("Demonstrate that performance depends on a) group evolution and b) a members position in the tree.\n");
+
+ let (small_bench_bc, large_bench_bc) = bench_commit_size(Case::Best, &crypto_provider)?;
+ let (small_bench_wc, large_bench_wc) = bench_commit_size(Case::Worst, &crypto_provider)?;
+
+ println!("\nBest case a), worst case b) : commit size is θ(log(n)) bytes.");
+ println!("group sizes n :\n{GROUP_SIZES:?}\ncommit sizes :\n{large_bench_bc:?}");
+
+ println!("\nWorst case a), worst case b) : commit size is θ(n) bytes.");
+ println!("group sizes n :\n{GROUP_SIZES:?}\ncommit sizes :\n{large_bench_wc:?}");
+
+ println!(
+ "\nBest case b) : if n-1 is a power of 2, commit size is θ(1) bytes, independent of a)."
+ );
+ println!("group sizes n :\n{GROUP_SIZES:?}\ncommit sizes, best case a) :\n{small_bench_bc:?}");
+ println!("commit sizes, worst case a) :\n{small_bench_wc:?}");
+
+ Ok(())
+}
diff --git a/examples/x509.rs b/examples/x509.rs
new file mode 100644
index 0000000..42316ce
--- /dev/null
+++ b/examples/x509.rs
@@ -0,0 +1,38 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::{CipherSuite, Client};
+use mls_rs_crypto_openssl::OpensslCryptoProvider;
+
+const CIPHERSUITE: CipherSuite = CipherSuite::CURVE25519_AES128;
+
+fn main() {
+ let crypto_provider = OpensslCryptoProvider::new();
+
+ let secret_key = mls_rs_crypto_openssl::x509::signature_secret_key_from_bytes(include_bytes!(
+ "../../mls-rs-crypto-openssl/test_data/x509/leaf/key.pem"
+ ))
+ .unwrap();
+
+ let signing_identity = mls_rs_crypto_openssl::x509::signing_identity_from_certificate(
+ include_bytes!("../../mls-rs-crypto-openssl/test_data/x509/leaf/cert.der"),
+ )
+ .unwrap();
+
+ let alice_client = Client::builder()
+ .crypto_provider(crypto_provider)
+ .identity_provider(
+ mls_rs_crypto_openssl::x509::identity_provider_from_certificate(include_bytes!(
+ "../../mls-rs-crypto-openssl/test_data/x509/root_ca/cert.der"
+ ))
+ .unwrap(),
+ )
+ .signing_identity(signing_identity, secret_key, CIPHERSUITE)
+ .build();
+
+ let mut alice_group = alice_client.create_group(Default::default()).unwrap();
+
+ alice_group.commit(Vec::new()).unwrap();
+ alice_group.apply_pending_commit().unwrap();
+}
diff --git a/src/client.rs b/src/client.rs
new file mode 100644
index 0000000..a7031bb
--- /dev/null
+++ b/src/client.rs
@@ -0,0 +1,1049 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::cipher_suite::CipherSuite;
+use crate::client_builder::{recreate_config, BaseConfig, ClientBuilder, MakeConfig};
+use crate::client_config::ClientConfig;
+use crate::group::framing::MlsMessage;
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{
+ framing::{Content, MlsMessagePayload, PublicMessage, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ proposal::{AddProposal, Proposal},
+};
+use crate::group::{snapshot::Snapshot, ExportedTree, Group, NewMemberInfo};
+use crate::identity::SigningIdentity;
+use crate::key_package::{KeyPackageGeneration, KeyPackageGenerator};
+use crate::protocol_version::ProtocolVersion;
+use crate::tree_kem::node::NodeIndex;
+use alloc::vec::Vec;
+use mls_rs_codec::MlsDecode;
+use mls_rs_core::crypto::{CryptoProvider, SignatureSecretKey};
+use mls_rs_core::error::{AnyError, IntoAnyError};
+use mls_rs_core::extension::{ExtensionError, ExtensionList, ExtensionType};
+use mls_rs_core::group::{GroupStateStorage, ProposalType};
+use mls_rs_core::identity::CredentialType;
+use mls_rs_core::key_package::KeyPackageStorage;
+
+use crate::group::external_commit::ExternalCommitBuilder;
+
+#[cfg(feature = "by_ref_proposal")]
+use alloc::boxed::Box;
+
+#[derive(Debug)]
+#[cfg_attr(feature = "std", derive(thiserror::Error))]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::enum_to_error_code)]
+#[non_exhaustive]
+pub enum MlsError {
+ #[cfg_attr(feature = "std", error(transparent))]
+ IdentityProviderError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ CryptoProviderError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ KeyPackageRepoError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ GroupStorageError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ PskStoreError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ MlsRulesError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ SerializationError(AnyError),
+ #[cfg_attr(feature = "std", error(transparent))]
+ ExtensionError(AnyError),
+ #[cfg_attr(feature = "std", error("Cipher suite does not match"))]
+ CipherSuiteMismatch,
+ #[cfg_attr(feature = "std", error("Invalid commit, missing required path"))]
+ CommitMissingPath,
+ #[cfg_attr(feature = "std", error("plaintext message for incorrect epoch"))]
+ InvalidEpoch,
+ #[cfg_attr(feature = "std", error("invalid signature found"))]
+ InvalidSignature,
+ #[cfg_attr(feature = "std", error("invalid confirmation tag"))]
+ InvalidConfirmationTag,
+ #[cfg_attr(feature = "std", error("invalid membership tag"))]
+ InvalidMembershipTag,
+ #[cfg_attr(feature = "std", error("corrupt private key, missing required values"))]
+ InvalidTreeKemPrivateKey,
+ #[cfg_attr(feature = "std", error("key package not found, unable to process"))]
+ WelcomeKeyPackageNotFound,
+ #[cfg_attr(feature = "std", error("leaf not found in tree for index {0}"))]
+ LeafNotFound(u32),
+ #[cfg_attr(feature = "std", error("message from self can't be processed"))]
+ CantProcessMessageFromSelf,
+ #[cfg_attr(
+ feature = "std",
+ error("pending proposals found, commit required before application messages can be sent")
+ )]
+ CommitRequired,
+ #[cfg_attr(
+ feature = "std",
+ error("ratchet tree not provided or discovered in GroupInfo")
+ )]
+ RatchetTreeNotFound,
+ #[cfg_attr(feature = "std", error("External sender cannot commit"))]
+ ExternalSenderCannotCommit,
+ #[cfg_attr(feature = "std", error("Unsupported protocol version {0:?}"))]
+ UnsupportedProtocolVersion(ProtocolVersion),
+ #[cfg_attr(feature = "std", error("Protocol version mismatch"))]
+ ProtocolVersionMismatch,
+ #[cfg_attr(feature = "std", error("Unsupported cipher suite {0:?}"))]
+ UnsupportedCipherSuite(CipherSuite),
+ #[cfg_attr(feature = "std", error("Signing key of external sender is unknown"))]
+ UnknownSigningIdentityForExternalSender,
+ #[cfg_attr(
+ feature = "std",
+ error("External proposals are disabled for this group")
+ )]
+ ExternalProposalsDisabled,
+ #[cfg_attr(
+ feature = "std",
+ error("Signing identity is not allowed to externally propose")
+ )]
+ InvalidExternalSigningIdentity,
+ #[cfg_attr(feature = "std", error("Missing ExternalPub extension"))]
+ MissingExternalPubExtension,
+ #[cfg_attr(feature = "std", error("Epoch not found"))]
+ EpochNotFound,
+ #[cfg_attr(feature = "std", error("Unencrypted application message"))]
+ UnencryptedApplicationMessage,
+ #[cfg_attr(
+ feature = "std",
+ error("NewMemberCommit sender type can only be used to send Commit content")
+ )]
+ ExpectedCommitForNewMemberCommit,
+ #[cfg_attr(
+ feature = "std",
+ error("NewMemberProposal sender type can only be used to send add proposals")
+ )]
+ ExpectedAddProposalForNewMemberProposal,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit missing ExternalInit proposal")
+ )]
+ ExternalCommitMissingExternalInit,
+ #[cfg_attr(
+ feature = "std",
+ error(
+ "A ReIinit has been applied. The next action must be creating or receiving a welcome."
+ )
+ )]
+ GroupUsedAfterReInit,
+ #[cfg_attr(feature = "std", error("Pending ReIinit not found."))]
+ PendingReInitNotFound,
+ #[cfg_attr(
+ feature = "std",
+ error("The extensions in the welcome message and in the reinit do not match.")
+ )]
+ ReInitExtensionsMismatch,
+ #[cfg_attr(feature = "std", error("signer not found for given identity"))]
+ SignerNotFound,
+ #[cfg_attr(feature = "std", error("commit already pending"))]
+ ExistingPendingCommit,
+ #[cfg_attr(feature = "std", error("pending commit not found"))]
+ PendingCommitNotFound,
+ #[cfg_attr(feature = "std", error("unexpected message type for action"))]
+ UnexpectedMessageType,
+ #[cfg_attr(
+ feature = "std",
+ error("membership tag on MlsPlaintext for non-member sender")
+ )]
+ MembershipTagForNonMember,
+ #[cfg_attr(feature = "std", error("No member found for given identity id."))]
+ MemberNotFound,
+ #[cfg_attr(feature = "std", error("group not found"))]
+ GroupNotFound,
+ #[cfg_attr(feature = "std", error("unexpected PSK ID"))]
+ UnexpectedPskId,
+ #[cfg_attr(feature = "std", error("invalid sender for content type"))]
+ InvalidSender,
+ #[cfg_attr(feature = "std", error("GroupID mismatch"))]
+ GroupIdMismatch,
+ #[cfg_attr(feature = "std", error("storage retention can not be zero"))]
+ NonZeroRetentionRequired,
+ #[cfg_attr(feature = "std", error("Too many PSK IDs to compute PSK secret"))]
+ TooManyPskIds,
+ #[cfg_attr(feature = "std", error("Missing required Psk"))]
+ MissingRequiredPsk,
+ #[cfg_attr(feature = "std", error("Old group state not found"))]
+ OldGroupStateNotFound,
+ #[cfg_attr(feature = "std", error("leaf secret already consumed"))]
+ InvalidLeafConsumption,
+ #[cfg_attr(feature = "std", error("key not available, invalid generation {0}"))]
+ KeyMissing(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("requested generation {0} is too far ahead of current generation")
+ )]
+ InvalidFutureGeneration(u32),
+ #[cfg_attr(feature = "std", error("leaf node has no children"))]
+ LeafNodeNoChildren,
+ #[cfg_attr(feature = "std", error("root node has no parent"))]
+ LeafNodeNoParent,
+ #[cfg_attr(feature = "std", error("index out of range"))]
+ InvalidTreeIndex,
+ #[cfg_attr(feature = "std", error("time overflow"))]
+ TimeOverflow,
+ #[cfg_attr(feature = "std", error("invalid leaf_node_source"))]
+ InvalidLeafNodeSource,
+ #[cfg_attr(feature = "std", error("key package has expired or is not valid yet"))]
+ InvalidLifetime,
+ #[cfg_attr(feature = "std", error("required extension not found"))]
+ RequiredExtensionNotFound(ExtensionType),
+ #[cfg_attr(feature = "std", error("required proposal not found"))]
+ RequiredProposalNotFound(ProposalType),
+ #[cfg_attr(feature = "std", error("required credential not found"))]
+ RequiredCredentialNotFound(CredentialType),
+ #[cfg_attr(feature = "std", error("capabilities must describe extensions used"))]
+ ExtensionNotInCapabilities(ExtensionType),
+ #[cfg_attr(feature = "std", error("expected non-blank node"))]
+ ExpectedNode,
+ #[cfg_attr(feature = "std", error("node index is out of bounds {0}"))]
+ InvalidNodeIndex(NodeIndex),
+ #[cfg_attr(feature = "std", error("unexpected empty node found"))]
+ UnexpectedEmptyNode,
+ #[cfg_attr(
+ feature = "std",
+ error("duplicate signature key, hpke key or identity found at index {0}")
+ )]
+ DuplicateLeafData(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("In-use credential type not supported by new leaf at index")
+ )]
+ InUseCredentialTypeUnsupportedByNewLeaf,
+ #[cfg_attr(
+ feature = "std",
+ error("Not all members support the credential type used by new leaf")
+ )]
+ CredentialTypeOfNewLeafIsUnsupported,
+ #[cfg_attr(
+ feature = "std",
+ error("the length of the update path is different than the length of the direct path")
+ )]
+ WrongPathLen,
+ #[cfg_attr(
+ feature = "std",
+ error("same HPKE leaf key before and after applying the update path for leaf {0}")
+ )]
+ SameHpkeKey(u32),
+ #[cfg_attr(feature = "std", error("init key is not valid for cipher suite"))]
+ InvalidInitKey,
+ #[cfg_attr(
+ feature = "std",
+ error("init key can not be equal to leaf node public key")
+ )]
+ InitLeafKeyEquality,
+ #[cfg_attr(feature = "std", error("different identity in update for leaf {0}"))]
+ DifferentIdentityInUpdate(u32),
+ #[cfg_attr(feature = "std", error("update path pub key mismatch"))]
+ PubKeyMismatch,
+ #[cfg_attr(feature = "std", error("tree hash mismatch"))]
+ TreeHashMismatch,
+ #[cfg_attr(feature = "std", error("bad update: no suitable secret key"))]
+ UpdateErrorNoSecretKey,
+ #[cfg_attr(feature = "std", error("invalid lca, not found on direct path"))]
+ LcaNotFoundInDirectPath,
+ #[cfg_attr(feature = "std", error("update path parent hash mismatch"))]
+ ParentHashMismatch,
+ #[cfg_attr(feature = "std", error("unexpected pattern of unmerged leaves"))]
+ UnmergedLeavesMismatch,
+ #[cfg_attr(feature = "std", error("empty tree"))]
+ UnexpectedEmptyTree,
+ #[cfg_attr(feature = "std", error("trailing blanks"))]
+ UnexpectedTrailingBlanks,
+ // Proposal Rules errors
+ #[cfg_attr(
+ feature = "std",
+ error("Commiter must not include any update proposals generated by the commiter")
+ )]
+ InvalidCommitSelfUpdate,
+ #[cfg_attr(feature = "std", error("A PreSharedKey proposal must have a PSK of type External or type Resumption and usage Application"))]
+ InvalidTypeOrUsageInPreSharedKeyProposal,
+ #[cfg_attr(feature = "std", error("psk nonce length does not match cipher suite"))]
+ InvalidPskNonceLength,
+ #[cfg_attr(
+ feature = "std",
+ error("ReInit proposal protocol version is less than the version of the original group")
+ )]
+ InvalidProtocolVersionInReInit,
+ #[cfg_attr(feature = "std", error("More than one proposal applying to leaf: {0}"))]
+ MoreThanOneProposalForLeaf(u32),
+ #[cfg_attr(
+ feature = "std",
+ error("More than one GroupContextExtensions proposal")
+ )]
+ MoreThanOneGroupContextExtensionsProposal,
+ #[cfg_attr(feature = "std", error("Invalid proposal type for sender"))]
+ InvalidProposalTypeForSender,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit must have exactly one ExternalInit proposal")
+ )]
+ ExternalCommitMustHaveExactlyOneExternalInit,
+ #[cfg_attr(feature = "std", error("External commit must have a new leaf"))]
+ ExternalCommitMustHaveNewLeaf,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit contains removal of other identity")
+ )]
+ ExternalCommitRemovesOtherIdentity,
+ #[cfg_attr(
+ feature = "std",
+ error("External commit contains more than one Remove proposal")
+ )]
+ ExternalCommitWithMoreThanOneRemove,
+ #[cfg_attr(feature = "std", error("Duplicate PSK IDs"))]
+ DuplicatePskIds,
+ #[cfg_attr(
+ feature = "std",
+ error("Invalid proposal type {0:?} in external commit")
+ )]
+ InvalidProposalTypeInExternalCommit(ProposalType),
+ #[cfg_attr(feature = "std", error("Committer can not remove themselves"))]
+ CommitterSelfRemoval,
+ #[cfg_attr(
+ feature = "std",
+ error("Only members can commit proposals by reference")
+ )]
+ OnlyMembersCanCommitProposalsByRef,
+ #[cfg_attr(feature = "std", error("Other proposal with ReInit"))]
+ OtherProposalWithReInit,
+ #[cfg_attr(feature = "std", error("Unsupported group extension {0:?}"))]
+ UnsupportedGroupExtension(ExtensionType),
+ #[cfg_attr(feature = "std", error("Unsupported custom proposal type {0:?}"))]
+ UnsupportedCustomProposal(ProposalType),
+ #[cfg_attr(feature = "std", error("by-ref proposal not found"))]
+ ProposalNotFound,
+ #[cfg_attr(
+ feature = "std",
+ error("Removing non-existing member (or removing a member twice)")
+ )]
+ RemovingNonExistingMember,
+ #[cfg_attr(feature = "std", error("Updated identity not a valid successor"))]
+ InvalidSuccessor,
+ #[cfg_attr(
+ feature = "std",
+ error("Updating non-existing member (or updating a member twice)")
+ )]
+ UpdatingNonExistingMember,
+ #[cfg_attr(feature = "std", error("Failed generating next path secret"))]
+ FailedGeneratingPathSecret,
+ #[cfg_attr(feature = "std", error("Invalid group info"))]
+ InvalidGroupInfo,
+ #[cfg_attr(feature = "std", error("Invalid welcome message"))]
+ InvalidWelcomeMessage,
+}
+
+impl IntoAnyError for MlsError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+}
+
+impl From<mls_rs_codec::Error> for MlsError {
+ #[inline]
+ fn from(e: mls_rs_codec::Error) -> Self {
+ MlsError::SerializationError(e.into_any_error())
+ }
+}
+
+impl From<ExtensionError> for MlsError {
+ #[inline]
+ fn from(e: ExtensionError) -> Self {
+ MlsError::ExtensionError(e.into_any_error())
+ }
+}
+
+/// MLS client used to create key packages and manage groups.
+///
+/// [`Client::builder`] can be used to instantiate it.
+///
+/// Clients are able to support multiple protocol versions, ciphersuites
+/// and underlying identities used to join groups and generate key packages.
+/// Applications may decide to create one or many clients depending on their
+/// specific needs.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone, Debug)]
+pub struct Client<C> {
+ pub(crate) config: C,
+ pub(crate) signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ pub(crate) signer: Option<SignatureSecretKey>,
+ pub(crate) version: ProtocolVersion,
+}
+
+impl Client<()> {
+ /// Returns a [`ClientBuilder`]
+ /// used to configure client preferences and providers.
+ pub fn builder() -> ClientBuilder<BaseConfig> {
+ ClientBuilder::new()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<C> Client<C>
+where
+ C: ClientConfig + Clone,
+{
+ pub(crate) fn new(
+ config: C,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ version: ProtocolVersion,
+ ) -> Self {
+ Client {
+ config,
+ signer,
+ signing_identity,
+ version,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn to_builder(&self) -> ClientBuilder<MakeConfig<C>> {
+ ClientBuilder::from_config(recreate_config(
+ self.config.clone(),
+ self.signer.clone(),
+ self.signing_identity.clone(),
+ self.version,
+ ))
+ }
+
+ /// Creates a new key package message that can be used to to add this
+ /// client to a [Group](crate::group::Group). Each call to this function
+ /// will produce a unique value that is signed by `signing_identity`.
+ ///
+ /// The secret keys for the resulting key package message will be stored in
+ /// the [KeyPackageStorage](crate::KeyPackageStorage)
+ /// that was used to configure the client and will
+ /// automatically be erased when this key package is used to
+ /// [join a group](Client::join_group).
+ ///
+ /// # Warning
+ ///
+ /// A key package message may only be used once.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate_key_package_message(&self) -> Result<MlsMessage, MlsError> {
+ Ok(self.generate_key_package().await?.key_package_message())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn generate_key_package(&self) -> Result<KeyPackageGeneration, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ let cipher_suite_provider = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ let key_package_generator = KeyPackageGenerator {
+ protocol_version: self.version,
+ cipher_suite_provider: &cipher_suite_provider,
+ signing_key: self.signer()?,
+ signing_identity,
+ identity_provider: &self.config.identity_provider(),
+ };
+
+ let key_pkg_gen = key_package_generator
+ .generate(
+ self.config.lifetime(),
+ self.config.capabilities(),
+ self.config.key_package_extensions(),
+ self.config.leaf_node_extensions(),
+ )
+ .await?;
+
+ let (id, key_package_data) = key_pkg_gen.to_storage()?;
+
+ self.config
+ .key_package_repo()
+ .insert(id, key_package_data)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+
+ Ok(key_pkg_gen)
+ }
+
+ /// Create a group with a specific group_id.
+ ///
+ /// This function behaves the same way as
+ /// [create_group](Client::create_group) except that it
+ /// specifies a specific unique group identifier to be used.
+ ///
+ /// # Warning
+ ///
+ /// It is recommended to use [create_group](Client::create_group)
+ /// instead of this function because it guarantees that group_id values
+ /// are globally unique.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create_group_with_id(
+ &self,
+ group_id: Vec<u8>,
+ group_context_extensions: ExtensionList,
+ ) -> Result<Group<C>, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ Group::new(
+ self.config.clone(),
+ Some(group_id),
+ cipher_suite,
+ self.version,
+ signing_identity.clone(),
+ group_context_extensions,
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// Create a MLS group.
+ ///
+ /// The `cipher_suite` provided must be supported by the
+ /// [CipherSuiteProvider](crate::CipherSuiteProvider)
+ /// that was used to build the client.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create_group(
+ &self,
+ group_context_extensions: ExtensionList,
+ ) -> Result<Group<C>, MlsError> {
+ let (signing_identity, cipher_suite) = self.signing_identity()?;
+
+ Group::new(
+ self.config.clone(),
+ None,
+ cipher_suite,
+ self.version,
+ signing_identity.clone(),
+ group_context_extensions,
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// Join a MLS group via a welcome message created by a
+ /// [Commit](crate::group::CommitOutput).
+ ///
+ /// `tree_data` is required to be provided out of band if the client that
+ /// created `welcome_message` did not use the `ratchet_tree_extension`
+ /// according to [`MlsRules::commit_options`](`crate::MlsRules::commit_options`).
+ /// at the time the welcome message was created. `tree_data` can
+ /// be exported from a group using the
+ /// [export tree function](crate::group::Group::export_tree).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join_group(
+ &self,
+ tree_data: Option<ExportedTree<'_>>,
+ welcome_message: &MlsMessage,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ Group::join(
+ welcome_message,
+ tree_data,
+ self.config.clone(),
+ self.signer()?.clone(),
+ )
+ .await
+ }
+
+ /// 0-RTT add to an existing [group](crate::group::Group)
+ ///
+ /// External commits allow for immediate entry into a
+ /// [group](crate::group::Group), even if all of the group members
+ /// are currently offline and unable to process messages. Sending an
+ /// external commit is only allowed for groups that have provided
+ /// a public `group_info_message` containing an
+ /// [ExternalPubExt](crate::extension::ExternalPubExt), which can be
+ /// generated by an existing group member using the
+ /// [group_info_message](crate::group::Group::group_info_message)
+ /// function.
+ ///
+ /// `tree_data` may be provided following the same rules as [Client::join_group]
+ ///
+ /// If PSKs are provided in `external_psks`, the
+ /// [PreSharedKeyStorage](crate::PreSharedKeyStorage)
+ /// used to configure the client will be searched to resolve their values.
+ ///
+ /// `to_remove` may be used to remove an existing member provided that the
+ /// identity of the existing group member at that [index](crate::group::Member::index)
+ /// is a [valid successor](crate::IdentityProvider::valid_successor)
+ /// of `signing_identity` as defined by the
+ /// [IdentityProvider](crate::IdentityProvider) that this client
+ /// was configured with.
+ ///
+ /// # Warning
+ ///
+ /// Only one external commit can be performed against a given group info.
+ /// There may also be security trade-offs to this approach.
+ ///
+ // TODO: Add a comment about forward secrecy and a pointer to the future
+ // book chapter on this topic
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit_external(
+ &self,
+ group_info_msg: MlsMessage,
+ ) -> Result<(Group<C>, MlsMessage), MlsError> {
+ ExternalCommitBuilder::new(
+ self.signer()?.clone(),
+ self.signing_identity()?.0.clone(),
+ self.config.clone(),
+ )
+ .build(group_info_msg)
+ .await
+ }
+
+ pub fn external_commit_builder(&self) -> Result<ExternalCommitBuilder<C>, MlsError> {
+ Ok(ExternalCommitBuilder::new(
+ self.signer()?.clone(),
+ self.signing_identity()?.0.clone(),
+ self.config.clone(),
+ ))
+ }
+
+ /// Load an existing group state into this client using the
+ /// [GroupStateStorage](crate::GroupStateStorage) that
+ /// this client was configured to use.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn load_group(&self, group_id: &[u8]) -> Result<Group<C>, MlsError> {
+ let snapshot = self
+ .config
+ .group_state_storage()
+ .state(group_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .ok_or(MlsError::GroupNotFound)?;
+
+ let snapshot = Snapshot::mls_decode(&mut &*snapshot)?;
+
+ Group::from_snapshot(self.config.clone(), snapshot).await
+ }
+
+ /// Request to join an existing [group](crate::group::Group).
+ ///
+ /// An existing group member will need to perform a
+ /// [commit](crate::Group::commit) to complete the add and the resulting
+ /// welcome message can be used by [join_group](Client::join_group).
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn external_add_proposal(
+ &self,
+ group_info: &MlsMessage,
+ tree_data: Option<crate::group::ExportedTree<'_>>,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let protocol_version = group_info.version;
+
+ if !self.config.version_supported(protocol_version) && protocol_version == self.version {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .as_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite = group_info.group_context.cipher_suite;
+
+ let cipher_suite_provider = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ crate::group::validate_group_info_joiner(
+ protocol_version,
+ group_info,
+ tree_data,
+ &self.config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let key_package = self.generate_key_package().await?.key_package;
+
+ (key_package.cipher_suite == cipher_suite)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))?;
+
+ let message = AuthenticatedContent::new_signed(
+ &cipher_suite_provider,
+ &group_info.group_context,
+ Sender::NewMemberProposal,
+ Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
+ key_package,
+ })))),
+ self.signer()?,
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ let plaintext = PublicMessage {
+ content: message.content,
+ auth: message.auth,
+ membership_tag: None,
+ };
+
+ Ok(MlsMessage {
+ version: protocol_version,
+ payload: MlsMessagePayload::Plain(plaintext),
+ })
+ }
+
+ fn signer(&self) -> Result<&SignatureSecretKey, MlsError> {
+ self.signer.as_ref().ok_or(MlsError::SignerNotFound)
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn signing_identity(&self) -> Result<(&SigningIdentity, CipherSuite), MlsError> {
+ self.signing_identity
+ .as_ref()
+ .map(|(id, cs)| (id, *cs))
+ .ok_or(MlsError::SignerNotFound)
+ }
+
+ /// Returns key package extensions used by this client
+ pub fn key_package_extensions(&self) -> ExtensionList {
+ self.config.key_package_extensions()
+ }
+
+ /// The [KeyPackageStorage] that this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn key_package_store(&self) -> <C as ClientConfig>::KeyPackageRepository {
+ self.config.key_package_repo()
+ }
+
+ /// The [PreSharedKeyStorage](crate::PreSharedKeyStorage) that
+ /// this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn secret_store(&self) -> <C as ClientConfig>::PskStore {
+ self.config.secret_store()
+ }
+
+ /// The [GroupStateStorage] that this client was configured to use.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn group_state_storage(&self) -> <C as ClientConfig>::GroupStateStorage {
+ self.config.group_state_storage()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::identity::test_utils::get_test_signing_identity;
+
+ pub use crate::client_builder::test_utils::{TestClientBuilder, TestClientConfig};
+
+ pub const TEST_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::MLS_10;
+ pub const TEST_CIPHER_SUITE: CipherSuite = CipherSuite::P256_AES128;
+ pub const TEST_CUSTOM_PROPOSAL_TYPE: ProposalType = ProposalType::new(65001);
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn test_client_with_key_pkg(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identity: &str,
+ ) -> (Client<TestClientConfig>, MlsMessage) {
+ test_client_with_key_pkg_custom(protocol_version, cipher_suite, identity, |_| {}).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn test_client_with_key_pkg_custom<F>(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identity: &str,
+ mut config: F,
+ ) -> (Client<TestClientConfig>, MlsMessage)
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let (identity, secret_key) =
+ get_test_signing_identity(cipher_suite, identity.as_bytes()).await;
+
+ let mut client = TestClientBuilder::new_for_test()
+ .used_protocol_version(protocol_version)
+ .signing_identity(identity.clone(), secret_key, cipher_suite)
+ .build();
+
+ config(&mut client.config);
+
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ (client, key_package)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::*;
+
+ use super::*;
+ use crate::{
+ crypto::test_utils::TestCryptoProvider,
+ identity::test_utils::{get_test_basic_credential, get_test_signing_identity},
+ tree_kem::leaf_node::LeafNodeSource,
+ };
+ use assert_matches::assert_matches;
+
+ use crate::{
+ group::{
+ message_processor::ProposalMessageDescription,
+ proposal::Proposal,
+ test_utils::{test_group, test_group_custom_config},
+ ReceivedMessage,
+ },
+ psk::{ExternalPskId, PreSharedKey},
+ };
+
+ use alloc::vec;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_keygen() {
+ // This is meant to test the inputs to the internal key package generator
+ // See KeyPackageGenerator tests for key generation specific tests
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let (identity, secret_key) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let client = TestClientBuilder::new_for_test()
+ .signing_identity(identity.clone(), secret_key, cipher_suite)
+ .build();
+
+ // TODO: Tests around extensions
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ assert_eq!(key_package.version, protocol_version);
+
+ let key_package = key_package.into_key_package().unwrap();
+
+ assert_eq!(key_package.cipher_suite, cipher_suite);
+
+ assert_eq!(
+ &key_package.leaf_node.signing_identity.credential,
+ &get_test_basic_credential(b"foo".to_vec())
+ );
+
+ assert_eq!(key_package.leaf_node.signing_identity, identity);
+
+ let capabilities = key_package.leaf_node.ungreased_capabilities();
+ assert_eq!(capabilities, client.config.capabilities());
+
+ let client_lifetime = client.config.lifetime();
+ assert_matches!(key_package.leaf_node.leaf_node_source, LeafNodeSource::KeyPackage(lifetime) if (lifetime.not_after - lifetime.not_before) == (client_lifetime.not_after - client_lifetime.not_before));
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_add_proposal_adds_to_group() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let proposal = bob
+ .external_add_proposal(
+ &alice_group.group.group_info_message(true).await.unwrap(),
+ None,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ let message = alice_group
+ .group
+ .process_incoming_message(proposal)
+ .await
+ .unwrap();
+
+ assert_matches!(
+ message,
+ ReceivedMessage::Proposal(ProposalMessageDescription {
+ proposal: Proposal::Add(p), ..}
+ ) if p.key_package.leaf_node.signing_identity == bob_identity
+ );
+
+ alice_group.group.commit(vec![]).await.unwrap();
+ alice_group.group.apply_pending_commit().await.unwrap();
+
+ // Check that the new member is in the group
+ assert!(alice_group
+ .group
+ .roster()
+ .members_iter()
+ .any(|member| member.signing_identity == bob_identity))
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn join_via_external_commit(do_remove: bool, with_psk: bool) -> Result<(), MlsError> {
+ // An external commit cannot be the first commit in a group as it requires
+ // interim_transcript_hash to be computed from the confirmed_transcript_hash and
+ // confirmation_tag, which is not the case for the initial interim_transcript_hash.
+
+ let psk = PreSharedKey::from(b"psk".to_vec());
+ let psk_id = ExternalPskId::new(b"psk id".to_vec());
+
+ let mut alice_group =
+ test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |c| {
+ c.psk(psk_id.clone(), psk.clone())
+ })
+ .await;
+
+ let (mut bob_group, _) = alice_group
+ .join_with_custom_config("bob", false, |c| {
+ c.0.psk_store.insert(psk_id.clone(), psk.clone());
+ })
+ .await
+ .unwrap();
+
+ let group_info_msg = alice_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let new_client_id = if do_remove { "bob" } else { "charlie" };
+
+ let (new_client_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, new_client_id.as_bytes()).await;
+
+ let new_client = TestClientBuilder::new_for_test()
+ .psk(psk_id.clone(), psk)
+ .signing_identity(new_client_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let mut builder = new_client.external_commit_builder().unwrap();
+
+ if do_remove {
+ builder = builder.with_removal(1);
+ }
+
+ if with_psk {
+ builder = builder.with_external_psk(psk_id);
+ }
+
+ let (new_group, external_commit) = builder.build(group_info_msg).await?;
+
+ let num_members = if do_remove { 2 } else { 3 };
+
+ assert_eq!(new_group.roster().members_iter().count(), num_members);
+
+ let _ = alice_group
+ .group
+ .process_incoming_message(external_commit.clone())
+ .await
+ .unwrap();
+
+ let bob_current_epoch = bob_group.group.current_epoch();
+
+ let message = bob_group
+ .group
+ .process_incoming_message(external_commit)
+ .await
+ .unwrap();
+
+ assert!(alice_group.group.roster().members_iter().count() == num_members);
+
+ if !do_remove {
+ assert!(bob_group.group.roster().members_iter().count() == num_members);
+ } else {
+ // Bob was removed so his epoch must stay the same
+ assert_eq!(bob_group.group.current_epoch(), bob_current_epoch);
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(message, ReceivedMessage::Commit(desc) if !desc.state_update.active);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(message, ReceivedMessage::Commit(_));
+ }
+
+ // Comparing epoch authenticators is sufficient to check that members are in sync.
+ assert_eq!(
+ alice_group.group.epoch_authenticator().unwrap(),
+ new_group.epoch_authenticator().unwrap()
+ );
+
+ Ok(())
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_external_commit() {
+ // New member can join
+ join_via_external_commit(false, false).await.unwrap();
+ // New member can remove an old copy of themselves
+ join_via_external_commit(true, false).await.unwrap();
+ // New member can inject a PSK
+ join_via_external_commit(false, true).await.unwrap();
+ // All works together
+ join_via_external_commit(true, true).await.unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_an_external_commit_requires_a_group_info_message() {
+ let (alice_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"alice").await;
+
+ let alice = TestClientBuilder::new_for_test()
+ .signing_identity(alice_identity.clone(), secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let msg = alice.generate_key_package_message().await.unwrap();
+ let res = alice.commit_external(msg).await.map(|_| ());
+
+ assert_matches!(res, Err(MlsError::UnexpectedMessageType));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_with_invalid_group_info_fails() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut bob_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ bob_group.group.commit(vec![]).await.unwrap();
+ bob_group.group.apply_pending_commit().await.unwrap();
+
+ let group_info_msg = bob_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (carol_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"carol").await;
+
+ let carol = TestClientBuilder::new_for_test()
+ .signing_identity(carol_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (_, external_commit) = carol
+ .external_commit_builder()
+ .unwrap()
+ .build(group_info_msg)
+ .await
+ .unwrap();
+
+ // If Carol tries to join Alice's group using the group info from Bob's group, that fails.
+ let res = alice_group
+ .group
+ .process_incoming_message(external_commit)
+ .await;
+ assert_matches!(res, Err(_));
+ }
+
+ #[test]
+ fn builder_can_be_obtained_from_client_to_edit_properties_for_new_client() {
+ let alice = TestClientBuilder::new_for_test()
+ .extension_type(33.into())
+ .build();
+ let bob = alice.to_builder().extension_type(34.into()).build();
+ assert_eq!(bob.config.supported_extensions(), [33, 34].map(Into::into));
+ }
+}
diff --git a/src/client_builder.rs b/src/client_builder.rs
new file mode 100644
index 0000000..186c436
--- /dev/null
+++ b/src/client_builder.rs
@@ -0,0 +1,1029 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! Definitions to build a [`Client`].
+//!
+//! See [`ClientBuilder`].
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::Client,
+ client_config::ClientConfig,
+ extension::{ExtensionType, MlsExtension},
+ group::{
+ mls_rules::{DefaultMlsRules, MlsRules},
+ proposal::ProposalType,
+ },
+ identity::CredentialType,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::{ExternalPskId, PreSharedKey},
+ storage_provider::in_memory::{
+ InMemoryGroupStateStorage, InMemoryKeyPackageStorage, InMemoryPreSharedKeyStorage,
+ },
+ tree_kem::{Capabilities, Lifetime},
+ Sealed,
+};
+
+#[cfg(feature = "std")]
+use crate::time::MlsTime;
+
+use alloc::vec::Vec;
+
+#[cfg(feature = "sqlite")]
+use mls_rs_provider_sqlite::{
+ SqLiteDataStorageEngine, SqLiteDataStorageError,
+ {
+ connection_strategy::ConnectionStrategy,
+ storage::{SqLiteGroupStateStorage, SqLiteKeyPackageStorage, SqLitePreSharedKeyStorage},
+ },
+};
+
+#[cfg(feature = "private_message")]
+pub use crate::group::padding::PaddingMode;
+
+/// Base client configuration type when instantiating `ClientBuilder`
+pub type BaseConfig = Config<
+ InMemoryKeyPackageStorage,
+ InMemoryPreSharedKeyStorage,
+ InMemoryGroupStateStorage,
+ Missing,
+ DefaultMlsRules,
+ Missing,
+>;
+
+/// Base client configuration type when instantiating `ClientBuilder`
+pub type BaseInMemoryConfig = Config<
+ InMemoryKeyPackageStorage,
+ InMemoryPreSharedKeyStorage,
+ InMemoryGroupStateStorage,
+ Missing,
+ Missing,
+ Missing,
+>;
+
+pub type EmptyConfig = Config<Missing, Missing, Missing, Missing, Missing, Missing>;
+
+/// Base client configuration that is backed by SQLite storage.
+#[cfg(feature = "sqlite")]
+pub type BaseSqlConfig = Config<
+ SqLiteKeyPackageStorage,
+ SqLitePreSharedKeyStorage,
+ SqLiteGroupStateStorage,
+ Missing,
+ DefaultMlsRules,
+ Missing,
+>;
+
+/// Builder for [`Client`]
+///
+/// This is returned by [`Client::builder`] and allows to tweak settings the `Client` will use. At a
+/// minimum, the builder must be told the [`CryptoProvider`] and [`IdentityProvider`] to use. Other
+/// settings have default values. This means that the following
+/// methods must be called before [`ClientBuilder::build`]:
+///
+/// - To specify the [`CryptoProvider`]: [`ClientBuilder::crypto_provider`]
+/// - To specify the [`IdentityProvider`]: [`ClientBuilder::identity_provider`]
+///
+/// # Example
+///
+/// ```
+/// use mls_rs::{
+/// Client,
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+///
+/// let _client = Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build();
+/// ```
+///
+/// # Spelling out a `Client` type
+///
+/// There are two main ways to spell out a `Client` type if needed (e.g. function return type).
+///
+/// The first option uses `impl MlsConfig`:
+/// ```
+/// use mls_rs::{
+/// Client,
+/// client_builder::MlsConfig,
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// fn make_client() -> Client<impl MlsConfig> {
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+/// Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build()
+/// }
+///```
+///
+/// The second option is more verbose and consists in writing the full `Client` type:
+/// ```
+/// use mls_rs::{
+/// Client,
+/// client_builder::{BaseConfig, WithIdentityProvider, WithCryptoProvider},
+/// identity::{SigningIdentity, basic::{BasicIdentityProvider, BasicCredential}},
+/// CipherSuite,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// type MlsClient = Client<
+/// WithIdentityProvider<
+/// BasicIdentityProvider,
+/// WithCryptoProvider<OpensslCryptoProvider, BaseConfig>,
+/// >,
+/// >;
+///
+/// fn make_client_2() -> MlsClient {
+/// // Replace by code to load the certificate and secret key
+/// let secret_key = b"never hard-code secrets".to_vec().into();
+/// let public_key = b"test invalid public key".to_vec().into();
+/// let basic_identity = BasicCredential::new(b"name".to_vec());
+/// let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public_key);
+///
+/// Client::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .signing_identity(signing_identity, secret_key, CipherSuite::CURVE25519_AES128)
+/// .build()
+/// }
+///
+/// ```
+#[derive(Debug)]
+pub struct ClientBuilder<C>(C);
+
+impl Default for ClientBuilder<BaseConfig> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl<C> ClientBuilder<C> {
+ pub(crate) fn from_config(c: C) -> Self {
+ Self(c)
+ }
+}
+
+impl ClientBuilder<BaseConfig> {
+ /// Create a new client builder with default in-memory providers
+ pub fn new() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: Default::default(),
+ psk_store: Default::default(),
+ group_state_storage: Default::default(),
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ }))
+ }
+}
+
+impl ClientBuilder<EmptyConfig> {
+ pub fn new_empty() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: Missing,
+ psk_store: Missing,
+ group_state_storage: Missing,
+ identity_provider: Missing,
+ mls_rules: Missing,
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ }))
+ }
+}
+
+#[cfg(feature = "sqlite")]
+impl ClientBuilder<BaseSqlConfig> {
+ /// Create a new client builder with SQLite storage providers.
+ pub fn new_sqlite<CS: ConnectionStrategy>(
+ storage: SqLiteDataStorageEngine<CS>,
+ ) -> Result<Self, SqLiteDataStorageError> {
+ Ok(Self(Config(ConfigInner {
+ settings: Default::default(),
+ key_package_repo: storage.key_package_storage()?,
+ psk_store: storage.pre_shared_key_storage()?,
+ group_state_storage: storage.group_state_storage()?,
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signer: Default::default(),
+ signing_identity: Default::default(),
+ version: ProtocolVersion::MLS_10,
+ })))
+ }
+}
+
+impl<C: IntoConfig> ClientBuilder<C> {
+ /// Add an extension type to the list of extension types supported by the client.
+ pub fn extension_type(self, type_: ExtensionType) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.extension_types(Some(type_))
+ }
+
+ /// Add multiple extension types to the list of extension types supported by the client.
+ pub fn extension_types<I>(self, types: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ExtensionType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.extension_types.extend(types);
+ ClientBuilder(c)
+ }
+
+ /// Add a custom proposal type to the list of proposals types supported by the client.
+ pub fn custom_proposal_type(self, type_: ProposalType) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.custom_proposal_types(Some(type_))
+ }
+
+ /// Add multiple custom proposal types to the list of proposal types supported by the client.
+ pub fn custom_proposal_types<I>(self, types: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProposalType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.custom_proposal_types.extend(types);
+ ClientBuilder(c)
+ }
+
+ /// Add a protocol version to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_version(self, version: ProtocolVersion) -> ClientBuilder<IntoConfigOutput<C>> {
+ self.protocol_versions(Some(version))
+ }
+
+ /// Add multiple protocol versions to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_versions<I>(self, versions: I) -> ClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProtocolVersion>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.protocol_versions.extend(versions);
+ ClientBuilder(c)
+ }
+
+ /// Add a key package extension to the list of key package extensions supported by the client.
+ pub fn key_package_extension<T>(
+ self,
+ extension: T,
+ ) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
+ where
+ T: MlsExtension,
+ Self: Sized,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_extensions.set_from(extension)?;
+ Ok(ClientBuilder(c))
+ }
+
+ /// Add multiple key package extensions to the list of key package extensions supported by the
+ /// client.
+ pub fn key_package_extensions(
+ self,
+ extensions: ExtensionList,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_extensions.append(extensions);
+ ClientBuilder(c)
+ }
+
+ /// Add a leaf node extension to the list of leaf node extensions supported by the client.
+ pub fn leaf_node_extension<T>(
+ self,
+ extension: T,
+ ) -> Result<ClientBuilder<IntoConfigOutput<C>>, ExtensionError>
+ where
+ T: MlsExtension,
+ Self: Sized,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.leaf_node_extensions.set_from(extension)?;
+ Ok(ClientBuilder(c))
+ }
+
+ /// Add multiple leaf node extensions to the list of leaf node extensions supported by the
+ /// client.
+ pub fn leaf_node_extensions(
+ self,
+ extensions: ExtensionList,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.leaf_node_extensions.append(extensions);
+ ClientBuilder(c)
+ }
+
+ /// Set the lifetime duration in seconds of key packages generated by the client.
+ pub fn key_package_lifetime(self, duration_in_s: u64) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.lifetime_in_s = duration_in_s;
+ ClientBuilder(c)
+ }
+
+ /// Set the key package repository to be used by the client.
+ ///
+ /// By default, an in-memory repository is used.
+ pub fn key_package_repo<K>(self, key_package_repo: K) -> ClientBuilder<WithKeyPackageRepo<K, C>>
+ where
+ K: KeyPackageStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the PSK store to be used by the client.
+ ///
+ /// By default, an in-memory store is used.
+ pub fn psk_store<P>(self, psk_store: P) -> ClientBuilder<WithPskStore<P, C>>
+ where
+ P: PreSharedKeyStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the group state storage to be used by the client.
+ ///
+ /// By default, an in-memory storage is used.
+ pub fn group_state_storage<G>(
+ self,
+ group_state_storage: G,
+ ) -> ClientBuilder<WithGroupStateStorage<G, C>>
+ where
+ G: GroupStateStorage,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage,
+ identity_provider: c.identity_provider,
+ crypto_provider: c.crypto_provider,
+ mls_rules: c.mls_rules,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the identity validator to be used by the client.
+ pub fn identity_provider<I>(
+ self,
+ identity_provider: I,
+ ) -> ClientBuilder<WithIdentityProvider<I, C>>
+ where
+ I: IdentityProvider,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the crypto provider to be used by the client.
+ pub fn crypto_provider<Cp>(
+ self,
+ crypto_provider: Cp,
+ ) -> ClientBuilder<WithCryptoProvider<Cp, C>>
+ where
+ Cp: CryptoProvider,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the user-defined proposal rules to be used by the client.
+ ///
+ /// User-defined rules are used when sending and receiving commits before
+ /// enforcing general MLS protocol rules. If the rule set returns an error when
+ /// receiving a commit, the entire commit is considered invalid. If the
+ /// rule set would return an error when sending a commit, individual proposals
+ /// may be filtered out to compensate.
+ pub fn mls_rules<Pr>(self, mls_rules: Pr) -> ClientBuilder<WithMlsRules<Pr, C>>
+ where
+ Pr: MlsRules,
+ {
+ let Config(c) = self.0.into_config();
+
+ ClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ key_package_repo: c.key_package_repo,
+ psk_store: c.psk_store,
+ group_state_storage: c.group_state_storage,
+ identity_provider: c.identity_provider,
+ mls_rules,
+ crypto_provider: c.crypto_provider,
+ signer: c.signer,
+ signing_identity: c.signing_identity,
+ version: c.version,
+ }))
+ }
+
+ /// Set the protocol version used by the client. By default, the client uses version MLS 1.0
+ pub fn used_protocol_version(
+ self,
+ version: ProtocolVersion,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.version = version;
+ ClientBuilder(c)
+ }
+
+ /// Set the signing identity used by the client as well as the matching signer and cipher suite.
+ /// This must be called in order to create groups and key packages.
+ pub fn signing_identity(
+ self,
+ signing_identity: SigningIdentity,
+ signer: SignatureSecretKey,
+ cipher_suite: CipherSuite,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signer = Some(signer);
+ c.0.signing_identity = Some((signing_identity, cipher_suite));
+ ClientBuilder(c)
+ }
+
+ /// Set the signer used by the client. This must be called in order to join groups.
+ pub fn signer(self, signer: SignatureSecretKey) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signer = Some(signer);
+ ClientBuilder(c)
+ }
+
+ #[cfg(any(test, feature = "test_util"))]
+ pub(crate) fn key_package_not_before(
+ self,
+ key_package_not_before: u64,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.key_package_not_before = Some(key_package_not_before);
+ ClientBuilder(c)
+ }
+}
+
+impl<C: IntoConfig> ClientBuilder<C>
+where
+ C::KeyPackageRepository: KeyPackageStorage + Clone,
+ C::PskStore: PreSharedKeyStorage + Clone,
+ C::GroupStateStorage: GroupStateStorage + Clone,
+ C::IdentityProvider: IdentityProvider + Clone,
+ C::MlsRules: MlsRules + Clone,
+ C::CryptoProvider: CryptoProvider + Clone,
+{
+ pub(crate) fn build_config(self) -> IntoConfigOutput<C> {
+ let mut c = self.0.into_config();
+
+ if c.0.settings.protocol_versions.is_empty() {
+ c.0.settings.protocol_versions = ProtocolVersion::all().collect();
+ }
+
+ c
+ }
+
+ /// Build a client.
+ ///
+ /// See [`ClientBuilder`] documentation if the return type of this function needs to be spelled
+ /// out.
+ pub fn build(self) -> Client<IntoConfigOutput<C>> {
+ let mut c = self.build_config();
+ let version = c.0.version;
+ let signer = c.0.signer.take();
+ let signing_identity = c.0.signing_identity.take();
+
+ Client::new(c, signer, signing_identity, version)
+ }
+}
+
+impl<C: IntoConfig<PskStore = InMemoryPreSharedKeyStorage>> ClientBuilder<C> {
+ /// Add a PSK to the in-memory PSK store.
+ pub fn psk(
+ self,
+ psk_id: ExternalPskId,
+ psk: PreSharedKey,
+ ) -> ClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.psk_store.insert(psk_id, psk);
+ ClientBuilder(c)
+ }
+}
+
+/// Marker type for required `ClientBuilder` services that have not been specified yet.
+#[derive(Debug)]
+pub struct Missing;
+
+/// Change the key package repository used by a client configuration.
+///
+/// See [`ClientBuilder::key_package_repo`].
+pub type WithKeyPackageRepo<K, C> = Config<
+ K,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the PSK store used by a client configuration.
+///
+/// See [`ClientBuilder::psk_store`].
+pub type WithPskStore<P, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ P,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the group state storage used by a client configuration.
+///
+/// See [`ClientBuilder::group_state_storage`].
+pub type WithGroupStateStorage<G, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ G,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the identity validator used by a client configuration.
+///
+/// See [`ClientBuilder::identity_provider`].
+pub type WithIdentityProvider<I, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ I,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the proposal rules used by a client configuration.
+///
+/// See [`ClientBuilder::mls_rules`].
+pub type WithMlsRules<Pr, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ Pr,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Change the crypto provider used by a client configuration.
+///
+/// See [`ClientBuilder::crypto_provider`].
+pub type WithCryptoProvider<Cp, C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ Cp,
+>;
+
+/// Helper alias for `Config`.
+pub type IntoConfigOutput<C> = Config<
+ <C as IntoConfig>::KeyPackageRepository,
+ <C as IntoConfig>::PskStore,
+ <C as IntoConfig>::GroupStateStorage,
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+/// Helper alias to make a `Config` from a `ClientConfig`
+pub type MakeConfig<C> = Config<
+ <C as ClientConfig>::KeyPackageRepository,
+ <C as ClientConfig>::PskStore,
+ <C as ClientConfig>::GroupStateStorage,
+ <C as ClientConfig>::IdentityProvider,
+ <C as ClientConfig>::MlsRules,
+ <C as ClientConfig>::CryptoProvider,
+>;
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> ClientConfig for ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>
+where
+ Kpr: KeyPackageStorage + Clone,
+ Ps: PreSharedKeyStorage + Clone,
+ Gss: GroupStateStorage + Clone,
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type KeyPackageRepository = Kpr;
+ type PskStore = Ps;
+ type GroupStateStorage = Gss;
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.settings.extension_types.clone()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.settings.protocol_versions.clone()
+ }
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository {
+ self.key_package_repo.clone()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.mls_rules.clone()
+ }
+
+ fn secret_store(&self) -> Self::PskStore {
+ self.psk_store.clone()
+ }
+
+ fn group_state_storage(&self) -> Self::GroupStateStorage {
+ self.group_state_storage.clone()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.identity_provider.clone()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.crypto_provider.clone()
+ }
+
+ fn key_package_extensions(&self) -> ExtensionList {
+ self.settings.key_package_extensions.clone()
+ }
+
+ fn leaf_node_extensions(&self) -> ExtensionList {
+ self.settings.leaf_node_extensions.clone()
+ }
+
+ fn lifetime(&self) -> Lifetime {
+ #[cfg(feature = "std")]
+ let now_timestamp = MlsTime::now().seconds_since_epoch();
+
+ #[cfg(not(feature = "std"))]
+ let now_timestamp = 0;
+
+ #[cfg(test)]
+ let now_timestamp = self
+ .settings
+ .key_package_not_before
+ .unwrap_or(now_timestamp);
+
+ Lifetime {
+ not_before: now_timestamp,
+ not_after: now_timestamp + self.settings.lifetime_in_s,
+ }
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<crate::group::proposal::ProposalType> {
+ self.settings.custom_proposal_types.clone()
+ }
+}
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> Sealed for Config<Kpr, Ps, Gss, Ip, Pr, Cp> {}
+
+impl<Kpr, Ps, Gss, Ip, Pr, Cp> MlsConfig for Config<Kpr, Ps, Gss, Ip, Pr, Cp>
+where
+ Kpr: KeyPackageStorage + Clone,
+
+ Ps: PreSharedKeyStorage + Clone,
+ Gss: GroupStateStorage + Clone,
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type Output = ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>;
+
+ fn get(&self) -> &Self::Output {
+ &self.0
+ }
+}
+
+/// Helper trait to allow consuming crates to easily write a client type as `Client<impl MlsConfig>`
+///
+/// It is not meant to be implemented by consuming crates. `T: MlsConfig` implies `T: ClientConfig`.
+pub trait MlsConfig: Clone + Send + Sync + Sealed {
+ #[doc(hidden)]
+ type Output: ClientConfig;
+
+ #[doc(hidden)]
+ fn get(&self) -> &Self::Output;
+}
+
+/// Blanket implementation so that `T: MlsConfig` implies `T: ClientConfig`
+impl<T: MlsConfig> ClientConfig for T {
+ type KeyPackageRepository = <T::Output as ClientConfig>::KeyPackageRepository;
+ type PskStore = <T::Output as ClientConfig>::PskStore;
+ type GroupStateStorage = <T::Output as ClientConfig>::GroupStateStorage;
+ type IdentityProvider = <T::Output as ClientConfig>::IdentityProvider;
+ type MlsRules = <T::Output as ClientConfig>::MlsRules;
+ type CryptoProvider = <T::Output as ClientConfig>::CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.get().supported_extensions()
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.get().supported_custom_proposals()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.get().supported_protocol_versions()
+ }
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository {
+ self.get().key_package_repo()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.get().mls_rules()
+ }
+
+ fn secret_store(&self) -> Self::PskStore {
+ self.get().secret_store()
+ }
+
+ fn group_state_storage(&self) -> Self::GroupStateStorage {
+ self.get().group_state_storage()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.get().identity_provider()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.get().crypto_provider()
+ }
+
+ fn key_package_extensions(&self) -> ExtensionList {
+ self.get().key_package_extensions()
+ }
+
+ fn leaf_node_extensions(&self) -> ExtensionList {
+ self.get().leaf_node_extensions()
+ }
+
+ fn lifetime(&self) -> Lifetime {
+ self.get().lifetime()
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ self.get().capabilities()
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.get().version_supported(version)
+ }
+
+ fn supported_credential_types(&self) -> Vec<CredentialType> {
+ self.get().supported_credential_types()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub(crate) struct Settings {
+ pub(crate) extension_types: Vec<ExtensionType>,
+ pub(crate) protocol_versions: Vec<ProtocolVersion>,
+ pub(crate) custom_proposal_types: Vec<ProposalType>,
+ pub(crate) key_package_extensions: ExtensionList,
+ pub(crate) leaf_node_extensions: ExtensionList,
+ pub(crate) lifetime_in_s: u64,
+ #[cfg(any(test, feature = "test_util"))]
+ pub(crate) key_package_not_before: Option<u64>,
+}
+
+impl Default for Settings {
+ fn default() -> Self {
+ Self {
+ extension_types: Default::default(),
+ protocol_versions: Default::default(),
+ key_package_extensions: Default::default(),
+ leaf_node_extensions: Default::default(),
+ lifetime_in_s: 365 * 24 * 3600,
+ custom_proposal_types: Default::default(),
+ #[cfg(any(test, feature = "test_util"))]
+ key_package_not_before: None,
+ }
+ }
+}
+
+pub(crate) fn recreate_config<T: ClientConfig>(
+ c: T,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ version: ProtocolVersion,
+) -> MakeConfig<T> {
+ Config(ConfigInner {
+ settings: Settings {
+ extension_types: c.supported_extensions(),
+ protocol_versions: c.supported_protocol_versions(),
+ custom_proposal_types: c.supported_custom_proposals(),
+ key_package_extensions: c.key_package_extensions(),
+ leaf_node_extensions: c.leaf_node_extensions(),
+ lifetime_in_s: {
+ let l = c.lifetime();
+ l.not_after - l.not_before
+ },
+ #[cfg(any(test, feature = "test_util"))]
+ key_package_not_before: None,
+ },
+ key_package_repo: c.key_package_repo(),
+ psk_store: c.secret_store(),
+ group_state_storage: c.group_state_storage(),
+ identity_provider: c.identity_provider(),
+ mls_rules: c.mls_rules(),
+ crypto_provider: c.crypto_provider(),
+ signer,
+ signing_identity,
+ version,
+ })
+}
+
+/// Definitions meant to be private that are inaccessible outside this crate. They need to be marked
+/// `pub` because they appear in public definitions.
+mod private {
+ use mls_rs_core::{
+ crypto::{CipherSuite, SignatureSecretKey},
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ };
+
+ use crate::client_builder::{IntoConfigOutput, Settings};
+
+ #[derive(Clone, Debug)]
+ pub struct Config<Kpr, Ps, Gss, Ip, Pr, Cp>(pub(crate) ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp>);
+
+ #[derive(Clone, Debug)]
+ pub struct ConfigInner<Kpr, Ps, Gss, Ip, Pr, Cp> {
+ pub(crate) settings: Settings,
+ pub(crate) key_package_repo: Kpr,
+ pub(crate) psk_store: Ps,
+ pub(crate) group_state_storage: Gss,
+ pub(crate) identity_provider: Ip,
+ pub(crate) mls_rules: Pr,
+ pub(crate) crypto_provider: Cp,
+ pub(crate) signer: Option<SignatureSecretKey>,
+ pub(crate) signing_identity: Option<(SigningIdentity, CipherSuite)>,
+ pub(crate) version: ProtocolVersion,
+ }
+
+ pub trait IntoConfig {
+ type KeyPackageRepository;
+ type PskStore;
+ type GroupStateStorage;
+ type IdentityProvider;
+ type MlsRules;
+ type CryptoProvider;
+
+ fn into_config(self) -> IntoConfigOutput<Self>;
+ }
+
+ impl<Kpr, Ps, Gss, Ip, Pr, Cp> IntoConfig for Config<Kpr, Ps, Gss, Ip, Pr, Cp> {
+ type KeyPackageRepository = Kpr;
+ type PskStore = Ps;
+ type GroupStateStorage = Gss;
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn into_config(self) -> Self {
+ self
+ }
+ }
+}
+
+use mls_rs_core::{
+ crypto::{CryptoProvider, SignatureSecretKey},
+ extension::{ExtensionError, ExtensionList},
+ group::GroupStateStorage,
+ identity::IdentityProvider,
+ key_package::KeyPackageStorage,
+ psk::PreSharedKeyStorage,
+};
+use private::{Config, ConfigInner, IntoConfig};
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ client_builder::{BaseConfig, ClientBuilder, WithIdentityProvider},
+ crypto::test_utils::TestCryptoProvider,
+ identity::{
+ basic::BasicIdentityProvider,
+ test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ },
+ CipherSuite,
+ };
+
+ use super::WithCryptoProvider;
+
+ pub type TestClientConfig = WithIdentityProvider<
+ BasicWithCustomProvider,
+ WithCryptoProvider<TestCryptoProvider, BaseConfig>,
+ >;
+
+ pub type TestClientBuilder = ClientBuilder<TestClientConfig>;
+
+ impl TestClientBuilder {
+ pub fn new_for_test() -> Self {
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(BasicWithCustomProvider::new(BasicIdentityProvider::new()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn with_random_signing_identity(
+ self,
+ identity: &str,
+ cipher_suite: CipherSuite,
+ ) -> Self {
+ let (signing_identity, signer) =
+ get_test_signing_identity(cipher_suite, identity.as_bytes()).await;
+ self.signing_identity(signing_identity, signer, cipher_suite)
+ }
+ }
+}
diff --git a/src/client_config.rs b/src/client_config.rs
new file mode 100644
index 0000000..339f335
--- /dev/null
+++ b/src/client_config.rs
@@ -0,0 +1,68 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ extension::ExtensionType,
+ group::{mls_rules::MlsRules, proposal::ProposalType},
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::{leaf_node::ConfigProperties, Capabilities, Lifetime},
+ ExtensionList,
+};
+use alloc::vec::Vec;
+use mls_rs_core::{
+ crypto::CryptoProvider, group::GroupStateStorage, identity::IdentityProvider,
+ key_package::KeyPackageStorage, psk::PreSharedKeyStorage,
+};
+
+pub trait ClientConfig: Send + Sync + Clone {
+ type KeyPackageRepository: KeyPackageStorage + Clone;
+ type PskStore: PreSharedKeyStorage + Clone;
+ type GroupStateStorage: GroupStateStorage + Clone;
+ type IdentityProvider: IdentityProvider + Clone;
+ type MlsRules: MlsRules + Clone;
+ type CryptoProvider: CryptoProvider + Clone;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType>;
+ fn supported_custom_proposals(&self) -> Vec<ProposalType>;
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion>;
+
+ fn key_package_repo(&self) -> Self::KeyPackageRepository;
+
+ fn mls_rules(&self) -> Self::MlsRules;
+
+ fn secret_store(&self) -> Self::PskStore;
+ fn group_state_storage(&self) -> Self::GroupStateStorage;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn crypto_provider(&self) -> Self::CryptoProvider;
+
+ fn key_package_extensions(&self) -> ExtensionList;
+ fn leaf_node_extensions(&self) -> ExtensionList;
+ fn lifetime(&self) -> Lifetime;
+
+ fn capabilities(&self) -> Capabilities {
+ Capabilities {
+ protocol_versions: self.supported_protocol_versions(),
+ cipher_suites: self.crypto_provider().supported_cipher_suites(),
+ extensions: self.supported_extensions(),
+ proposals: self.supported_custom_proposals(),
+ credentials: self.supported_credential_types(),
+ }
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.supported_protocol_versions().contains(&version)
+ }
+
+ fn supported_credential_types(&self) -> Vec<CredentialType> {
+ self.identity_provider().supported_types()
+ }
+
+ fn leaf_properties(&self) -> ConfigProperties {
+ ConfigProperties {
+ capabilities: self.capabilities(),
+ extensions: self.leaf_node_extensions(),
+ }
+ }
+}
diff --git a/src/crypto.rs b/src/crypto.rs
new file mode 100644
index 0000000..795476a
--- /dev/null
+++ b/src/crypto.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub(crate) use mls_rs_core::crypto::CipherSuiteProvider;
+
+pub use mls_rs_core::crypto::{
+ HpkeCiphertext, HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey, SignaturePublicKey,
+ SignatureSecretKey,
+};
+
+pub use mls_rs_core::secret::Secret;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use cfg_if::cfg_if;
+ use mls_rs_core::crypto::CryptoProvider;
+
+ cfg_if! {
+ if #[cfg(target_arch = "wasm32")] {
+ pub use mls_rs_crypto_webcrypto::WebCryptoProvider as TestCryptoProvider;
+ } else {
+ pub use mls_rs_crypto_openssl::OpensslCryptoProvider as TestCryptoProvider;
+ }
+ }
+
+ use crate::cipher_suite::CipherSuite;
+
+ pub fn test_cipher_suite_provider(
+ cipher_suite: CipherSuite,
+ ) -> <TestCryptoProvider as CryptoProvider>::CipherSuiteProvider {
+ TestCryptoProvider::new()
+ .cipher_suite_provider(cipher_suite)
+ .unwrap()
+ }
+
+ #[allow(unused)]
+ pub fn try_test_cipher_suite_provider(
+ cipher_suite: u16,
+ ) -> Option<<TestCryptoProvider as CryptoProvider>::CipherSuiteProvider> {
+ TestCryptoProvider::new().cipher_suite_provider(CipherSuite::from(cipher_suite))
+ }
+}
diff --git a/src/extension.rs b/src/extension.rs
new file mode 100644
index 0000000..4cba416
--- /dev/null
+++ b/src/extension.rs
@@ -0,0 +1,52 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::extension::{ExtensionType, MlsCodecExtension, MlsExtension};
+
+pub(crate) use built_in::*;
+
+/// Default extension types required by the MLS RFC.
+pub mod built_in;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+ use core::convert::Infallible;
+ use core::fmt::Debug;
+ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+ use mls_rs_core::extension::MlsExtension;
+
+ use super::*;
+
+ pub const TEST_EXTENSION_TYPE: u16 = 42;
+
+ #[derive(MlsSize, MlsEncode, MlsDecode, Clone, Debug, PartialEq)]
+ pub(crate) struct TestExtension {
+ pub(crate) foo: u8,
+ }
+
+ impl From<u8> for TestExtension {
+ fn from(value: u8) -> Self {
+ Self { foo: value }
+ }
+ }
+
+ impl MlsExtension for TestExtension {
+ type SerializationError = Infallible;
+
+ type DeserializationError = Infallible;
+
+ fn extension_type() -> ExtensionType {
+ ExtensionType::from(TEST_EXTENSION_TYPE)
+ }
+
+ fn to_bytes(&self) -> Result<Vec<u8>, Self::SerializationError> {
+ Ok([self.foo].to_vec())
+ }
+
+ fn from_bytes(data: &[u8]) -> Result<Self, Self::DeserializationError> {
+ Ok(TestExtension { foo: data[0] })
+ }
+ }
+}
diff --git a/src/extension/built_in.rs b/src/extension/built_in.rs
new file mode 100644
index 0000000..361a112
--- /dev/null
+++ b/src/extension/built_in.rs
@@ -0,0 +1,330 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::{ExtensionType, MlsCodecExtension};
+
+use mls_rs_core::{group::ProposalType, identity::CredentialType};
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_core::{
+ extension::ExtensionList,
+ identity::{IdentityProvider, SigningIdentity},
+ time::MlsTime,
+};
+
+use crate::group::ExportedTree;
+
+use mls_rs_core::crypto::HpkePublicKey;
+
+/// Application specific identifier.
+///
+/// A custom application level identifier that can be optionally stored
+/// within the `leaf_node_extensions` of a group [Member](crate::group::Member).
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct ApplicationIdExt {
+ /// Application level identifier presented by this extension.
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub identifier: Vec<u8>,
+}
+
+impl Debug for ApplicationIdExt {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ApplicationIdExt")
+ .field(
+ "identifier",
+ &mls_rs_core::debug::pretty_bytes(&self.identifier),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ApplicationIdExt {
+ /// Create a new application level identifier extension.
+ pub fn new(identifier: Vec<u8>) -> Self {
+ ApplicationIdExt { identifier }
+ }
+
+ /// Get the application level identifier presented by this extension.
+ #[cfg(feature = "ffi")]
+ pub fn identifier(&self) -> &[u8] {
+ &self.identifier
+ }
+}
+
+impl MlsCodecExtension for ApplicationIdExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::APPLICATION_ID
+ }
+}
+
+/// Representation of an MLS ratchet tree.
+///
+/// Used to provide new members
+/// a copy of the current group state in-band.
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct RatchetTreeExt {
+ pub tree_data: ExportedTree<'static>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl RatchetTreeExt {
+ /// Required custom extension types.
+ #[cfg(feature = "ffi")]
+ pub fn tree_data(&self) -> &ExportedTree<'static> {
+ &self.tree_data
+ }
+}
+
+impl MlsCodecExtension for RatchetTreeExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::RATCHET_TREE
+ }
+}
+
+/// Require members to have certain capabilities.
+///
+/// Used within a
+/// [Group Context Extensions Proposal](crate::group::proposal::Proposal)
+/// in order to require that all current and future members of a group MUST
+/// support specific extensions, proposals, or credentials.
+///
+/// # Warning
+///
+/// Extension, proposal, and credential types defined by the MLS RFC and
+/// provided are considered required by default and should NOT be used
+/// within this extension.
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
+pub struct RequiredCapabilitiesExt {
+ pub extensions: Vec<ExtensionType>,
+ pub proposals: Vec<ProposalType>,
+ pub credentials: Vec<CredentialType>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl RequiredCapabilitiesExt {
+ /// Create a required capabilities extension.
+ pub fn new(
+ extensions: Vec<ExtensionType>,
+ proposals: Vec<ProposalType>,
+ credentials: Vec<CredentialType>,
+ ) -> Self {
+ Self {
+ extensions,
+ proposals,
+ credentials,
+ }
+ }
+
+ /// Required custom extension types.
+ #[cfg(feature = "ffi")]
+ pub fn extensions(&self) -> &[ExtensionType] {
+ &self.extensions
+ }
+
+ /// Required custom proposal types.
+ #[cfg(feature = "ffi")]
+ pub fn proposals(&self) -> &[ProposalType] {
+ &self.proposals
+ }
+
+ /// Required custom credential types.
+ #[cfg(feature = "ffi")]
+ pub fn credentials(&self) -> &[CredentialType] {
+ &self.credentials
+ }
+}
+
+impl MlsCodecExtension for RequiredCapabilitiesExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::REQUIRED_CAPABILITIES
+ }
+}
+
+/// External public key used for [External Commits](crate::Client::commit_external).
+///
+/// This proposal type is optionally provided as part of a
+/// [Group Info](crate::group::Group::group_info_message).
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct ExternalPubExt {
+ /// Public key to be used for an external commit.
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub external_pub: HpkePublicKey,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ExternalPubExt {
+ /// Get the public key to be used for an external commit.
+ #[cfg(feature = "ffi")]
+ pub fn external_pub(&self) -> &HpkePublicKey {
+ &self.external_pub
+ }
+}
+
+impl MlsCodecExtension for ExternalPubExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::EXTERNAL_PUB
+ }
+}
+
+/// Enable proposals by an [ExternalClient](crate::external_client::ExternalClient).
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[non_exhaustive]
+pub struct ExternalSendersExt {
+ pub allowed_senders: Vec<SigningIdentity>,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ExternalSendersExt {
+ pub fn new(allowed_senders: Vec<SigningIdentity>) -> Self {
+ Self { allowed_senders }
+ }
+
+ #[cfg(feature = "ffi")]
+ pub fn allowed_senders(&self) -> &[SigningIdentity] {
+ &self.allowed_senders
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn verify_all<I: IdentityProvider>(
+ &self,
+ provider: &I,
+ timestamp: Option<MlsTime>,
+ group_context_extensions: &ExtensionList,
+ ) -> Result<(), I::Error> {
+ for id in self.allowed_senders.iter() {
+ provider
+ .validate_external_sender(id, timestamp, Some(group_context_extensions))
+ .await?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl MlsCodecExtension for ExternalSendersExt {
+ fn extension_type() -> ExtensionType {
+ ExtensionType::EXTERNAL_SENDERS
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ use crate::tree_kem::node::NodeVec;
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, identity::test_utils::get_test_signing_identity,
+ };
+
+ use mls_rs_core::extension::MlsExtension;
+
+ use mls_rs_core::identity::BasicCredential;
+
+ use alloc::vec;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[test]
+ fn test_application_id_extension() {
+ let test_id = vec![0u8; 32];
+ let test_extension = ApplicationIdExt {
+ identifier: test_id.clone(),
+ };
+
+ let as_extension = test_extension.into_extension().unwrap();
+
+ assert_eq!(as_extension.extension_type, ExtensionType::APPLICATION_ID);
+
+ let restored = ApplicationIdExt::from_extension(&as_extension).unwrap();
+ assert_eq!(restored.identifier, test_id);
+ }
+
+ #[test]
+ fn test_ratchet_tree() {
+ let ext = RatchetTreeExt {
+ tree_data: ExportedTree::new(NodeVec::from(vec![None, None])),
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+ assert_eq!(as_extension.extension_type, ExtensionType::RATCHET_TREE);
+
+ let restored = RatchetTreeExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[test]
+ fn test_required_capabilities() {
+ let ext = RequiredCapabilitiesExt {
+ extensions: vec![0.into(), 1.into()],
+ proposals: vec![42.into(), 43.into()],
+ credentials: vec![BasicCredential::credential_type()],
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+
+ assert_eq!(
+ as_extension.extension_type,
+ ExtensionType::REQUIRED_CAPABILITIES
+ );
+
+ let restored = RequiredCapabilitiesExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_external_senders() {
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, &[1]).await.0;
+ let ext = ExternalSendersExt::new(vec![identity]);
+
+ let as_extension = ext.clone().into_extension().unwrap();
+
+ assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_SENDERS);
+
+ let restored = ExternalSendersExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+
+ #[test]
+ fn test_external_pub() {
+ let ext = ExternalPubExt {
+ external_pub: vec![0, 1, 2, 3].into(),
+ };
+
+ let as_extension = ext.clone().into_extension().unwrap();
+ assert_eq!(as_extension.extension_type, ExtensionType::EXTERNAL_PUB);
+
+ let restored = ExternalPubExt::from_extension(&as_extension).unwrap();
+ assert_eq!(ext, restored)
+ }
+}
diff --git a/src/external_client.rs b/src/external_client.rs
new file mode 100644
index 0000000..0c882ac
--- /dev/null
+++ b/src/external_client.rs
@@ -0,0 +1,142 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{framing::MlsMessage, message_processor::validate_key_package, ExportedTree},
+ KeyPackage,
+};
+
+pub mod builder;
+mod config;
+mod group;
+
+pub(crate) use config::ExternalClientConfig;
+use mls_rs_core::{
+ crypto::{CryptoProvider, SignatureSecretKey},
+ identity::SigningIdentity,
+};
+
+use builder::{ExternalBaseConfig, ExternalClientBuilder};
+
+pub use group::{ExternalGroup, ExternalReceivedMessage, ExternalSnapshot};
+
+/// A client capable of observing a group's state without having
+/// private keys required to read content.
+///
+/// This structure is useful when an application is sending
+/// plaintext control messages in order to allow a central server
+/// to facilitate communication between users.
+///
+/// # Warning
+///
+/// This structure will only be able to observe groups that were
+/// created by clients that have the `encrypt_control_messages`
+/// option returned by [`MlsRules::encryption_options`](`crate::MlsRules::encryption_options`)
+/// set to `false`. Any control messages that are sent encrypted
+/// over the wire will break the ability of this client to track
+/// the resulting group state.
+pub struct ExternalClient<C> {
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl ExternalClient<()> {
+ pub fn builder() -> ExternalClientBuilder<ExternalBaseConfig> {
+ ExternalClientBuilder::new()
+ }
+}
+
+impl<C> ExternalClient<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ pub(crate) fn new(
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ ) -> Self {
+ Self {
+ config,
+ signing_data,
+ }
+ }
+
+ /// Begin observing a group based on a GroupInfo message created by
+ /// [Group::group_info_message](crate::group::Group::group_info_message)
+ ///
+ ///`tree_data` is required to be provided out of band if the client that
+ /// created GroupInfo message did not did not use the `ratchet_tree_extension`
+ /// according to [`MlsRules::commit_options`](crate::MlsRules::commit_options)
+ /// at the time the welcome message
+ /// was created. `tree_data` can be exported from a group using the
+ /// [export tree function](crate::group::Group::export_tree).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn observe_group(
+ &self,
+ group_info: MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<ExternalGroup<C>, MlsError> {
+ ExternalGroup::join(
+ self.config.clone(),
+ self.signing_data.clone(),
+ group_info,
+ tree_data,
+ )
+ .await
+ }
+
+ /// Load an existing observed group by loading a snapshot that was
+ /// generated by
+ /// [ExternalGroup::snapshot](self::ExternalGroup::snapshot).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn load_group(
+ &self,
+ snapshot: ExternalSnapshot,
+ ) -> Result<ExternalGroup<C>, MlsError> {
+ ExternalGroup::from_snapshot(self.config.clone(), snapshot).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate_key_package(
+ &self,
+ key_package: MlsMessage,
+ ) -> Result<KeyPackage, MlsError> {
+ let version = key_package.version;
+
+ let key_package = key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cs = self
+ .config
+ .crypto_provider()
+ .cipher_suite_provider(key_package.cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(key_package.cipher_suite))?;
+
+ let id = self.config.identity_provider();
+
+ validate_key_package(&key_package, version, &cs, &id).await?;
+
+ Ok(key_package)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod tests_utils {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ key_package::test_utils::test_key_package_message,
+ };
+
+ pub use super::builder::test_utils::*;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_client_can_validate_key_package() {
+ let kp = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await;
+ let server = TestExternalClientBuilder::new_for_test().build();
+ let validated_kp = server.validate_key_package(kp.clone()).await.unwrap();
+
+ assert_eq!(kp.into_key_package().unwrap(), validated_kp);
+ }
+}
diff --git a/src/external_client/builder.rs b/src/external_client/builder.rs
new file mode 100644
index 0000000..04c9768
--- /dev/null
+++ b/src/external_client/builder.rs
@@ -0,0 +1,602 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! Definitions to build an [`ExternalClient`].
+//!
+//! See [`ExternalClientBuilder`].
+
+use crate::{
+ crypto::SignaturePublicKey,
+ extension::ExtensionType,
+ external_client::{ExternalClient, ExternalClientConfig},
+ group::{
+ mls_rules::{DefaultMlsRules, MlsRules},
+ proposal::ProposalType,
+ },
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::Capabilities,
+ CryptoProvider, Sealed,
+};
+use std::{
+ collections::HashMap,
+ fmt::{self, Debug},
+};
+
+/// Base client configuration type when instantiating `ExternalClientBuilder`
+pub type ExternalBaseConfig = Config<Missing, DefaultMlsRules, Missing>;
+
+/// Builder for [`ExternalClient`]
+///
+/// This is returned by [`ExternalClient::builder`] and allows to tweak settings the
+/// `ExternalClient` will use. At a minimum, the builder must be told the [`CryptoProvider`]
+/// and [`IdentityProvider`] to use. Other settings have default values. This
+/// means that the following methods must be called before [`ExternalClientBuilder::build`]:
+///
+/// - To specify the [`CryptoProvider`]: [`ExternalClientBuilder::crypto_provider`]
+/// - To specify the [`IdentityProvider`]: [`ExternalClientBuilder::identity_provider`]
+///
+/// # Example
+///
+/// ```
+/// use mls_rs::{
+/// external_client::ExternalClient,
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// let _client = ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build();
+/// ```
+///
+/// # Spelling out an `ExternalClient` type
+///
+/// There are two main ways to spell out an `ExternalClient` type if needed (e.g. function return type).
+///
+/// The first option uses `impl MlsConfig`:
+/// ```
+/// use mls_rs::{
+/// external_client::{ExternalClient, builder::MlsConfig},
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// fn make_client() -> ExternalClient<impl MlsConfig> {
+/// ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::default())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build()
+/// }
+///```
+///
+/// The second option is more verbose and consists in writing the full `ExternalClient` type:
+/// ```
+/// use mls_rs::{
+/// external_client::{ExternalClient, builder::{ExternalBaseConfig, WithIdentityProvider, WithCryptoProvider}},
+/// identity::basic::BasicIdentityProvider,
+/// };
+///
+/// use mls_rs_crypto_openssl::OpensslCryptoProvider;
+///
+/// type MlsClient = ExternalClient<WithIdentityProvider<
+/// BasicIdentityProvider,
+/// WithCryptoProvider<OpensslCryptoProvider, ExternalBaseConfig>,
+/// >>;
+///
+/// fn make_client_2() -> MlsClient {
+/// ExternalClient::builder()
+/// .crypto_provider(OpensslCryptoProvider::new())
+/// .identity_provider(BasicIdentityProvider::new())
+/// .build()
+/// }
+///
+/// ```
+#[derive(Debug)]
+pub struct ExternalClientBuilder<C>(C);
+
+impl Default for ExternalClientBuilder<ExternalBaseConfig> {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+impl ExternalClientBuilder<ExternalBaseConfig> {
+ pub fn new() -> Self {
+ Self(Config(ConfigInner {
+ settings: Default::default(),
+ identity_provider: Missing,
+ mls_rules: DefaultMlsRules::new(),
+ crypto_provider: Missing,
+ signing_data: None,
+ }))
+ }
+}
+
+impl<C: IntoConfig> ExternalClientBuilder<C> {
+ /// Add an extension type to the list of extension types supported by the client.
+ pub fn extension_type(
+ self,
+ type_: ExtensionType,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.extension_types(Some(type_))
+ }
+
+ /// Add multiple extension types to the list of extension types supported by the client.
+ pub fn extension_types<I>(self, types: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ExtensionType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.extension_types.extend(types);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add a custom proposal type to the list of proposals types supported by the client.
+ pub fn custom_proposal_type(
+ self,
+ type_: ProposalType,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.custom_proposal_types(Some(type_))
+ }
+
+ /// Add multiple custom proposal types to the list of proposal types supported by the client.
+ pub fn custom_proposal_types<I>(self, types: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProposalType>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.custom_proposal_types.extend(types);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add a protocol version to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_version(
+ self,
+ version: ProtocolVersion,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ self.protocol_versions(Some(version))
+ }
+
+ /// Add multiple protocol versions to the list of protocol versions supported by the client.
+ ///
+ /// If no protocol version is explicitly added, the client will support all protocol versions
+ /// supported by this crate.
+ pub fn protocol_versions<I>(self, versions: I) -> ExternalClientBuilder<IntoConfigOutput<C>>
+ where
+ I: IntoIterator<Item = ProtocolVersion>,
+ {
+ let mut c = self.0.into_config();
+ c.0.settings.protocol_versions.extend(versions);
+ ExternalClientBuilder(c)
+ }
+
+ /// Add an external signing key to be used by the client.
+ pub fn external_signing_key(
+ self,
+ id: Vec<u8>,
+ key: SignaturePublicKey,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.external_signing_keys.insert(id, key);
+ ExternalClientBuilder(c)
+ }
+
+ /// Specify the number of epochs before the current one to keep.
+ ///
+ /// By default, all epochs are kept.
+ pub fn max_epoch_jitter(self, max_jitter: u64) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.max_epoch_jitter = Some(max_jitter);
+ ExternalClientBuilder(c)
+ }
+
+ /// Specify whether processed proposals should be cached by the external group. In case they
+ /// are not cached by the group, they should be cached externally and inserted using
+ /// `ExternalGroup::insert_proposal` before processing the next commit.
+ pub fn cache_proposals(
+ self,
+ cache_proposals: bool,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.settings.cache_proposals = cache_proposals;
+ ExternalClientBuilder(c)
+ }
+
+ /// Set the identity validator to be used by the client.
+ pub fn identity_provider<I>(
+ self,
+ identity_provider: I,
+ ) -> ExternalClientBuilder<WithIdentityProvider<I, C>>
+ where
+ I: IdentityProvider,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider: c.crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the crypto provider to be used by the client.
+ ///
+ // TODO add a comment once we have a default provider
+ pub fn crypto_provider<Cp>(
+ self,
+ crypto_provider: Cp,
+ ) -> ExternalClientBuilder<WithCryptoProvider<Cp, C>>
+ where
+ Cp: CryptoProvider,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider: c.identity_provider,
+ mls_rules: c.mls_rules,
+ crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the user-defined proposal rules to be used by the client.
+ ///
+ /// User-defined rules are used when sending and receiving commits before
+ /// enforcing general MLS protocol rules. If the rule set returns an error when
+ /// receiving a commit, the entire commit is considered invalid. If the
+ /// rule set would return an error when sending a commit, individual proposals
+ /// may be filtered out to compensate.
+ pub fn mls_rules<Pr>(self, mls_rules: Pr) -> ExternalClientBuilder<WithMlsRules<Pr, C>>
+ where
+ Pr: MlsRules,
+ {
+ let Config(c) = self.0.into_config();
+ ExternalClientBuilder(Config(ConfigInner {
+ settings: c.settings,
+ identity_provider: c.identity_provider,
+ mls_rules,
+ crypto_provider: c.crypto_provider,
+ signing_data: c.signing_data,
+ }))
+ }
+
+ /// Set the signature secret key used by the client to send external proposals.
+ pub fn signer(
+ self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ ) -> ExternalClientBuilder<IntoConfigOutput<C>> {
+ let mut c = self.0.into_config();
+ c.0.signing_data = Some((signer, signing_identity));
+ ExternalClientBuilder(c)
+ }
+}
+
+impl<C: IntoConfig> ExternalClientBuilder<C>
+where
+ C::IdentityProvider: IdentityProvider + Clone,
+ C::MlsRules: MlsRules + Clone,
+ C::CryptoProvider: CryptoProvider + Clone,
+{
+ pub(crate) fn build_config(self) -> IntoConfigOutput<C> {
+ let mut c = self.0.into_config();
+
+ if c.0.settings.protocol_versions.is_empty() {
+ c.0.settings.protocol_versions = ProtocolVersion::all().collect();
+ }
+
+ c
+ }
+
+ /// Build an external client.
+ ///
+ /// See [`ExternalClientBuilder`] documentation if the return type of this function needs to be
+ /// spelled out.
+ pub fn build(self) -> ExternalClient<IntoConfigOutput<C>> {
+ let mut c = self.build_config();
+ let signing_data = c.0.signing_data.take();
+ ExternalClient::new(c, signing_data)
+ }
+}
+
+/// Marker type for required `ExternalClientBuilder` services that have not been specified yet.
+#[derive(Debug)]
+pub struct Missing;
+
+/// Change the identity validator used by a client configuration.
+///
+/// See [`ExternalClientBuilder::identity_provider`].
+pub type WithIdentityProvider<I, C> =
+ Config<I, <C as IntoConfig>::MlsRules, <C as IntoConfig>::CryptoProvider>;
+
+/// Change the proposal filter used by a client configuration.
+///
+/// See [`ExternalClientBuilder::mls_rules`].
+pub type WithMlsRules<Pr, C> =
+ Config<<C as IntoConfig>::IdentityProvider, Pr, <C as IntoConfig>::CryptoProvider>;
+
+/// Change the crypto provider used by a client configuration.
+///
+/// See [`ExternalClientBuilder::crypto_provider`].
+pub type WithCryptoProvider<Cp, C> =
+ Config<<C as IntoConfig>::IdentityProvider, <C as IntoConfig>::MlsRules, Cp>;
+
+/// Helper alias for `Config`.
+pub type IntoConfigOutput<C> = Config<
+ <C as IntoConfig>::IdentityProvider,
+ <C as IntoConfig>::MlsRules,
+ <C as IntoConfig>::CryptoProvider,
+>;
+
+impl<Ip, Pr, Cp> ExternalClientConfig for ConfigInner<Ip, Pr, Cp>
+where
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.settings.extension_types.clone()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.settings.protocol_versions.clone()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.identity_provider.clone()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.crypto_provider.clone()
+ }
+
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey> {
+ self.settings
+ .external_signing_keys
+ .get(external_key_id)
+ .cloned()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.mls_rules.clone()
+ }
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ self.settings.max_epoch_jitter
+ }
+
+ fn cache_proposals(&self) -> bool {
+ self.settings.cache_proposals
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.settings.custom_proposal_types.clone()
+ }
+}
+
+impl<Ip, Mpf, Cp> Sealed for Config<Ip, Mpf, Cp> {}
+
+impl<Ip, Pr, Cp> MlsConfig for Config<Ip, Pr, Cp>
+where
+ Ip: IdentityProvider + Clone,
+ Pr: MlsRules + Clone,
+ Cp: CryptoProvider + Clone,
+{
+ type Output = ConfigInner<Ip, Pr, Cp>;
+
+ fn get(&self) -> &Self::Output {
+ &self.0
+ }
+}
+
+/// Helper trait to allow consuming crates to easily write an external client type as
+/// `ExternalClient<impl MlsConfig>`
+///
+/// It is not meant to be implemented by consuming crates. `T: MlsConfig` implies
+/// `T: ExternalClientConfig`.
+pub trait MlsConfig: Send + Sync + Clone + Sealed {
+ #[doc(hidden)]
+ type Output: ExternalClientConfig;
+
+ #[doc(hidden)]
+ fn get(&self) -> &Self::Output;
+}
+
+/// Blanket implementation so that `T: MlsConfig` implies `T: ExternalClientConfig`
+impl<T: MlsConfig> ExternalClientConfig for T {
+ type IdentityProvider = <T::Output as ExternalClientConfig>::IdentityProvider;
+ type MlsRules = <T::Output as ExternalClientConfig>::MlsRules;
+ type CryptoProvider = <T::Output as ExternalClientConfig>::CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType> {
+ self.get().supported_extensions()
+ }
+
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion> {
+ self.get().supported_protocol_versions()
+ }
+
+ fn supported_custom_proposals(&self) -> Vec<ProposalType> {
+ self.get().supported_custom_proposals()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.get().identity_provider()
+ }
+
+ fn crypto_provider(&self) -> Self::CryptoProvider {
+ self.get().crypto_provider()
+ }
+
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey> {
+ self.get().external_signing_key(external_key_id)
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.get().mls_rules()
+ }
+
+ fn cache_proposals(&self) -> bool {
+ self.get().cache_proposals()
+ }
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ self.get().max_epoch_jitter()
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ self.get().capabilities()
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.get().version_supported(version)
+ }
+
+ fn supported_credentials(&self) -> Vec<CredentialType> {
+ self.get().supported_credentials()
+ }
+}
+
+#[derive(Clone)]
+pub(crate) struct Settings {
+ pub(crate) extension_types: Vec<ExtensionType>,
+ pub(crate) custom_proposal_types: Vec<ProposalType>,
+ pub(crate) protocol_versions: Vec<ProtocolVersion>,
+ pub(crate) external_signing_keys: HashMap<Vec<u8>, SignaturePublicKey>,
+ pub(crate) max_epoch_jitter: Option<u64>,
+ pub(crate) cache_proposals: bool,
+}
+
+impl Debug for Settings {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Settings")
+ .field("extension_types", &self.extension_types)
+ .field("custom_proposal_types", &self.custom_proposal_types)
+ .field("protocol_versions", &self.protocol_versions)
+ .field(
+ "external_signing_keys",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.external_signing_keys
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .field("max_epoch_jitter", &self.max_epoch_jitter)
+ .field("cache_proposals", &self.cache_proposals)
+ .finish()
+ }
+}
+
+impl Default for Settings {
+ fn default() -> Self {
+ Self {
+ cache_proposals: true,
+ extension_types: vec![],
+ protocol_versions: vec![],
+ external_signing_keys: Default::default(),
+ max_epoch_jitter: None,
+ custom_proposal_types: vec![],
+ }
+ }
+}
+
+/// Definitions meant to be private that are inaccessible outside this crate. They need to be marked
+/// `pub` because they appear in public definitions.
+mod private {
+ use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity};
+
+ use super::{IntoConfigOutput, Settings};
+
+ #[derive(Clone, Debug)]
+ pub struct Config<Ip, Pr, Cp>(pub(crate) ConfigInner<Ip, Pr, Cp>);
+
+ #[derive(Clone, Debug)]
+ pub struct ConfigInner<Ip, Mpf, Cp> {
+ pub(crate) settings: Settings,
+ pub(crate) identity_provider: Ip,
+ pub(crate) mls_rules: Mpf,
+ pub(crate) crypto_provider: Cp,
+ pub(crate) signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ }
+
+ pub trait IntoConfig {
+ type IdentityProvider;
+ type MlsRules;
+ type CryptoProvider;
+
+ fn into_config(self) -> IntoConfigOutput<Self>;
+ }
+
+ impl<Ip, Pr, Cp> IntoConfig for Config<Ip, Pr, Cp> {
+ type IdentityProvider = Ip;
+ type MlsRules = Pr;
+ type CryptoProvider = Cp;
+
+ fn into_config(self) -> Self {
+ self
+ }
+ }
+}
+
+use mls_rs_core::{
+ crypto::SignatureSecretKey,
+ identity::{IdentityProvider, SigningIdentity},
+};
+use private::{Config, ConfigInner, IntoConfig};
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ cipher_suite::CipherSuite, crypto::test_utils::TestCryptoProvider,
+ identity::basic::BasicIdentityProvider,
+ };
+
+ use super::{
+ ExternalBaseConfig, ExternalClientBuilder, WithCryptoProvider, WithIdentityProvider,
+ };
+
+ pub type TestExternalClientConfig = WithIdentityProvider<
+ BasicIdentityProvider,
+ WithCryptoProvider<TestCryptoProvider, ExternalBaseConfig>,
+ >;
+
+ pub type TestExternalClientBuilder = ExternalClientBuilder<TestExternalClientConfig>;
+
+ impl TestExternalClientBuilder {
+ pub fn new_for_test() -> Self {
+ ExternalClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::default())
+ .identity_provider(BasicIdentityProvider::new())
+ }
+
+ pub fn new_for_test_disabling_cipher_suite(cipher_suite: CipherSuite) -> Self {
+ let crypto_provider = TestCryptoProvider::with_enabled_cipher_suites(
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .filter(|cs| cs != &cipher_suite)
+ .collect(),
+ );
+
+ ExternalClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new())
+ }
+ }
+}
diff --git a/src/external_client/config.rs b/src/external_client/config.rs
new file mode 100644
index 0000000..649be99
--- /dev/null
+++ b/src/external_client/config.rs
@@ -0,0 +1,54 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::identity::IdentityProvider;
+
+use crate::{
+ crypto::SignaturePublicKey,
+ extension::ExtensionType,
+ group::{mls_rules::MlsRules, proposal::ProposalType},
+ identity::CredentialType,
+ protocol_version::ProtocolVersion,
+ tree_kem::Capabilities,
+ CryptoProvider,
+};
+
+pub trait ExternalClientConfig: Send + Sync + Clone {
+ type IdentityProvider: IdentityProvider + Clone;
+ type MlsRules: MlsRules + Clone;
+ type CryptoProvider: CryptoProvider;
+
+ fn supported_extensions(&self) -> Vec<ExtensionType>;
+ fn supported_custom_proposals(&self) -> Vec<ProposalType>;
+ fn supported_protocol_versions(&self) -> Vec<ProtocolVersion>;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn crypto_provider(&self) -> Self::CryptoProvider;
+ fn external_signing_key(&self, external_key_id: &[u8]) -> Option<SignaturePublicKey>;
+
+ fn mls_rules(&self) -> Self::MlsRules;
+
+ fn cache_proposals(&self) -> bool;
+
+ fn max_epoch_jitter(&self) -> Option<u64> {
+ None
+ }
+
+ fn capabilities(&self) -> Capabilities {
+ Capabilities {
+ protocol_versions: self.supported_protocol_versions(),
+ cipher_suites: self.crypto_provider().supported_cipher_suites(),
+ extensions: self.supported_extensions(),
+ proposals: self.supported_custom_proposals(),
+ credentials: self.supported_credentials(),
+ }
+ }
+
+ fn version_supported(&self, version: ProtocolVersion) -> bool {
+ self.supported_protocol_versions().contains(&version)
+ }
+
+ fn supported_credentials(&self) -> Vec<CredentialType> {
+ self.identity_provider().supported_types()
+ }
+}
diff --git a/src/external_client/group.rs b/src/external_client/group.rs
new file mode 100644
index 0000000..8939948
--- /dev/null
+++ b/src/external_client/group.rs
@@ -0,0 +1,1354 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::SignatureSecretKey, error::IntoAnyError, extension::ExtensionList, group::Member,
+ identity::IdentityProvider,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ external_client::ExternalClientConfig,
+ group::{
+ cipher_suite_provider,
+ confirmation_tag::ConfirmationTag,
+ framing::PublicMessage,
+ member_from_leaf_node,
+ message_processor::{
+ ApplicationMessageDescription, CommitMessageDescription, EventOrContent,
+ MessageProcessor, ProposalMessageDescription, ProvisionalState,
+ },
+ snapshot::RawGroupState,
+ state::GroupState,
+ transcript_hash::InterimTranscriptHash,
+ validate_group_info_joiner, ContentType, ExportedTree, GroupContext, GroupInfo, Roster,
+ Welcome,
+ },
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::AlwaysFoundPskStorage,
+ tree_kem::{node::LeafIndex, path_secret::PathSecret, TreeKemPrivate},
+ CryptoProvider, KeyPackage, MlsMessage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ group::{
+ framing::{Content, MlsMessagePayload},
+ message_processor::CachedProposal,
+ message_signature::AuthenticatedContent,
+ proposal::Proposal,
+ proposal_ref::ProposalRef,
+ Sender,
+ },
+ WireFormat,
+};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+use crate::group::proposal::CustomProposal;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ extension::ExternalSendersExt,
+ group::proposal::{AddProposal, ReInitProposal, RemoveProposal},
+};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{
+ JustPreSharedKeyID, PreSharedKeyID, PskGroupId, PskNonce, ResumptionPSKUsage, ResumptionPsk,
+ },
+};
+
+#[cfg(feature = "private_message")]
+use crate::group::framing::PrivateMessage;
+
+use alloc::boxed::Box;
+
+/// The result of processing an [ExternalGroup](ExternalGroup) message using
+/// [process_incoming_message](ExternalGroup::process_incoming_message)
+#[derive(Clone, Debug)]
+#[allow(clippy::large_enum_variant)]
+pub enum ExternalReceivedMessage {
+ /// State update as the result of a successful commit.
+ Commit(CommitMessageDescription),
+ /// Received proposal and its unique identifier.
+ Proposal(ProposalMessageDescription),
+ /// Encrypted message that can not be processed.
+ Ciphertext(ContentType),
+ /// Validated GroupInfo object
+ GroupInfo(GroupInfo),
+ /// Validated welcome message
+ Welcome,
+ /// Validated key package
+ KeyPackage(KeyPackage),
+}
+
+/// A handle to an observed group that can track plaintext control messages
+/// and the resulting group state.
+#[derive(Clone)]
+pub struct ExternalGroup<C>
+where
+ C: ExternalClientConfig,
+{
+ pub(crate) config: C,
+ pub(crate) cipher_suite_provider: <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider,
+ pub(crate) state: GroupState,
+ pub(crate) signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl<C: ExternalClientConfig + Clone> ExternalGroup<C> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(
+ config: C,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+ group_info: MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<Self, MlsError> {
+ let protocol_version = group_info.version;
+
+ if !config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .into_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ group_info.group_context.cipher_suite,
+ )?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ tree_data,
+ &config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ &cipher_suite_provider,
+ &group_info.group_context.confirmed_transcript_hash,
+ &group_info.confirmation_tag,
+ )
+ .await?;
+
+ Ok(Self {
+ config,
+ signing_data,
+ state: GroupState::new(
+ group_info.group_context,
+ public_tree,
+ interim_transcript_hash,
+ group_info.confirmation_tag,
+ ),
+ cipher_suite_provider,
+ })
+ }
+
+ /// Process a message that was sent to the group.
+ ///
+ /// * Proposals will be stored in the group state and processed by the
+ /// same rules as a standard group.
+ ///
+ /// * Commits will result in the same outcome as a standard group.
+ /// However, the integrity of the resulting group state can only be partially
+ /// verified, since the external group does have access to the group
+ /// secrets required to do a complete check.
+ ///
+ /// * Application messages are always encrypted so they result in a no-op
+ /// that returns [ExternalReceivedMessage::Ciphertext]
+ ///
+ /// # Warning
+ ///
+ /// Processing an encrypted commit or proposal message has the same result
+ /// as processing an encrypted application message. Proper tracking of
+ /// the group state requires that all proposal and commit messages are
+ /// readable.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ExternalReceivedMessage, MlsError> {
+ MessageProcessor::process_incoming_message(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ self.config.cache_proposals(),
+ )
+ .await
+ }
+
+ /// Replay a proposal message into the group skipping all validation steps.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn insert_proposal_from_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<(), MlsError> {
+ let ptxt = match message.payload {
+ MlsMessagePayload::Plain(p) => Ok(p),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ let auth_content: AuthenticatedContent = ptxt.into();
+
+ let proposal_ref =
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
+
+ let sender = auth_content.content.sender;
+
+ let proposal = match auth_content.content.content {
+ Content::Proposal(p) => Ok(*p),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ self.group_state_mut()
+ .proposals
+ .insert(proposal_ref, proposal, sender);
+
+ Ok(())
+ }
+
+ /// Force insert a proposal directly into the internal state of the group
+ /// with no validation.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn insert_proposal(&mut self, proposal: CachedProposal) {
+ self.group_state_mut().proposals.insert(
+ proposal.proposal_ref,
+ proposal.proposal,
+ proposal.sender,
+ )
+ }
+
+ /// Create an external proposal to request that a group add a new member
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_add(
+ &mut self,
+ key_package: MlsMessage,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_package = key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ self.propose(
+ Proposal::Add(alloc::boxed::Box::new(AddProposal { key_package })),
+ authenticated_data,
+ )
+ .await
+ }
+
+ /// Create an external proposal to request that a group remove an existing member
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_remove(
+ &mut self,
+ index: u32,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let to_remove = LeafIndex(index);
+
+ // Verify that this leaf is actually in the tree
+ self.group_state().public_tree.get_leaf_node(to_remove)?;
+
+ self.propose(
+ Proposal::Remove(RemoveProposal { to_remove }),
+ authenticated_data,
+ )
+ .await
+ }
+
+ /// Create an external proposal to request that a group inserts an external
+ /// pre shared key into its state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_external_psk(
+ &mut self,
+ psk: ExternalPskId,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.psk_proposal(JustPreSharedKeyID::External(psk))?;
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create an external proposal to request that a group adds a pre shared key
+ /// from a previous epoch to the current group state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_resumption_psk(
+ &mut self,
+ psk_epoch: u64,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group_context().group_id().to_vec()),
+ };
+
+ let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?;
+ self.propose(proposal, authenticated_data).await
+ }
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ fn psk_proposal(&self, key_id: JustPreSharedKeyID) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id,
+ psk_nonce: PskNonce::random(&self.cipher_suite_provider)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
+ },
+ }))
+ }
+
+ /// Create an external proposal to request that a group sets extensions stored in the group
+ /// state.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_group_context_extensions(
+ &mut self,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = Proposal::GroupContextExtensions(extensions);
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create an external proposal to request that a group is reinitialized.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_reinit(
+ &mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ self.cipher_suite_provider
+ .random_bytes_vec(self.cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ let proposal = Proposal::ReInit(ReInitProposal {
+ group_id,
+ version,
+ cipher_suite,
+ extensions,
+ });
+
+ self.propose(proposal, authenticated_data).await
+ }
+
+ /// Create a custom proposal message.
+ ///
+ /// # Warning
+ ///
+ /// In order for the proposal generated by this function to be successfully
+ /// committed, the group needs to have `signing_identity` as an entry
+ /// within an [ExternalSendersExt](crate::extension::built_in::ExternalSendersExt)
+ /// as part of its group context extensions.
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_custom(
+ &mut self,
+ proposal: CustomProposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ self.propose(Proposal::Custom(proposal), authenticated_data)
+ .await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn propose(
+ &mut self,
+ proposal: Proposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let (signer, signing_identity) =
+ self.signing_data.as_ref().ok_or(MlsError::SignerNotFound)?;
+
+ let external_senders_ext = self
+ .state
+ .context
+ .extensions
+ .get_as::<ExternalSendersExt>()?
+ .ok_or(MlsError::ExternalProposalsDisabled)?;
+
+ let sender_index = external_senders_ext
+ .allowed_senders
+ .iter()
+ .position(|allowed_signer| signing_identity == allowed_signer)
+ .ok_or(MlsError::InvalidExternalSigningIdentity)?;
+
+ let sender = Sender::External(sender_index as u32);
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ &self.state.context,
+ sender,
+ Content::Proposal(Box::new(proposal.clone())),
+ signer,
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ self.state.proposals.insert(
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?,
+ proposal,
+ sender,
+ );
+
+ let plaintext = PublicMessage {
+ content: auth_content.content,
+ auth: auth_content.auth,
+ membership_tag: None,
+ };
+
+ Ok(MlsMessage::new(
+ self.group_context().version(),
+ MlsMessagePayload::Plain(plaintext),
+ ))
+ }
+
+ /// Delete all sent and received proposals cached for commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn clear_proposal_cache(&mut self) {
+ self.state.proposals.clear()
+ }
+
+ #[inline(always)]
+ pub(crate) fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ /// Get the current group context summarizing various information about the group.
+ #[inline(always)]
+ pub fn group_context(&self) -> &GroupContext {
+ &self.group_state().context
+ }
+
+ /// Export the current ratchet tree used within the group.
+ pub fn export_tree(&self) -> Result<Vec<u8>, MlsError> {
+ self.group_state()
+ .public_tree
+ .nodes
+ .mls_encode_to_vec()
+ .map_err(Into::into)
+ }
+
+ /// Get the current roster of the group.
+ #[inline(always)]
+ pub fn roster(&self) -> Roster {
+ self.group_state().public_tree.roster()
+ }
+
+ /// Get the
+ /// [transcript hash](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-transcript-hashes)
+ /// for the current epoch that the group is in.
+ #[inline(always)]
+ pub fn transcript_hash(&self) -> &Vec<u8> {
+ &self.group_state().context.confirmed_transcript_hash
+ }
+
+ /// Get the
+ /// [tree hash](https://www.rfc-editor.org/rfc/rfc9420.html#name-tree-hashes)
+ /// for the current epoch that the group is in.
+ #[inline(always)]
+ pub fn tree_hash(&self) -> &[u8] {
+ &self.group_state().context.tree_hash
+ }
+
+ /// Find a member based on their identity.
+ ///
+ /// Identities are matched based on the
+ /// [IdentityProvider](crate::IdentityProvider)
+ /// that this group was configured with.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_member_with_identity(
+ &self,
+ identity_id: &SigningIdentity,
+ ) -> Result<Member, MlsError> {
+ let identity = self
+ .identity_provider()
+ .identity(identity_id, self.group_context().extensions())
+ .await
+ .map_err(|error| MlsError::IdentityProviderError(error.into_any_error()))?;
+
+ let tree = &self.group_state().public_tree;
+
+ #[cfg(feature = "tree_index")]
+ let index = tree.get_leaf_node_with_identity(&identity);
+
+ #[cfg(not(feature = "tree_index"))]
+ let index = tree
+ .get_leaf_node_with_identity(
+ &identity,
+ &self.identity_provider(),
+ self.group_context().extensions(),
+ )
+ .await?;
+
+ let index = index.ok_or(MlsError::MemberNotFound)?;
+ let node = self.group_state().public_tree.get_leaf_node(index)?;
+
+ Ok(member_from_leaf_node(node, index))
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl<C> MessageProcessor for ExternalGroup<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ type MlsRules = C::MlsRules;
+ type IdentityProvider = C::IdentityProvider;
+ type PreSharedKeyStorage = AlwaysFoundPskStorage;
+ type OutputType = ExternalReceivedMessage;
+ type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;
+
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex> {
+ None
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.config.mls_rules()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ let auth_content = crate::group::message_verifier::verify_plaintext_authentication(
+ &self.cipher_suite_provider,
+ message,
+ None,
+ None,
+ &self.state,
+ )
+ .await?;
+
+ Ok(EventOrContent::Content(auth_content))
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ Ok(EventOrContent::Event(ExternalReceivedMessage::Ciphertext(
+ cipher_text.content_type,
+ )))
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ _secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ self.state.context = provisional_public_state.group_context;
+ #[cfg(feature = "by_ref_proposal")]
+ self.state.proposals.clear();
+ self.state.interim_transcript_hash = interim_transcript_hash;
+ self.state.public_tree = provisional_public_state.public_tree;
+ self.state.confirmation_tag = confirmation_tag.clone();
+
+ Ok(())
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.config.identity_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ AlwaysFoundPskStorage
+ }
+
+ fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ &mut self.state
+ }
+
+ fn can_continue_processing(&self, _provisional_state: &ProvisionalState) -> bool {
+ true
+ }
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64> {
+ self.config
+ .max_epoch_jitter()
+ .map(|j| self.state.context.epoch - j)
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ &self.cipher_suite_provider
+ }
+}
+
+/// Serializable snapshot of an [ExternalGroup](ExternalGroup) state.
+#[derive(Debug, MlsEncode, MlsSize, MlsDecode, PartialEq, Clone)]
+pub struct ExternalSnapshot {
+ version: u16,
+ state: RawGroupState,
+ signing_data: Option<(SignatureSecretKey, SigningIdentity)>,
+}
+
+impl ExternalSnapshot {
+ /// Serialize the snapshot
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.mls_encode_to_vec()?)
+ }
+
+ /// Deserialize the snapshot
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Ok(Self::mls_decode(&mut &*bytes)?)
+ }
+}
+
+impl<C> ExternalGroup<C>
+where
+ C: ExternalClientConfig + Clone,
+{
+ /// Create a snapshot of this group's current internal state.
+ pub fn snapshot(&self) -> ExternalSnapshot {
+ ExternalSnapshot {
+ state: RawGroupState::export(self.group_state()),
+ version: 1,
+ signing_data: self.signing_data.clone(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_snapshot(
+ config: C,
+ snapshot: ExternalSnapshot,
+ ) -> Result<Self, MlsError> {
+ #[cfg(feature = "tree_index")]
+ let identity_provider = config.identity_provider();
+
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ snapshot.state.context.cipher_suite,
+ )?;
+
+ Ok(ExternalGroup {
+ config,
+ signing_data: snapshot.signing_data,
+ state: snapshot
+ .state
+ .import(
+ #[cfg(feature = "tree_index")]
+ &identity_provider,
+ )
+ .await?,
+ cipher_suite_provider,
+ })
+ }
+}
+
+impl From<CommitMessageDescription> for ExternalReceivedMessage {
+ fn from(value: CommitMessageDescription) -> Self {
+ ExternalReceivedMessage::Commit(value)
+ }
+}
+
+impl TryFrom<ApplicationMessageDescription> for ExternalReceivedMessage {
+ type Error = MlsError;
+
+ fn try_from(_: ApplicationMessageDescription) -> Result<Self, Self::Error> {
+ Err(MlsError::UnencryptedApplicationMessage)
+ }
+}
+
+impl From<ProposalMessageDescription> for ExternalReceivedMessage {
+ fn from(value: ProposalMessageDescription) -> Self {
+ ExternalReceivedMessage::Proposal(value)
+ }
+}
+
+impl From<GroupInfo> for ExternalReceivedMessage {
+ fn from(value: GroupInfo) -> Self {
+ ExternalReceivedMessage::GroupInfo(value)
+ }
+}
+
+impl From<Welcome> for ExternalReceivedMessage {
+ fn from(_: Welcome) -> Self {
+ ExternalReceivedMessage::Welcome
+ }
+}
+
+impl From<KeyPackage> for ExternalReceivedMessage {
+ fn from(value: KeyPackage) -> Self {
+ ExternalReceivedMessage::KeyPackage(value)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::{
+ external_client::tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
+ group::test_utils::TestGroup,
+ };
+
+ use super::ExternalGroup;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_external_group(
+ group: &TestGroup,
+ ) -> ExternalGroup<TestExternalClientConfig> {
+ make_external_group_with_config(
+ group,
+ TestExternalClientBuilder::new_for_test().build_config(),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_external_group_with_config(
+ group: &TestGroup,
+ config: TestExternalClientConfig,
+ ) -> ExternalGroup<TestExternalClientConfig> {
+ ExternalGroup::join(
+ config,
+ None,
+ group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ None,
+ )
+ .await
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::make_external_group;
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::{
+ test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ MlsError,
+ },
+ crypto::{test_utils::TestCryptoProvider, SignatureSecretKey},
+ extension::ExternalSendersExt,
+ external_client::{
+ group::test_utils::make_external_group_with_config,
+ tests_utils::{TestExternalClientBuilder, TestExternalClientConfig},
+ ExternalGroup, ExternalReceivedMessage, ExternalSnapshot,
+ },
+ group::{
+ framing::{Content, MlsMessagePayload},
+ proposal::{AddProposal, Proposal, ProposalOrRef},
+ proposal_ref::ProposalRef,
+ test_utils::{test_group, TestGroup},
+ ProposalMessageDescription,
+ },
+ identity::{test_utils::get_test_signing_identity, SigningIdentity},
+ key_package::test_utils::{test_key_package, test_key_package_message},
+ protocol_version::ProtocolVersion,
+ ExtensionList, MlsMessage,
+ };
+ use assert_matches::assert_matches;
+ use mls_rs_codec::{MlsDecode, MlsEncode};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_group_with_one_commit(v: ProtocolVersion, cs: CipherSuite) -> TestGroup {
+ let mut group = test_group(v, cs).await;
+ group.group.commit(Vec::new()).await.unwrap();
+ group.process_pending_commit().await.unwrap();
+ group
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_group_two_members(
+ v: ProtocolVersion,
+ cs: CipherSuite,
+ #[cfg(feature = "by_ref_proposal")] ext_identity: Option<SigningIdentity>,
+ ) -> TestGroup {
+ let mut group = test_group_with_one_commit(v, cs).await;
+
+ let bob_key_package = test_key_package_message(v, cs, "bob").await;
+
+ let mut commit_builder = group
+ .group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap();
+
+ #[cfg(feature = "by_ref_proposal")]
+ if let Some(ext_signer) = ext_identity {
+ let mut ext_list = ExtensionList::new();
+
+ ext_list
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![ext_signer],
+ })
+ .unwrap();
+
+ commit_builder = commit_builder.set_group_context_ext(ext_list).unwrap();
+ }
+
+ commit_builder.build().await.unwrap();
+
+ group.process_pending_commit().await.unwrap();
+ group
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_be_created() {
+ for (v, cs) in ProtocolVersion::all().flat_map(|v| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (v, cs))
+ }) {
+ make_external_group(&test_group_with_one_commit(v, cs).await).await;
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_commit() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+ let commit_output = alice.group.commit(Vec::new()).await.unwrap();
+ alice.group.apply_pending_commit().await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_proposals_by_reference() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let bob_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let add_proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: bob_key_package,
+ }));
+
+ let packet = alice.propose(add_proposal.clone()).await;
+
+ let proposal_process = server.process_incoming_message(packet).await.unwrap();
+
+ assert_matches!(
+ proposal_process,
+ ExternalReceivedMessage::Proposal(ProposalMessageDescription { ref proposal, ..}) if proposal == &add_proposal
+ );
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+ alice.group.apply_pending_commit().await.unwrap();
+
+ let commit_result = server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ commit_result,
+ ExternalReceivedMessage::Commit(commit_description)
+ if commit_description.state_update.roster_update.added().iter().any(|added| added.index == 1)
+ );
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(commit_result, ExternalReceivedMessage::Commit(_));
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_process_commit_adding_member() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+ let (_, commit) = alice.join("bob").await;
+
+ let update = match server.process_incoming_message(commit).await.unwrap() {
+ ExternalReceivedMessage::Commit(update) => update.state_update,
+ _ => panic!("Expected processed commit"),
+ };
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(update.roster_update.added().len(), 1);
+
+ assert_eq!(server.state.public_tree.get_leaf_nodes().len(), 2);
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_rejects_commit_not_for_current_epoch() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let mut commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain.content.epoch = 0,
+ _ => panic!("Unexpected non-plaintext data"),
+ };
+
+ let res = server
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidEpoch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_reject_message_with_invalid_signature() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test().build_config(),
+ )
+ .await;
+
+ let mut commit_output = alice.group.commit(Vec::new()).await.unwrap();
+
+ match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain.auth.signature = Vec::new().into(),
+ _ => panic!("Unexpected non-plaintext data"),
+ };
+
+ let res = server
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_rejects_unencrypted_application_message() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let plaintext = alice
+ .make_plaintext(Content::Application(b"hello".to_vec().into()))
+ .await;
+
+ let res = server.process_incoming_message(plaintext).await;
+
+ assert_matches!(res, Err(MlsError::UnencryptedApplicationMessage));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_will_reject_unsupported_cipher_suites() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let config =
+ TestExternalClientBuilder::new_for_test_disabling_cipher_suite(TEST_CIPHER_SUITE)
+ .build_config();
+
+ let res = ExternalGroup::join(
+ config,
+ None,
+ alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ None,
+ )
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCipherSuite(TEST_CIPHER_SUITE))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_will_reject_unsupported_protocol_versions() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let config = TestExternalClientBuilder::new_for_test().build_config();
+
+ let mut group_info = alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ group_info.version = ProtocolVersion::from(64);
+
+ let res = ExternalGroup::join(config, None, group_info, None)
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedProtocolVersion(v)) if v ==
+ ProtocolVersion::from(64)
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn setup_extern_proposal_test(
+ extern_proposals_allowed: bool,
+ ) -> (SigningIdentity, SignatureSecretKey, TestGroup) {
+ let (server_identity, server_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"server").await;
+
+ let alice = test_group_two_members(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ extern_proposals_allowed.then(|| server_identity.clone()),
+ )
+ .await;
+
+ (server_identity, server_key, alice)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_external_proposal(
+ server: &mut ExternalGroup<TestExternalClientConfig>,
+ alice: &mut TestGroup,
+ external_proposal: MlsMessage,
+ ) {
+ let auth_content = external_proposal.clone().into_plaintext().unwrap().into();
+
+ let proposal_ref = ProposalRef::from_content(&server.cipher_suite_provider, &auth_content)
+ .await
+ .unwrap();
+
+ // Alice receives the proposal
+ alice.process_message(external_proposal).await.unwrap();
+
+ // Alice commits the proposal
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ let commit = match commit_output
+ .commit_message
+ .clone()
+ .into_plaintext()
+ .unwrap()
+ .content
+ .content
+ {
+ Content::Commit(commit) => commit,
+ _ => panic!("not a commit"),
+ };
+
+ // The proposal should be in the resulting commit
+ assert!(commit
+ .proposals
+ .contains(&ProposalOrRef::Reference(proposal_ref)));
+
+ alice.process_pending_commit().await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ assert_eq!(alice.group.state, server.state);
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_propose_add() {
+ let (server_identity, server_key, mut alice) = setup_extern_proposal_test(true).await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let charlie_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "charlie").await;
+
+ let external_proposal = server
+ .propose_add(charlie_key_package, vec![])
+ .await
+ .unwrap();
+
+ test_external_proposal(&mut server, &mut alice, external_proposal).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_propose_remove() {
+ let (server_identity, server_key, mut alice) = setup_extern_proposal_test(true).await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let external_proposal = server.propose_remove(1, vec![]).await.unwrap();
+
+ test_external_proposal(&mut server, &mut alice, external_proposal).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_external_proposal_not_allowed() {
+ let (signing_id, secret_key, alice) = setup_extern_proposal_test(false).await;
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((secret_key, signing_id));
+
+ let charlie_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "charlie").await;
+
+ let res = server.propose_add(charlie_key_package, vec![]).await;
+
+ assert_matches!(res, Err(MlsError::ExternalProposalsDisabled));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_external_signing_identity_invalid() {
+ let (server_identity, server_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"server").await;
+
+ let alice = test_group_two_members(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Some(
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"not server")
+ .await
+ .0,
+ ),
+ )
+ .await;
+
+ let mut server = make_external_group(&alice).await;
+
+ server.signing_data = Some((server_key, server_identity));
+
+ let res = server.propose_remove(1, vec![]).await;
+
+ assert_matches!(res, Err(MlsError::InvalidExternalSigningIdentity));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_errors_on_old_epoch() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test()
+ .max_epoch_jitter(0)
+ .build_config(),
+ )
+ .await;
+
+ let old_application_msg = alice
+ .group
+ .encrypt_application_message(&[], vec![])
+ .await
+ .unwrap();
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let res = server.process_incoming_message(old_application_msg).await;
+
+ assert_matches!(res, Err(MlsError::InvalidEpoch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposals_can_be_cached_externally() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut server = make_external_group_with_config(
+ &alice,
+ TestExternalClientBuilder::new_for_test()
+ .cache_proposals(false)
+ .build_config(),
+ )
+ .await;
+
+ let proposal = alice.group.propose_update(vec![]).await.unwrap();
+
+ let commit_output = alice.group.commit(vec![]).await.unwrap();
+
+ server
+ .process_incoming_message(proposal.clone())
+ .await
+ .unwrap();
+
+ server.insert_proposal_from_message(proposal).await.unwrap();
+
+ server
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_observe_since_creation() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let info = alice
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let config = TestExternalClientBuilder::new_for_test().build_config();
+ let mut server = ExternalGroup::join(config, None, info, None).await.unwrap();
+
+ for _ in 0..2 {
+ let commit = alice.group.commit(vec![]).await.unwrap().commit_message;
+ alice.process_pending_commit().await.unwrap();
+ server.process_incoming_message(commit).await.unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_be_serialized_to_tls_encoding() {
+ let server =
+ make_external_group(&test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await).await;
+
+ let snapshot = server.snapshot().mls_encode_to_vec().unwrap();
+ let snapshot_restored = ExternalSnapshot::mls_decode(&mut snapshot.as_slice()).unwrap();
+
+ let server_restored =
+ ExternalGroup::from_snapshot(server.config.clone(), snapshot_restored)
+ .await
+ .unwrap();
+
+ assert_eq!(server.group_state(), server_restored.group_state());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_info() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let info = alice
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap();
+
+ let update = server.process_incoming_message(info.clone()).await.unwrap();
+ let info = info.into_group_info().unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::GroupInfo(update_info) if update_info == info);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_key_package() {
+ let alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let kp = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await;
+
+ let update = server.process_incoming_message(kp.clone()).await.unwrap();
+ let kp = kp.into_key_package().unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::KeyPackage(update_kp) if update_kp == kp);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_group_can_validate_welcome() {
+ let mut alice = test_group_with_one_commit(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut server = make_external_group(&alice).await;
+
+ let [welcome] = alice
+ .group
+ .commit_builder()
+ .add_member(
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "john").await,
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages
+ .try_into()
+ .unwrap();
+
+ let update = server.process_incoming_message(welcome).await.unwrap();
+
+ assert_matches!(update, ExternalReceivedMessage::Welcome);
+ }
+}
diff --git a/src/grease.rs b/src/grease.rs
new file mode 100644
index 0000000..cd4f208
--- /dev/null
+++ b/src/grease.rs
@@ -0,0 +1,227 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList, group::Capabilities};
+
+use crate::{
+ client::MlsError,
+ group::{GroupInfo, NewMemberInfo},
+ key_package::KeyPackage,
+ tree_kem::leaf_node::LeafNode,
+};
+
+impl LeafNode {
+ pub fn ungreased_capabilities(&self) -> Capabilities {
+ let mut capabilitites = self.capabilities.clone();
+ grease_functions::ungrease(&mut capabilitites.cipher_suites);
+ grease_functions::ungrease(&mut capabilitites.extensions);
+ grease_functions::ungrease(&mut capabilitites.proposals);
+ grease_functions::ungrease(&mut capabilitites.credentials);
+ capabilitites
+ }
+
+ pub fn ungreased_extensions(&self) -> ExtensionList {
+ let mut extensions = self.extensions.clone();
+ grease_functions::ungrease_extensions(&mut extensions);
+ extensions
+ }
+
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease(&mut self.capabilities.cipher_suites, cs)?;
+ grease_functions::grease(&mut self.capabilities.proposals, cs)?;
+ grease_functions::grease(&mut self.capabilities.credentials, cs)?;
+
+ let mut new_extensions = grease_functions::grease_extensions(&mut self.extensions, cs)?;
+ self.capabilities.extensions.append(&mut new_extensions);
+
+ Ok(())
+ }
+}
+
+impl KeyPackage {
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
+ }
+
+ pub fn ungreased_extensions(&self) -> ExtensionList {
+ let mut extensions = self.extensions.clone();
+ grease_functions::ungrease_extensions(&mut extensions);
+ extensions
+ }
+}
+
+impl GroupInfo {
+ pub fn grease<P: CipherSuiteProvider>(&mut self, cs: &P) -> Result<(), MlsError> {
+ grease_functions::grease_extensions(&mut self.extensions, cs).map(|_| ())
+ }
+}
+
+impl NewMemberInfo {
+ pub fn ungrease(&mut self) {
+ grease_functions::ungrease_extensions(&mut self.group_info_extensions)
+ }
+}
+
+#[cfg(feature = "grease")]
+mod grease_functions {
+ use core::ops::Deref;
+
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ error::IntoAnyError,
+ extension::{Extension, ExtensionList, ExtensionType},
+ };
+
+ use super::MlsError;
+
+ pub const GREASE_VALUES: &[u16] = &[
+ 0x0A0A, 0x1A1A, 0x2A2A, 0x3A3A, 0x4A4A, 0x5A5A, 0x6A6A, 0x7A7A, 0x8A8A, 0x9A9A, 0xAAAA,
+ 0xBABA, 0xCACA, 0xDADA, 0xEAEA,
+ ];
+
+ pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
+ array: &mut Vec<T>,
+ cs: &P,
+ ) -> Result<(), MlsError> {
+ array.push(random_grease_value(cs)?.into());
+ Ok(())
+ }
+
+ pub fn grease_extensions<P: CipherSuiteProvider>(
+ extensions: &mut ExtensionList,
+ cs: &P,
+ ) -> Result<Vec<ExtensionType>, MlsError> {
+ let grease_value = random_grease_value(cs)?;
+ extensions.set(Extension::new(grease_value.into(), vec![]));
+ Ok(vec![grease_value.into()])
+ }
+
+ fn random_grease_value<P: CipherSuiteProvider>(cs: &P) -> Result<u16, MlsError> {
+ let index = cs
+ .random_bytes_vec(1)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?[0];
+
+ Ok(GREASE_VALUES[index as usize % GREASE_VALUES.len()])
+ }
+
+ pub fn ungrease<T: Deref<Target = u16>>(array: &mut Vec<T>) {
+ array.retain(|x| !GREASE_VALUES.contains(&**x));
+ }
+
+ pub fn ungrease_extensions(extensions: &mut ExtensionList) {
+ for e in GREASE_VALUES {
+ extensions.remove((*e).into())
+ }
+ }
+}
+
+#[cfg(not(feature = "grease"))]
+mod grease_functions {
+ use core::ops::Deref;
+
+ use alloc::vec::Vec;
+
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ extension::{ExtensionList, ExtensionType},
+ };
+
+ use super::MlsError;
+
+ pub fn grease<T: From<u16>, P: CipherSuiteProvider>(
+ _array: &mut [T],
+ _cs: &P,
+ ) -> Result<(), MlsError> {
+ Ok(())
+ }
+
+ pub fn grease_extensions<P: CipherSuiteProvider>(
+ _extensions: &mut ExtensionList,
+ _cs: &P,
+ ) -> Result<Vec<ExtensionType>, MlsError> {
+ Ok(Vec::new())
+ }
+
+ pub fn ungrease<T: Deref<Target = u16>>(_array: &mut [T]) {}
+
+ pub fn ungrease_extensions(_extensions: &mut ExtensionList) {}
+}
+
+#[cfg(all(test, feature = "grease"))]
+mod tests {
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use std::ops::Deref;
+
+ use mls_rs_core::extension::ExtensionList;
+
+ use crate::{
+ client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::test_utils::test_group,
+ };
+
+ use super::grease_functions::GREASE_VALUES;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn key_package_is_greased() {
+ let key_pkg = test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .1
+ .into_key_package()
+ .unwrap();
+
+ assert!(is_ext_greased(&key_pkg.extensions));
+ assert!(is_ext_greased(&key_pkg.leaf_node.extensions));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.cipher_suites));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.extensions));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.proposals));
+ assert!(is_greased(&key_pkg.leaf_node.capabilities.credentials));
+
+ assert!(!is_greased(
+ &key_pkg.leaf_node.capabilities.protocol_versions
+ ));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_info_is_greased() {
+ let group_info = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(is_ext_greased(&group_info.extensions));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn public_api_is_not_greased() {
+ let member = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group
+ .roster()
+ .member_with_index(0)
+ .unwrap();
+
+ assert!(!is_ext_greased(member.extensions()));
+ assert!(!is_greased(member.capabilities().protocol_versions()));
+ assert!(!is_greased(member.capabilities().cipher_suites()));
+ assert!(!is_greased(member.capabilities().extensions()));
+ assert!(!is_greased(member.capabilities().proposals()));
+ assert!(!is_greased(member.capabilities().credentials()));
+ }
+
+ fn is_greased<T: Deref<Target = u16>>(list: &[T]) -> bool {
+ list.iter().any(|v| GREASE_VALUES.contains(v))
+ }
+
+ fn is_ext_greased(extensions: &ExtensionList) -> bool {
+ extensions
+ .iter()
+ .any(|ext| GREASE_VALUES.contains(&*ext.extension_type()))
+ }
+}
diff --git a/src/group/ciphertext_processor.rs b/src/group/ciphertext_processor.rs
new file mode 100644
index 0000000..bf70f5d
--- /dev/null
+++ b/src/group/ciphertext_processor.rs
@@ -0,0 +1,410 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use self::{
+ message_key::MessageKey,
+ reuse_guard::ReuseGuard,
+ sender_data_key::{SenderData, SenderDataAAD, SenderDataKey},
+};
+
+use super::{
+ epoch::EpochSecrets,
+ framing::{ContentType, FramedContent, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ secret_tree::{KeyType, MessageKeyData},
+ GroupContext,
+};
+use crate::{
+ client::MlsError,
+ tree_kem::node::{LeafIndex, NodeIndex},
+};
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+use zeroize::Zeroizing;
+
+mod message_key;
+mod reuse_guard;
+mod sender_data_key;
+
+#[cfg(feature = "private_message")]
+use super::framing::{PrivateContentAAD, PrivateMessage, PrivateMessageContent};
+
+#[cfg(test)]
+pub use sender_data_key::test_utils::*;
+
+pub(crate) trait GroupStateProvider {
+ fn group_context(&self) -> &GroupContext;
+ fn self_index(&self) -> LeafIndex;
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets;
+ fn epoch_secrets(&self) -> &EpochSecrets;
+}
+
+pub(crate) struct CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+}
+
+impl<'a, GS, CP> CiphertextProcessor<'a, GS, CP>
+where
+ GS: GroupStateProvider,
+ CP: CipherSuiteProvider,
+{
+ pub fn new(
+ group_state: &'a mut GS,
+ cipher_suite_provider: CP,
+ ) -> CiphertextProcessor<'a, GS, CP> {
+ Self {
+ group_state,
+ cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_encryption_key(
+ &mut self,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let self_index = NodeIndex::from(self.group_state.self_index());
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .next_message_key(&self.cipher_suite_provider, self_index, key_type)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decryption_key(
+ &mut self,
+ sender: LeafIndex,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let sender = NodeIndex::from(sender);
+
+ self.group_state
+ .epoch_secrets_mut()
+ .secret_tree
+ .message_key_generation(&self.cipher_suite_provider, sender, key_type, generation)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn seal(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ padding: PaddingMode,
+ ) -> Result<PrivateMessage, MlsError> {
+ if Sender::Member(*self.group_state.self_index()) != auth_content.content.sender {
+ return Err(MlsError::InvalidSender);
+ }
+
+ let content_type = ContentType::from(&auth_content.content.content);
+ let authenticated_data = auth_content.content.authenticated_data;
+
+ // Build a ciphertext content using the plaintext content and signature
+ let private_content = PrivateMessageContent {
+ content: auth_content.content.content,
+ auth: auth_content.auth,
+ };
+
+ // Build ciphertext aad using the plaintext message
+ let aad = PrivateContentAAD {
+ group_id: auth_content.content.group_id,
+ epoch: auth_content.content.epoch,
+ content_type,
+ authenticated_data: authenticated_data.clone(),
+ };
+
+ // Generate a 4 byte reuse guard
+ let reuse_guard = ReuseGuard::random(&self.cipher_suite_provider)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Grab an encryption key from the current epoch's key schedule
+ let key_type = match &content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ let mut serialized_private_content = private_content.mls_encode_to_vec()?;
+
+ // Apply padding to private content based on the current padding mode.
+ serialized_private_content.resize(padding.padded_size(serialized_private_content.len()), 0);
+
+ let serialized_private_content = Zeroizing::new(serialized_private_content);
+
+ // Encrypt the ciphertext content using the encryption key and a nonce that is
+ // reuse safe by xor the reuse guard with the first 4 bytes
+ let self_index = self.group_state.self_index();
+
+ let key_data = self.next_encryption_key(key_type).await?;
+ let generation = key_data.generation;
+
+ let ciphertext = MessageKey::new(key_data)
+ .encrypt(
+ &self.cipher_suite_provider,
+ &serialized_private_content,
+ &aad.mls_encode_to_vec()?,
+ &reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ // Construct an mls sender data struct using the plaintext sender info, the generation
+ // of the key schedule encryption key, and the reuse guard used to encrypt ciphertext
+ let sender_data = SenderData {
+ sender: self_index,
+ generation,
+ reuse_guard,
+ };
+
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ };
+
+ // Encrypt the sender data with the derived sender_key and sender_nonce from the current
+ // epoch's key schedule
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let encrypted_sender_data = sender_data_key.seal(&sender_data, &sender_data_aad).await?;
+
+ Ok(PrivateMessage {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type,
+ authenticated_data,
+ encrypted_sender_data,
+ ciphertext,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn open(
+ &mut self,
+ ciphertext: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Decrypt the sender data with the derived sender_key and sender_nonce from the message
+ // epoch's key schedule
+ let sender_data_aad = SenderDataAAD {
+ group_id: self.group_state.group_context().group_id.clone(),
+ epoch: self.group_state.group_context().epoch,
+ content_type: ciphertext.content_type,
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &self.group_state.epoch_secrets().sender_data_secret,
+ &ciphertext.ciphertext,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let sender_data = sender_data_key
+ .open(&ciphertext.encrypted_sender_data, &sender_data_aad)
+ .await?;
+
+ if self.group_state.self_index() == sender_data.sender {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+
+ // Grab a decryption key from the message epoch's key schedule
+ let key_type = match &ciphertext.content_type {
+ ContentType::Application => KeyType::Application,
+ _ => KeyType::Handshake,
+ };
+
+ // Decrypt the content of the message using the grabbed key
+ let key = self
+ .decryption_key(sender_data.sender, key_type, sender_data.generation)
+ .await?;
+
+ let sender = Sender::Member(*sender_data.sender);
+
+ let decrypted_content = MessageKey::new(key)
+ .decrypt(
+ &self.cipher_suite_provider,
+ &ciphertext.ciphertext,
+ &PrivateContentAAD::from(ciphertext).mls_encode_to_vec()?,
+ &sender_data.reuse_guard,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let ciphertext_content =
+ PrivateMessageContent::mls_decode(&mut &**decrypted_content, ciphertext.content_type)?;
+
+ // Build the MLS plaintext object and process it
+ let auth_content = AuthenticatedContent {
+ wire_format: WireFormat::PrivateMessage,
+ content: FramedContent {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ sender,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ content: ciphertext_content.content,
+ },
+ auth: ciphertext_content.auth,
+ };
+
+ Ok(auth_content)
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{
+ test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ CipherSuiteProvider,
+ },
+ group::{
+ framing::{ApplicationData, Content, Sender, WireFormat},
+ message_signature::AuthenticatedContent,
+ padding::PaddingMode,
+ test_utils::{random_bytes, test_group, TestGroup},
+ },
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{CiphertextProcessor, GroupStateProvider, MlsError};
+
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ struct TestData {
+ group: TestGroup,
+ content: AuthenticatedContent,
+ }
+
+ fn test_processor(
+ group: &mut TestGroup,
+ cipher_suite: CipherSuite,
+ ) -> CiphertextProcessor<'_, impl GroupStateProvider, impl CipherSuiteProvider> {
+ CiphertextProcessor::new(&mut group.group, test_cipher_suite_provider(cipher_suite))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_data(cipher_suite: CipherSuite) -> TestData {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let group = test_group(TEST_PROTOCOL_VERSION, cipher_suite).await;
+
+ let content = AuthenticatedContent::new_signed(
+ &provider,
+ group.group.context(),
+ Sender::Member(0),
+ Content::Application(ApplicationData::from(b"test".to_vec())),
+ &group.group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ TestData { group, content }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encrypt_decrypt() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_data = test_data(cipher_suite).await;
+ let mut receiver_group = test_data.group.clone();
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, cipher_suite);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let mut receiver_processor = test_processor(&mut receiver_group, cipher_suite);
+
+ let decrypted = receiver_processor.open(&ciphertext).await.unwrap();
+
+ assert_eq!(decrypted, test_data.content);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_padding_use() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext_step = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ let ciphertext_no_pad = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::None)
+ .await
+ .unwrap();
+
+ assert!(ciphertext_step.ciphertext.len() > ciphertext_no_pad.ciphertext.len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_sender() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ test_data.content.content.sender = Sender::Member(3);
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let res = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSender))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_cant_process_from_self() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let ciphertext = ciphertext_processor
+ .seal(test_data.content, PaddingMode::None)
+ .await
+ .unwrap();
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_decryption_error() {
+ let mut test_data = test_data(TEST_CIPHER_SUITE).await;
+ let mut receiver_group = test_data.group.clone();
+ let mut ciphertext_processor = test_processor(&mut test_data.group, TEST_CIPHER_SUITE);
+
+ let mut ciphertext = ciphertext_processor
+ .seal(test_data.content.clone(), PaddingMode::StepFunction)
+ .await
+ .unwrap();
+
+ ciphertext.ciphertext = random_bytes(ciphertext.ciphertext.len());
+ receiver_group.group.private_tree.self_index = LeafIndex::new(1);
+
+ let res = ciphertext_processor.open(&ciphertext).await;
+
+ assert!(res.is_err());
+ }
+}
diff --git a/src/group/ciphertext_processor/message_key.rs b/src/group/ciphertext_processor/message_key.rs
new file mode 100644
index 0000000..256db7d
--- /dev/null
+++ b/src/group/ciphertext_processor/message_key.rs
@@ -0,0 +1,57 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use zeroize::Zeroizing;
+
+use crate::{crypto::CipherSuiteProvider, group::secret_tree::MessageKeyData};
+
+use super::reuse_guard::ReuseGuard;
+
+#[derive(Debug, PartialEq, Eq)]
+pub struct MessageKey(MessageKeyData);
+
+impl MessageKey {
+ pub(crate) fn new(key: MessageKeyData) -> MessageKey {
+ MessageKey(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Vec<u8>, P::Error> {
+ provider
+ .aead_seal(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt<P: CipherSuiteProvider>(
+ &self,
+ provider: &P,
+ data: &[u8],
+ aad: &[u8],
+ reuse_guard: &ReuseGuard,
+ ) -> Result<Zeroizing<Vec<u8>>, P::Error> {
+ provider
+ .aead_open(
+ &self.0.key,
+ data,
+ Some(aad),
+ &reuse_guard.apply(&self.0.nonce),
+ )
+ .await
+ }
+}
+
+// TODO: Write test vectors
diff --git a/src/group/ciphertext_processor/reuse_guard.rs b/src/group/ciphertext_processor/reuse_guard.rs
new file mode 100644
index 0000000..10e1db1
--- /dev/null
+++ b/src/group/ciphertext_processor/reuse_guard.rs
@@ -0,0 +1,133 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::CipherSuiteProvider;
+
+const REUSE_GUARD_SIZE: usize = 4;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct ReuseGuard([u8; REUSE_GUARD_SIZE]);
+
+impl From<[u8; REUSE_GUARD_SIZE]> for ReuseGuard {
+ fn from(value: [u8; REUSE_GUARD_SIZE]) -> Self {
+ ReuseGuard(value)
+ }
+}
+
+impl From<ReuseGuard> for [u8; REUSE_GUARD_SIZE] {
+ fn from(value: ReuseGuard) -> Self {
+ value.0
+ }
+}
+
+impl AsRef<[u8]> for ReuseGuard {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl ReuseGuard {
+ pub(crate) fn random<P: CipherSuiteProvider>(provider: &P) -> Result<Self, P::Error> {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ provider.random_bytes(&mut data).map(|_| ReuseGuard(data))
+ }
+
+ pub(crate) fn apply(&self, nonce: &[u8]) -> Vec<u8> {
+ let mut new_nonce = nonce.to_vec();
+
+ new_nonce
+ .iter_mut()
+ .zip(self.as_ref().iter())
+ .for_each(|(nonce_byte, guard_byte)| *nonce_byte ^= guard_byte);
+
+ new_nonce
+ }
+}
+
+#[cfg(test)]
+mod test_utils {
+ use alloc::vec::Vec;
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ impl ReuseGuard {
+ pub fn new(guard: Vec<u8>) -> Self {
+ let mut data = [0u8; REUSE_GUARD_SIZE];
+ data.copy_from_slice(&guard);
+ Self(data)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::test_cipher_suite_provider,
+ };
+
+ use super::{ReuseGuard, REUSE_GUARD_SIZE};
+
+ #[test]
+ fn test_random_generation() {
+ let test_guard =
+ ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+
+ (0..1000).for_each(|_| {
+ let next = ReuseGuard::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap();
+ assert_ne!(next, test_guard);
+ })
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ nonce: Vec<u8>,
+ guard: [u8; REUSE_GUARD_SIZE],
+ result: Vec<u8>,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_reuse_guard_test_cases() -> Vec<TestCase> {
+ let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ [16, 32]
+ .into_iter()
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |len| {
+ let nonce = provider.random_bytes_vec(len).unwrap();
+ let guard = ReuseGuard::random(&provider).unwrap();
+
+ let result = guard.apply(&nonce);
+
+ TestCase {
+ nonce,
+ guard: guard.into(),
+ result,
+ }
+ },
+ )
+ .collect()
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(reuse_guard, generate_reuse_guard_test_cases())
+ }
+
+ #[test]
+ fn test_reuse_guard() {
+ let test_cases = load_test_cases();
+
+ for case in test_cases {
+ let guard = ReuseGuard::from(case.guard);
+ let result = guard.apply(&case.nonce);
+ assert_eq!(result, case.result);
+ }
+ }
+}
diff --git a/src/group/ciphertext_processor/sender_data_key.rs b/src/group/ciphertext_processor/sender_data_key.rs
new file mode 100644
index 0000000..983920a
--- /dev/null
+++ b/src/group/ciphertext_processor/sender_data_key.rs
@@ -0,0 +1,360 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::{
+ client::MlsError,
+ crypto::CipherSuiteProvider,
+ group::{epoch::SenderDataSecret, framing::ContentType, key_schedule::kdf_expand_with_label},
+ tree_kem::node::LeafIndex,
+};
+
+use super::ReuseGuard;
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderData {
+ pub sender: LeafIndex,
+ pub generation: u32,
+ pub reuse_guard: ReuseGuard,
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct SenderDataAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+}
+
+impl Debug for SenderDataAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .finish()
+ }
+}
+
+pub(crate) struct SenderDataKey<'a, CP: CipherSuiteProvider> {
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ cipher_suite_provider: &'a CP,
+}
+
+impl<CP: CipherSuiteProvider + Debug> Debug for SenderDataKey<'_, CP> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SenderDataKey")
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("cipher_suite_provider", self.cipher_suite_provider)
+ .finish()
+ }
+}
+
+impl<'a, CP: CipherSuiteProvider> SenderDataKey<'a, CP> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn new(
+ sender_data_secret: &SenderDataSecret,
+ ciphertext: &[u8],
+ cipher_suite_provider: &'a CP,
+ ) -> Result<SenderDataKey<'a, CP>, MlsError> {
+ // Sample the first extract_size bytes of the ciphertext, and if it is shorter, just use
+ // the ciphertext itself
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let ciphertext_sample = ciphertext.get(0..extract_size).unwrap_or(ciphertext);
+
+ // Generate a sender data key and nonce using the sender_data_secret from the current
+ // epoch's key schedule
+ let key = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"key",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_key_size()),
+ )
+ .await?;
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite_provider,
+ sender_data_secret,
+ b"nonce",
+ ciphertext_sample,
+ Some(cipher_suite_provider.aead_nonce_size()),
+ )
+ .await?;
+
+ Ok(Self {
+ key,
+ nonce,
+ cipher_suite_provider,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn seal(
+ &self,
+ sender_data: &SenderData,
+ aad: &SenderDataAAD,
+ ) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite_provider
+ .aead_seal(
+ &self.key,
+ &sender_data.mls_encode_to_vec()?,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn open(
+ &self,
+ sender_data: &[u8],
+ aad: &SenderDataAAD,
+ ) -> Result<SenderData, MlsError> {
+ self.cipher_suite_provider
+ .aead_open(
+ &self.key,
+ sender_data,
+ Some(&aad.mls_encode_to_vec()?),
+ &self.nonce,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .and_then(|data| SenderData::mls_decode(&mut &**data).map_err(From::from))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::SenderDataKey;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropSenderData {
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub ciphertext: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub nonce: Vec<u8>,
+ }
+
+ impl InteropSenderData {
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub(crate) fn new<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+ let ciphertext = cs.random_bytes_vec(77).unwrap();
+ let key = SenderDataKey::new(&secret, &ciphertext, cs).unwrap();
+ let secret = (*secret).clone();
+
+ Self {
+ ciphertext,
+ key: key.key.to_vec(),
+ nonce: key.nonce.to_vec(),
+ sender_data_secret: secret,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let secret = self.sender_data_secret.clone().into();
+
+ let key = SenderDataKey::new(&secret, &self.ciphertext, cs)
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), self.key, "sender data key mismatch");
+ assert_eq!(key.nonce.to_vec(), self.nonce, "sender data nonce mismatch");
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::reuse_guard::ReuseGuard, framing::ContentType},
+ tree_kem::node::LeafIndex,
+ };
+
+ use super::{SenderData, SenderDataAAD, SenderDataKey};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider,
+ group::test_utils::random_bytes, CipherSuiteProvider,
+ };
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ciphertext_bytes: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ expected_nonce: Vec<u8>,
+ sender_data: TestSenderData,
+ sender_data_aad: TestSenderDataAAD,
+ #[serde(with = "hex::serde")]
+ expected_ciphertext: Vec<u8>,
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderData {
+ sender: u32,
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ reuse_guard: Vec<u8>,
+ }
+
+ impl From<TestSenderData> for SenderData {
+ fn from(value: TestSenderData) -> Self {
+ let reuse_guard = ReuseGuard::new(value.reuse_guard);
+
+ Self {
+ sender: LeafIndex(value.sender),
+ generation: value.generation,
+ reuse_guard,
+ }
+ }
+ }
+
+ #[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
+ struct TestSenderDataAAD {
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ }
+
+ impl From<TestSenderDataAAD> for SenderDataAAD {
+ fn from(value: TestSenderDataAAD) -> Self {
+ Self {
+ epoch: value.epoch,
+ group_id: value.group_id,
+ content_type: ContentType::Application,
+ }
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let test_cases = CipherSuite::all().map(test_cipher_suite_provider).map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |provider| {
+ let ext_size = provider.kdf_extract_size();
+ let secret = random_bytes(ext_size).into();
+ let ciphertext_sizes = [ext_size - 5, ext_size, ext_size + 5];
+
+ let sender_data = TestSenderData {
+ sender: 0,
+ generation: 13,
+ reuse_guard: random_bytes(4),
+ };
+
+ let sender_data_aad = TestSenderDataAAD {
+ group_id: b"group".to_vec(),
+ epoch: 42,
+ };
+
+ ciphertext_sizes.into_iter().map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ move |ciphertext_size| {
+ let ciphertext_bytes = random_bytes(ciphertext_size);
+
+ let sender_data_key =
+ SenderDataKey::new(&secret, &ciphertext_bytes, &provider).unwrap();
+
+ let expected_ciphertext = sender_data_key
+ .seal(&sender_data.clone().into(), &sender_data_aad.clone().into())
+ .unwrap();
+
+ TestCase {
+ cipher_suite: provider.cipher_suite().into(),
+ secret: secret.to_vec(),
+ ciphertext_bytes,
+ expected_key: sender_data_key.key.to_vec(),
+ expected_nonce: sender_data_key.nonce.to_vec(),
+ sender_data: sender_data.clone(),
+ sender_data_aad: sender_data_aad.clone(),
+ expected_ciphertext,
+ }
+ },
+ )
+ },
+ );
+
+ test_cases.flatten().collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(sender_data_key_test_vector, generate_test_vector())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sender_data_key_test_vector() {
+ for test_case in load_test_cases() {
+ let Some(provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let sender_data_key = SenderDataKey::new(
+ &test_case.secret.into(),
+ &test_case.ciphertext_bytes,
+ &provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(sender_data_key.key.to_vec(), test_case.expected_key);
+ assert_eq!(sender_data_key.nonce.to_vec(), test_case.expected_nonce);
+
+ let sender_data = test_case.sender_data.into();
+ let sender_data_aad = test_case.sender_data_aad.into();
+
+ let ciphertext = sender_data_key
+ .seal(&sender_data, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(ciphertext, test_case.expected_ciphertext);
+
+ let plaintext = sender_data_key
+ .open(&ciphertext, &sender_data_aad)
+ .await
+ .unwrap();
+
+ assert_eq!(plaintext, sender_data);
+ }
+ }
+}
diff --git a/src/group/commit.rs b/src/group/commit.rs
new file mode 100644
index 0000000..c201057
--- /dev/null
+++ b/src/group/commit.rs
@@ -0,0 +1,1601 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, SignatureSecretKey},
+ error::IntoAnyError,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ client_config::ClientConfig,
+ extension::RatchetTreeExt,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{
+ kem::TreeKem, node::LeafIndex, path_secret::PathSecret, TreeKemPrivate, UpdatePath,
+ },
+ ExtensionList, MlsRules,
+};
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(not(feature = "private_message"))]
+use crate::WireFormat;
+
+#[cfg(feature = "psk")]
+use crate::{
+ group::{JustPreSharedKeyID, PskGroupId, ResumptionPSKUsage, ResumptionPsk},
+ psk::ExternalPskId,
+};
+
+use super::{
+ confirmation_tag::ConfirmationTag,
+ framing::{Content, MlsMessage, MlsMessagePayload, Sender},
+ key_schedule::{KeySchedule, WelcomeSecret},
+ message_processor::{path_update_required, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ mls_rules::CommitDirection,
+ proposal::{Proposal, ProposalOrRef},
+ ConfirmedTranscriptHash, EncryptedGroupSecrets, ExportedTree, Group, GroupContext, GroupInfo,
+ Welcome,
+};
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use super::proposal_cache::prepare_commit;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal::CustomProposal;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(mls_rs_core::arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Commit {
+ pub proposals: Vec<ProposalOrRef>,
+ pub path: Option<UpdatePath>,
+}
+
+#[derive(Clone, PartialEq, Debug, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(super) struct CommitGeneration {
+ pub content: AuthenticatedContent,
+ pub pending_private_tree: TreeKemPrivate,
+ pub pending_commit_secret: PathSecret,
+ pub commit_message_hash: CommitHash,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct CommitHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for CommitHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("CommitHash")
+ .fmt(f)
+ }
+}
+
+impl CommitHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn compute<CS: CipherSuiteProvider>(
+ cs: &CS,
+ commit: &MlsMessage,
+ ) -> Result<Self, MlsError> {
+ cs.hash(&commit.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .map(Self)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+#[non_exhaustive]
+/// Result of MLS commit operation using
+/// [`Group::commit`](crate::group::Group::commit) or
+/// [`CommitBuilder::build`](CommitBuilder::build).
+pub struct CommitOutput {
+ /// Commit message to send to other group members.
+ pub commit_message: MlsMessage,
+ /// Welcome messages to send to new group members. If the commit does not add members,
+ /// this list is empty. Otherwise, if [`MlsRules::commit_options`] returns `single_welcome_message`
+ /// set to true, then this list contains a single message sent to all members. Else, the list
+ /// contains one message for each added member. Recipients of each message can be identified using
+ /// [`MlsMessage::key_package_reference`] of their key packages and
+ /// [`MlsMessage::welcome_key_package_references`].
+ pub welcome_messages: Vec<MlsMessage>,
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ pub ratchet_tree: Option<ExportedTree<'static>>,
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ pub external_commit_group_info: Option<MlsMessage>,
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl CommitOutput {
+ /// Commit message to send to other group members.
+ #[cfg(feature = "ffi")]
+ pub fn commit_message(&self) -> &MlsMessage {
+ &self.commit_message
+ }
+
+ /// Welcome message to send to new group members.
+ #[cfg(feature = "ffi")]
+ pub fn welcome_messages(&self) -> &[MlsMessage] {
+ &self.welcome_messages
+ }
+
+ /// Ratchet tree that can be sent out of band if
+ /// `ratchet_tree_extension` is not used according to
+ /// [`MlsRules::commit_options`].
+ #[cfg(feature = "ffi")]
+ pub fn ratchet_tree(&self) -> Option<&ExportedTree<'static>> {
+ self.ratchet_tree.as_ref()
+ }
+
+ /// A group info that can be provided to new members in order to enable external commit
+ /// functionality. This value is set if [`MlsRules::commit_options`] returns
+ /// `allow_external_commit` set to true.
+ #[cfg(feature = "ffi")]
+ pub fn external_commit_group_info(&self) -> Option<&MlsMessage> {
+ self.external_commit_group_info.as_ref()
+ }
+
+ /// Proposals that were received in the prior epoch but not included in the following commit.
+ #[cfg(all(feature = "ffi", feature = "by_ref_proposal"))]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+}
+
+/// Build a commit with multiple proposals by-value.
+///
+/// Proposals within a commit can be by-value or by-reference.
+/// Proposals received during the current epoch will be added to the resulting
+/// commit by-reference automatically so long as they pass the rules defined
+/// in the current
+/// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+pub struct CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ group: &'a mut Group<C>,
+ pub(super) proposals: Vec<Proposal>,
+ authenticated_data: Vec<u8>,
+ group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+}
+
+impl<'a, C> CommitBuilder<'a, C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Insert an [`AddProposal`](crate::group::proposal::AddProposal) into
+ /// the current commit that is being built.
+ pub fn add_member(mut self, key_package: MlsMessage) -> Result<CommitBuilder<'a, C>, MlsError> {
+ let proposal = self.group.add_proposal(key_package)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Set group info extensions that will be inserted into the resulting
+ /// [welcome messages](CommitOutput::welcome_messages) for new members.
+ ///
+ /// Group info extensions that are transmitted as part of a welcome message
+ /// are encrypted along with other private values.
+ ///
+ /// These extensions can be retrieved as part of
+ /// [`NewMemberInfo`](crate::group::NewMemberInfo) that is returned
+ /// by joining the group via
+ /// [`Client::join_group`](crate::Client::join_group).
+ pub fn set_group_info_ext(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_info_extensions: extensions,
+ ..self
+ }
+ }
+
+ /// Insert a [`RemoveProposal`](crate::group::proposal::RemoveProposal) into
+ /// the current commit that is being built.
+ pub fn remove_member(mut self, index: u32) -> Result<Self, MlsError> {
+ let proposal = self.group.remove_proposal(index)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`GroupContextExtensions`](crate::group::proposal::Proposal::GroupContextExtensions)
+ /// into the current commit that is being built.
+ pub fn set_group_context_ext(mut self, extensions: ExtensionList) -> Result<Self, MlsError> {
+ let proposal = self.group.group_context_extensions_proposal(extensions);
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// an external PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_external_psk(mut self, psk_id: ExternalPskId) -> Result<Self, MlsError> {
+ let key_id = JustPreSharedKeyID::External(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a
+ /// [`PreSharedKeyProposal`](crate::group::proposal::PreSharedKeyProposal) with
+ /// a resumption PSK into the current commit that is being built.
+ #[cfg(feature = "psk")]
+ pub fn add_resumption_psk(mut self, psk_epoch: u64) -> Result<Self, MlsError> {
+ let psk_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group.group_id().to_vec()),
+ };
+
+ let key_id = JustPreSharedKeyID::Resumption(psk_id);
+ let proposal = self.group.psk_proposal(key_id)?;
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`ReInitProposal`](crate::group::proposal::ReInitProposal) into
+ /// the current commit that is being built.
+ pub fn reinit(
+ mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Self, MlsError> {
+ let proposal = self
+ .group
+ .reinit_proposal(group_id, version, cipher_suite, extensions)?;
+
+ self.proposals.push(proposal);
+ Ok(self)
+ }
+
+ /// Insert a [`CustomProposal`](crate::group::proposal::CustomProposal) into
+ /// the current commit that is being built.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ /// Insert a proposal that was previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposal(mut self, proposal: Proposal) -> Self {
+ self.proposals.push(proposal);
+ self
+ }
+
+ /// Insert proposals that were previously constructed such as when a
+ /// proposal is returned from
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ pub fn raw_proposals(mut self, mut proposals: Vec<Proposal>) -> Self {
+ self.proposals.append(&mut proposals);
+ self
+ }
+
+ /// Add additional authenticated data to the commit.
+ ///
+ /// # Warning
+ ///
+ /// The data provided here is always sent unencrypted.
+ pub fn authenticated_data(self, authenticated_data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data,
+ ..self
+ }
+ }
+
+ /// Change the committer's signing identity as part of making this commit.
+ /// This will only succeed if the [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group considers the credential inside this signing_identity
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// and results in the same
+ /// [identity](crate::IdentityProvider::identity)
+ /// being used.
+ pub fn set_new_signing_identity(
+ self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ ) -> Self {
+ Self {
+ new_signer: Some(signer),
+ new_signing_identity: Some(signing_identity),
+ ..self
+ }
+ }
+
+ /// Finalize the commit to send.
+ ///
+ /// # Errors
+ ///
+ /// This function will return an error if any of the proposals provided
+ /// are not contextually valid according to the rules defined by the
+ /// MLS RFC, or if they do not pass the custom rules defined by the current
+ /// [proposal rules](crate::client_builder::ClientBuilder::mls_rules).
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self) -> Result<CommitOutput, MlsError> {
+ self.group
+ .commit_internal(
+ self.proposals,
+ None,
+ self.authenticated_data,
+ self.group_info_extensions,
+ self.new_signer,
+ self.new_signing_identity,
+ )
+ .await
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Perform a commit of received proposals.
+ ///
+ /// This function is the equivalent of [`Group::commit_builder`] immediately
+ /// followed by [`CommitBuilder::build`]. Any received proposals since the
+ /// last commit will be included in the resulting message by-reference.
+ ///
+ /// Data provided in the `authenticated_data` field will be placed into
+ /// the resulting commit message unencrypted.
+ ///
+ /// # Pending Commits
+ ///
+ /// When a commit is created, it is not applied immediately in order to
+ /// allow for the resolution of conflicts when multiple members of a group
+ /// attempt to make commits at the same time. For example, a central relay
+ /// can be used to decide which commit should be accepted by the group by
+ /// determining a consistent view of commit packet order for all clients.
+ ///
+ /// Pending commits are stored internally as part of the group's state
+ /// so they do not need to be tracked outside of this library. Any commit
+ /// message that is processed before calling [Group::apply_pending_commit]
+ /// will clear the currently pending commit.
+ ///
+ /// # Empty Commits
+ ///
+ /// Sending a commit that contains no proposals is a valid operation
+ /// within the MLS protocol. It is useful for providing stronger forward
+ /// secrecy and post-compromise security, especially for long running
+ /// groups when group membership does not change often.
+ ///
+ /// # Path Updates
+ ///
+ /// Path updates provide forward secrecy and post-compromise security
+ /// within the MLS protocol.
+ /// The `path_required` option returned by [`MlsRules::commit_options`](`crate::MlsRules::commit_options`)
+ /// controls the ability of a group to send a commit without a path update.
+ /// An update path will automatically be sent if there are no proposals
+ /// in the commit, or if any proposal other than
+ /// [`Add`](crate::group::proposal::Proposal::Add),
+ /// [`Psk`](crate::group::proposal::Proposal::Psk),
+ /// or [`ReInit`](crate::group::proposal::Proposal::ReInit) are part of the commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(&mut self, authenticated_data: Vec<u8>) -> Result<CommitOutput, MlsError> {
+ self.commit_internal(
+ vec![],
+ None,
+ authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+ }
+
+ /// Create a new commit builder that can include proposals
+ /// by-value.
+ pub fn commit_builder(&mut self) -> CommitBuilder<C> {
+ CommitBuilder {
+ group: self,
+ proposals: Default::default(),
+ authenticated_data: Default::default(),
+ group_info_extensions: Default::default(),
+ new_signer: Default::default(),
+ new_signing_identity: Default::default(),
+ }
+ }
+
+ /// Returns commit and optional [`MlsMessage`] containing a welcome message
+ /// for newly added members.
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn commit_internal(
+ &mut self,
+ proposals: Vec<Proposal>,
+ external_leaf: Option<&LeafNode>,
+ authenticated_data: Vec<u8>,
+ mut welcome_group_info_extensions: ExtensionList,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<CommitOutput, MlsError> {
+ if self.pending_commit.is_some() {
+ return Err(MlsError::ExistingPendingCommit);
+ }
+
+ if self.state.pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ let mls_rules = self.config.mls_rules();
+
+ let is_external = external_leaf.is_some();
+
+ // Construct an initial Commit object with the proposals field populated from Proposals
+ // received during the current epoch, and an empty path field. Add passed in proposals
+ // by value
+ let sender = if is_external {
+ Sender::NewMemberCommit
+ } else {
+ Sender::Member(*self.private_tree.self_index)
+ };
+
+ let new_signer_ref = new_signer.as_ref().unwrap_or(&self.signer);
+ let old_signer = &self.signer;
+
+ #[cfg(feature = "std")]
+ let time = Some(crate::time::MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let time = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = self.state.proposals.prepare_commit(sender, proposals);
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = prepare_commit(sender, proposals);
+
+ let mut provisional_state = self
+ .state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ &self.config.identity_provider(),
+ &self.cipher_suite_provider,
+ &self.config.secret_store(),
+ &mls_rules,
+ time,
+ CommitDirection::Send,
+ )
+ .await?;
+
+ let (mut provisional_private_tree, _) =
+ self.provisional_private_tree(&provisional_state)?;
+
+ if is_external {
+ provisional_private_tree.self_index = provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit)?;
+
+ self.private_tree.self_index = provisional_private_tree.self_index;
+ }
+
+ let mut provisional_group_context = provisional_state.group_context;
+
+ // Decide whether to populate the path field: If the path field is required based on the
+ // proposals that are in the commit (see above), then it MUST be populated. Otherwise, the
+ // sender MAY omit the path field at its discretion.
+ let commit_options = mls_rules
+ .commit_options(
+ &provisional_state.public_tree.roster(),
+ &provisional_group_context.extensions,
+ &provisional_state.applied_proposals,
+ )
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let perform_path_update = commit_options.path_required
+ || path_update_required(&provisional_state.applied_proposals);
+
+ let (update_path, path_secrets, commit_secret) = if perform_path_update {
+ // If populating the path field: Create an UpdatePath using the new tree. Any new
+ // member (from an add proposal) MUST be excluded from the resolution during the
+ // computation of the UpdatePath. The GroupContext for this operation uses the
+ // group_id, epoch, tree_hash, and confirmed_transcript_hash values in the initial
+ // GroupContext object. The leaf_key_package for this UpdatePath must have a
+ // parent_hash extension.
+ let encap_gen = TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .encap(
+ &mut provisional_group_context,
+ &provisional_state.indexes_of_added_kpkgs,
+ new_signer_ref,
+ self.config.leaf_properties(),
+ new_signing_identity,
+ &self.cipher_suite_provider,
+ #[cfg(test)]
+ &self.commit_modifiers,
+ )
+ .await?;
+
+ (
+ Some(encap_gen.update_path),
+ Some(encap_gen.path_secrets),
+ encap_gen.commit_secret,
+ )
+ } else {
+ // Update the tree hash, since it was not updated by encap.
+ provisional_state
+ .public_tree
+ .update_hashes(
+ &[provisional_private_tree.self_index],
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ provisional_group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ (None, None, PathSecret::empty(&self.cipher_suite_provider))
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk_secret, psks) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = self.get_psk();
+
+ let added_key_pkgs: Vec<_> = provisional_state
+ .applied_proposals
+ .additions
+ .iter()
+ .map(|info| info.proposal.key_package.clone())
+ .collect();
+
+ let commit = Commit {
+ proposals: provisional_state.applied_proposals.into_proposals_or_refs(),
+ path: update_path,
+ };
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ old_signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ // Use the signature, the commit_secret and the psk_secret to advance the key schedule and
+ // compute the confirmation_tag value in the MlsPlaintext.
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ self.cipher_suite_provider(),
+ &self.state.interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ provisional_group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &self.key_schedule,
+ &commit_secret,
+ &provisional_group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ self.state.public_tree.total_leaf_count(),
+ &psk_secret,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ auth_content.auth.confirmation_tag = Some(confirmation_tag.clone());
+
+ let ratchet_tree_ext = commit_options
+ .ratchet_tree_extension
+ .then(|| RatchetTreeExt {
+ tree_data: ExportedTree::new(provisional_state.public_tree.nodes.clone()),
+ });
+
+ // Generate external commit group info if required by commit_options
+ let external_commit_group_info = match commit_options.allow_external_commit {
+ true => {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ key_schedule_result
+ .key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ if let Some(ref ratchet_tree_ext) = ratchet_tree_ext {
+ extensions.set_from(ratchet_tree_ext.clone())?;
+ }
+
+ let info = self
+ .make_group_info(
+ &provisional_group_context,
+ extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ let msg =
+ MlsMessage::new(self.protocol_version(), MlsMessagePayload::GroupInfo(info));
+
+ Some(msg)
+ }
+ false => None,
+ };
+
+ // Build the group info that will be placed into the welcome messages.
+ // Add the ratchet tree extension if necessary
+ if let Some(ratchet_tree_ext) = ratchet_tree_ext {
+ welcome_group_info_extensions.set_from(ratchet_tree_ext)?;
+ }
+
+ let welcome_group_info = self
+ .make_group_info(
+ &provisional_group_context,
+ welcome_group_info_extensions,
+ &confirmation_tag,
+ new_signer_ref,
+ )
+ .await?;
+
+ // Encrypt the GroupInfo using the key and nonce derived from the joiner_secret for
+ // the new epoch
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &self.cipher_suite_provider,
+ &key_schedule_result.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ let encrypted_group_info = welcome_secret
+ .encrypt(&welcome_group_info.mls_encode_to_vec()?)
+ .await?;
+
+ // Encrypt path secrets and joiner secret to new members
+ let path_secrets = path_secrets.as_ref();
+
+ #[cfg(not(any(mls_build_async, not(feature = "rayon"))))]
+ let encrypted_path_secrets: Vec<_> = added_key_pkgs
+ .into_par_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ .map(|(key_package, leaf_index)| {
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ })
+ .try_collect()?;
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ let encrypted_path_secrets = {
+ let mut secrets = Vec::new();
+
+ for (key_package, leaf_index) in added_key_pkgs
+ .into_iter()
+ .zip(provisional_state.indexes_of_added_kpkgs)
+ {
+ secrets.push(
+ self.encrypt_group_secrets(
+ &key_package,
+ leaf_index,
+ &key_schedule_result.joiner_secret,
+ path_secrets,
+ #[cfg(feature = "psk")]
+ psks.clone(),
+ &encrypted_group_info,
+ )
+ .await?,
+ );
+ }
+
+ secrets
+ };
+
+ let welcome_messages =
+ if commit_options.single_welcome_message && !encrypted_path_secrets.is_empty() {
+ vec![self.make_welcome_message(encrypted_path_secrets, encrypted_group_info)]
+ } else {
+ encrypted_path_secrets
+ .into_iter()
+ .map(|s| self.make_welcome_message(vec![s], encrypted_group_info.clone()))
+ .collect()
+ };
+
+ let commit_message = self.format_for_wire(auth_content.clone()).await?;
+
+ let pending_commit = CommitGeneration {
+ content: auth_content,
+ pending_private_tree: provisional_private_tree,
+ pending_commit_secret: commit_secret,
+ commit_message_hash: CommitHash::compute(&self.cipher_suite_provider, &commit_message)
+ .await?,
+ };
+
+ self.pending_commit = Some(pending_commit);
+
+ let ratchet_tree = (!commit_options.ratchet_tree_extension)
+ .then(|| ExportedTree::new(provisional_state.public_tree.nodes));
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ Ok(CommitOutput {
+ commit_message,
+ welcome_messages,
+ ratchet_tree,
+ external_commit_group_info,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional_state.unused_proposals,
+ })
+ }
+
+ // Construct a GroupInfo reflecting the new state
+ // Group ID, epoch, tree, and confirmed transcript hash from the new state
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_info(
+ &self,
+ group_context: &GroupContext,
+ extensions: ExtensionList,
+ confirmation_tag: &ConfirmationTag,
+ signer: &SignatureSecretKey,
+ ) -> Result<GroupInfo, MlsError> {
+ let mut group_info = GroupInfo {
+ group_context: group_context.clone(),
+ extensions,
+ confirmation_tag: confirmation_tag.clone(), // The confirmation_tag from the MlsPlaintext object
+ signer: LeafIndex(self.current_member_index()),
+ signature: vec![],
+ };
+
+ group_info.grease(self.cipher_suite_provider())?;
+
+ // Sign the GroupInfo using the member's private signing key
+ group_info
+ .sign(&self.cipher_suite_provider, signer, &())
+ .await?;
+
+ Ok(group_info)
+ }
+
+ fn make_welcome_message(
+ &self,
+ secrets: Vec<EncryptedGroupSecrets>,
+ encrypted_group_info: Vec<u8>,
+ ) -> MlsMessage {
+ MlsMessage::new(
+ self.context().protocol_version,
+ MlsMessagePayload::Welcome(Welcome {
+ cipher_suite: self.context().cipher_suite,
+ secrets,
+ encrypted_group_info,
+ }),
+ )
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec::Vec;
+
+ use crate::{
+ crypto::SignatureSecretKey,
+ tree_kem::{leaf_node::LeafNode, TreeKemPublic, UpdatePathNode},
+ };
+
+ #[derive(Copy, Clone, Debug)]
+ pub struct CommitModifiers {
+ pub modify_leaf: fn(&mut LeafNode, &SignatureSecretKey) -> Option<SignatureSecretKey>,
+ pub modify_tree: fn(&mut TreeKemPublic),
+ pub modify_path: fn(Vec<UpdatePathNode>) -> Vec<UpdatePathNode>,
+ }
+
+ impl Default for CommitModifiers {
+ fn default() -> Self {
+ Self {
+ modify_leaf: |_, _| None,
+ modify_tree: |_| (),
+ modify_path: |a| a,
+ }
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::boxed::Box;
+
+ use mls_rs_core::{
+ error::IntoAnyError,
+ extension::ExtensionType,
+ identity::{CredentialType, IdentityProvider},
+ time::MlsTime,
+ };
+
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{mls_rules::DefaultMlsRules, test_utils::test_group_custom},
+ mls_rules::CommitOptions,
+ Client,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::extension::ExternalSendersExt;
+
+ use crate::{
+ client::test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ client_builder::{
+ test_utils::TestClientConfig, BaseConfig, ClientBuilder, WithCryptoProvider,
+ WithIdentityProvider,
+ },
+ client_config::ClientConfig,
+ extension::test_utils::{TestExtension, TEST_EXTENSION_TYPE},
+ group::{
+ proposal::ProposalType,
+ test_utils::{test_group_custom_config, test_n_member_group},
+ },
+ identity::test_utils::get_test_signing_identity,
+ identity::{basic::BasicIdentityProvider, test_utils::get_test_basic_credential},
+ key_package::test_utils::test_key_package_message,
+ };
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{JustPreSharedKeyID, PreSharedKey, PreSharedKeyID},
+ };
+
+ use super::*;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_commit_builder_group() -> Group<TestClientConfig> {
+ test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(ProposalType::from(42))
+ .extension_type(TEST_EXTENSION_TYPE.into())
+ })
+ .await
+ .group
+ }
+
+ fn assert_commit_builder_output<C: ClientConfig>(
+ group: Group<C>,
+ mut commit_output: CommitOutput,
+ expected: Vec<Proposal>,
+ welcome_count: usize,
+ ) {
+ let plaintext = commit_output.commit_message.into_plaintext().unwrap();
+
+ let commit_data = match plaintext.content.content {
+ Content::Commit(commit) => commit,
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => panic!("Found non-commit data"),
+ };
+
+ assert_eq!(commit_data.proposals.len(), expected.len());
+
+ commit_data.proposals.into_iter().for_each(|proposal| {
+ let proposal = match proposal {
+ ProposalOrRef::Proposal(p) => p,
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalOrRef::Reference(_) => panic!("found proposal reference"),
+ };
+
+ #[cfg(feature = "psk")]
+ if let Some(psk_id) = match proposal.as_ref() {
+ Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(psk_id), .. },}) => Some(psk_id),
+ _ => None,
+ } {
+ let found = expected.iter().any(|item| matches!(item, Proposal::Psk(PreSharedKeyProposal { psk: PreSharedKeyID { key_id: JustPreSharedKeyID::External(id), .. }}) if id == psk_id));
+
+ assert!(found)
+ } else {
+ assert!(expected.contains(&proposal));
+ }
+
+ #[cfg(not(feature = "psk"))]
+ assert!(expected.contains(&proposal));
+ });
+
+ if welcome_count > 0 {
+ let welcome_msg = commit_output.welcome_messages.pop().unwrap();
+
+ assert_eq!(welcome_msg.version, group.state.context.protocol_version);
+
+ let welcome_msg = welcome_msg.into_welcome().unwrap();
+
+ assert_eq!(welcome_msg.cipher_suite, group.state.context.cipher_suite);
+ assert_eq!(welcome_msg.secrets.len(), welcome_count);
+ } else {
+ assert!(commit_output.welcome_messages.is_empty());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add() {
+ let mut group = test_commit_builder_group().await;
+
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_add = group.add_proposal(test_key_package).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_add], 1)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_add_with_ext() {
+ let mut group = test_commit_builder_group().await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let ext = TestExtension { foo: 42 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(ext.clone()).unwrap();
+
+ let welcome_message = group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .set_group_info_ext(extension_list)
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages
+ .remove(0);
+
+ let (_, context) = bob_client.join_group(None, &welcome_message).await.unwrap();
+
+ assert_eq!(
+ context
+ .group_info_extensions
+ .get_as::<TestExtension>()
+ .unwrap()
+ .unwrap(),
+ ext
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_remove() {
+ let mut group = test_commit_builder_group().await;
+ let test_key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+
+ group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_remove = group.remove_proposal(1).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_remove], 0);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_psk() {
+ let mut group = test_commit_builder_group().await;
+ let test_psk = ExternalPskId::new(vec![1]);
+
+ group
+ .config
+ .secret_store()
+ .insert(test_psk.clone(), PreSharedKey::from(vec![1]));
+
+ let commit_output = group
+ .commit_builder()
+ .add_external_psk(test_psk.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let key_id = JustPreSharedKeyID::External(test_psk);
+ let expected_psk = group.psk_proposal(key_id).unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_psk], 0)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_group_context_ext() {
+ let mut group = test_commit_builder_group().await;
+ let mut test_ext = ExtensionList::default();
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .set_group_context_ext(test_ext.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_ext = group.group_context_extensions_proposal(test_ext);
+
+ assert_commit_builder_output(group, commit_output, vec![expected_ext], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_reinit() {
+ let mut group = test_commit_builder_group().await;
+ let test_group_id = "foo".as_bytes().to_vec();
+ let test_cipher_suite = TEST_CIPHER_SUITE;
+ let test_protocol_version = TEST_PROTOCOL_VERSION;
+ let mut test_ext = ExtensionList::default();
+
+ test_ext
+ .set_from(RequiredCapabilitiesExt::default())
+ .unwrap();
+
+ let commit_output = group
+ .commit_builder()
+ .reinit(
+ Some(test_group_id.clone()),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext.clone(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let expected_reinit = group
+ .reinit_proposal(
+ Some(test_group_id),
+ test_protocol_version,
+ test_cipher_suite,
+ test_ext,
+ )
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![expected_reinit], 0);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_custom_proposal() {
+ let mut group = test_commit_builder_group().await;
+
+ let proposal = CustomProposal::new(42.into(), vec![0, 1]);
+
+ let commit_output = group
+ .commit_builder()
+ .custom_proposal(proposal.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![Proposal::Custom(proposal)], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_chaining() {
+ let mut group = test_commit_builder_group().await;
+ let kp1 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice").await;
+ let kp2 = test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let expected_adds = vec![
+ group.add_proposal(kp1.clone()).unwrap(),
+ group.add_proposal(kp2.clone()).unwrap(),
+ ];
+
+ let commit_output = group
+ .commit_builder()
+ .add_member(kp1)
+ .unwrap()
+ .add_member(kp2)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert_commit_builder_output(group, commit_output, expected_adds, 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_empty_commit() {
+ let mut group = test_commit_builder_group().await;
+
+ let commit_output = group.commit_builder().build().await.unwrap();
+
+ assert_commit_builder_output(group, commit_output, vec![], 0);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_authenticated_data() {
+ let mut group = test_commit_builder_group().await;
+ let test_data = "test".as_bytes().to_vec();
+
+ let commit_output = group
+ .commit_builder()
+ .authenticated_data(test_data.clone())
+ .build()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ commit_output
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ .content
+ .authenticated_data,
+ test_data
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_commit_builder_multiple_welcome_messages() {
+ let mut group = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ let options = CommitOptions::new().with_single_welcome_message(false);
+ b.mls_rules(DefaultMlsRules::new().with_commit_options(options))
+ })
+ .await;
+
+ let (alice, alice_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "a").await;
+
+ let (bob, bob_kp) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "b").await;
+
+ group
+ .group
+ .propose_add(alice_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ group
+ .group
+ .propose_add(bob_kp.clone(), vec![])
+ .await
+ .unwrap();
+
+ let output = group.group.commit(Vec::new()).await.unwrap();
+ let welcomes = output.welcome_messages;
+
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ for (client, kp) in [(alice, alice_kp), (bob, bob_kp)] {
+ let kp_ref = kp.key_package_reference(&cs).await.unwrap().unwrap();
+
+ let welcome = welcomes
+ .iter()
+ .find(|w| w.welcome_key_package_references().contains(&&kp_ref))
+ .unwrap();
+
+ client.join_group(None, welcome).await.unwrap();
+
+ assert_eq!(welcome.clone().into_welcome().unwrap().secrets.len(), 1);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_can_change_credential() {
+ let cs = TEST_CIPHER_SUITE;
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, cs, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(cs, b"member").await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .set_new_signing_identity(secret_key, identity.clone())
+ .build()
+ .await
+ .unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[0].process_pending_commit().await.unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in another member's state.
+ groups[1]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_tree_if_no_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ group.apply_pending_commit().await.unwrap();
+
+ let new_tree = group.export_tree();
+
+ assert_eq!(new_tree, commit.ratchet_tree.unwrap())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_tree_if_ratchet_tree_ext() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(true)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.ratchet_tree.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_group_info_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(false),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(!info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_includes_external_commit_and_tree_if_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(
+ CommitOptions::new()
+ .with_allow_external_commit(true)
+ .with_ratchet_tree_extension(true),
+ ),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let info = commit
+ .external_commit_group_info
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ assert!(info.extensions.has_extension(ExtensionType::RATCHET_TREE));
+ assert!(info.extensions.has_extension(ExtensionType::EXTERNAL_PUB));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_does_not_include_external_commit_group_info_if_not_requested() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_allow_external_commit(false)),
+ )
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ assert!(commit.external_commit_group_info.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let bob = client_with_test_extension(b"bob").await;
+ let bob_kp = bob.generate_key_package_message().await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let res = alice
+ .commit_builder()
+ .add_member(bob_kp)
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let alex = client_with_test_extension(b"alex").await;
+
+ alice
+ .commit_builder()
+ .add_member(alex.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .set_group_context_ext(extension_list.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn server_identity_is_validated_against_new_extensions() {
+ let alice = client_with_test_extension(b"alice").await;
+ let mut alice = alice.create_group(ExtensionList::new()).await.unwrap();
+
+ let mut extension_list = ExtensionList::new();
+ let extension = TestExtension { foo: b'a' };
+ extension_list.set_from(extension).unwrap();
+
+ let (alex_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"alex").await;
+
+ let mut alex_extensions = extension_list.clone();
+
+ alex_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![alex_server],
+ })
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(alex_extensions)
+ .unwrap()
+ .build()
+ .await;
+
+ assert!(res.is_err());
+
+ let (bob_server, _) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut bob_extensions = extension_list;
+
+ bob_extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![bob_server],
+ })
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .set_group_context_ext(bob_extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ }
+
+ #[derive(Debug, Clone)]
+ struct IdentityProviderWithExtension(BasicIdentityProvider);
+
+ #[derive(Clone, Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(feature = "std", error("test error"))]
+ struct IdentityProviderWithExtensionError {}
+
+ impl IntoAnyError for IdentityProviderWithExtensionError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ impl IdentityProviderWithExtension {
+ // True if the identity starts with the character `foo` from `TestExtension` or if `TestExtension`
+ // is not set.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn starts_with_foo(
+ &self,
+ identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> bool {
+ if let Some(extensions) = extensions {
+ if let Some(ext) = extensions.get_as::<TestExtension>().unwrap() {
+ self.identity(identity, extensions).await.unwrap()[0] == ext.foo
+ } else {
+ true
+ }
+ } else {
+ true
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for IdentityProviderWithExtension {
+ type Error = IdentityProviderWithExtensionError;
+
+ async fn validate_member(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ self.starts_with_foo(identity, timestamp, extensions)
+ .await
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn validate_external_sender(
+ &self,
+ identity: &SigningIdentity,
+ timestamp: Option<MlsTime>,
+ extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ (!self.starts_with_foo(identity, timestamp, extensions).await)
+ .then_some(())
+ .ok_or(IdentityProviderWithExtensionError {})
+ }
+
+ async fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ self.0
+ .identity(signing_identity, extensions)
+ .await
+ .map_err(|_| IdentityProviderWithExtensionError {})
+ }
+
+ async fn valid_successor(
+ &self,
+ _predecessor: &SigningIdentity,
+ _successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(true)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ self.0.supported_types()
+ }
+ }
+
+ type ExtensionClientConfig = WithIdentityProvider<
+ IdentityProviderWithExtension,
+ WithCryptoProvider<TestCryptoProvider, BaseConfig>,
+ >;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_test_extension(name: &[u8]) -> Client<ExtensionClientConfig> {
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .extension_types(vec![TEST_EXTENSION_TYPE.into()])
+ .identity_provider(IdentityProviderWithExtension(BasicIdentityProvider::new()))
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build()
+ }
+}
diff --git a/src/group/confirmation_tag.rs b/src/group/confirmation_tag.rs
new file mode 100644
index 0000000..409b382
--- /dev/null
+++ b/src/group/confirmation_tag.rs
@@ -0,0 +1,150 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::CipherSuiteProvider;
+use crate::{client::MlsError, group::transcript_hash::ConfirmedTranscriptHash};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmationTag(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmationTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmationTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmationTag {
+ type Target = Vec<u8>;
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ cipher_suite_provider
+ .mac(confirmation_key, confirmed_transcript_hash)
+ .await
+ .map(ConfirmationTag)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn matches<P: CipherSuiteProvider>(
+ &self,
+ confirmation_key: &[u8],
+ confirmed_transcript_hash: &ConfirmedTranscriptHash,
+ cipher_suite_provider: &P,
+ ) -> Result<bool, MlsError> {
+ let tag = ConfirmationTag::create(
+ confirmation_key,
+ confirmed_transcript_hash,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ Ok(&tag == self)
+ }
+}
+
+#[cfg(test)]
+impl ConfirmationTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
+ Self(
+ cipher_suite_provider
+ .mac(
+ &alloc::vec![0; cipher_suite_provider.kdf_extract_size()],
+ &[],
+ )
+ .await
+ .unwrap(),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_confirmation_tag_matching() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let confirmed_hash_a = ConfirmedTranscriptHash::from(b"foo_a".to_vec());
+
+ let confirmation_key_a = b"bar_a".to_vec();
+
+ let confirmed_hash_b = ConfirmedTranscriptHash::from(b"foo_b".to_vec());
+
+ let confirmation_key_b = b"bar_b".to_vec();
+
+ let confirmation_tag = ConfirmationTag::create(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_b,
+ &confirmed_hash_a,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+
+ let matches = confirmation_tag
+ .matches(
+ &confirmation_key_a,
+ &confirmed_hash_b,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!matches);
+ }
+ }
+}
diff --git a/src/group/context.rs b/src/group/context.rs
new file mode 100644
index 0000000..4ec23a9
--- /dev/null
+++ b/src/group/context.rs
@@ -0,0 +1,98 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{cipher_suite::CipherSuite, protocol_version::ProtocolVersion, ExtensionList};
+
+use super::ConfirmedTranscriptHash;
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct GroupContext {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) epoch: u64,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) tree_hash: Vec<u8>,
+ pub(crate) confirmed_transcript_hash: ConfirmedTranscriptHash,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for GroupContext {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupContext")
+ .field("protocol_version", &self.protocol_version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field(
+ "tree_hash",
+ &mls_rs_core::debug::pretty_bytes(&self.tree_hash),
+ )
+ .field("confirmed_transcript_hash", &self.confirmed_transcript_hash)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupContext {
+ pub(crate) fn new_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ group_id: Vec<u8>,
+ tree_hash: Vec<u8>,
+ extensions: ExtensionList,
+ ) -> Self {
+ GroupContext {
+ protocol_version,
+ cipher_suite,
+ group_id,
+ epoch: 0,
+ tree_hash,
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions,
+ }
+ }
+
+ /// Get the current protocol version in use by the group.
+ pub fn version(&self) -> ProtocolVersion {
+ self.protocol_version
+ }
+
+ /// Get the current cipher suite in use by the group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Get the unique identifier of this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// Get the current epoch number of the group's state.
+ pub fn epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
diff --git a/src/group/epoch.rs b/src/group/epoch.rs
new file mode 100644
index 0000000..58352d6
--- /dev/null
+++ b/src/group/epoch.rs
@@ -0,0 +1,165 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::tree_kem::node::NodeIndex;
+#[cfg(feature = "prior_epoch")]
+use crate::{crypto::SignaturePublicKey, group::GroupContext, tree_kem::node::LeafIndex};
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use zeroize::Zeroizing;
+
+#[cfg(all(feature = "prior_epoch", feature = "private_message"))]
+use super::ciphertext_processor::GroupStateProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "prior_epoch")]
+#[derive(Debug, Clone, MlsEncode, MlsDecode, MlsSize, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PriorEpoch {
+ pub(crate) context: GroupContext,
+ pub(crate) self_index: LeafIndex,
+ pub(crate) secrets: EpochSecrets,
+ pub(crate) signature_public_keys: Vec<Option<SignaturePublicKey>>,
+}
+
+#[cfg(feature = "prior_epoch")]
+impl PriorEpoch {
+ #[inline(always)]
+ pub(crate) fn epoch_id(&self) -> u64 {
+ self.context.epoch
+ }
+
+ #[inline(always)]
+ pub(crate) fn group_id(&self) -> &[u8] {
+ &self.context.group_id
+ }
+}
+
+#[cfg(all(feature = "private_message", feature = "prior_epoch"))]
+impl GroupStateProvider for PriorEpoch {
+ fn group_context(&self) -> &GroupContext {
+ &self.context
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.secrets
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct EpochSecrets {
+ #[cfg(feature = "psk")]
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) resumption_secret: PreSharedKey,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) sender_data_secret: SenderDataSecret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ pub(crate) secret_tree: SecretTree<NodeIndex>,
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct SenderDataSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for SenderDataSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("SenderDataSecret")
+ .fmt(f)
+ }
+}
+
+impl AsRef<[u8]> for SenderDataSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl Deref for SenderDataSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for SenderDataSecret {
+ fn from(bytes: Vec<u8>) -> Self {
+ Self(Zeroizing::new(bytes))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for SenderDataSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use super::*;
+ use crate::cipher_suite::CipherSuite;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ use crate::group::secret_tree::test_utils::get_test_tree;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::test_utils::get_test_group_context_with_id;
+
+ use crate::group::test_utils::random_bytes;
+
+ pub(crate) fn get_test_epoch_secrets(cipher_suite: CipherSuite) -> EpochSecrets {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ let secret_tree = get_test_tree(random_bytes(cs_provider.kdf_extract_size()), 2);
+
+ EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ sender_data_secret: random_bytes(cs_provider.kdf_extract_size()).into(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree,
+ }
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ pub(crate) fn get_test_epoch_with_id(
+ group_id: Vec<u8>,
+ cipher_suite: CipherSuite,
+ id: u64,
+ ) -> PriorEpoch {
+ PriorEpoch {
+ context: get_test_group_context_with_id(group_id, id, cipher_suite),
+ self_index: LeafIndex(0),
+ secrets: get_test_epoch_secrets(cipher_suite),
+ signature_public_keys: Default::default(),
+ }
+ }
+}
diff --git a/src/group/exported_tree.rs b/src/group/exported_tree.rs
new file mode 100644
index 0000000..acf507f
--- /dev/null
+++ b/src/group/exported_tree.rs
@@ -0,0 +1,51 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{borrow::Cow, vec::Vec};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use crate::{client::MlsError, tree_kem::node::NodeVec};
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, MlsSize, MlsEncode, MlsDecode, PartialEq, Clone)]
+pub struct ExportedTree<'a>(pub(crate) Cow<'a, NodeVec>);
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<'a> ExportedTree<'a> {
+ pub(crate) fn new(node_data: NodeVec) -> Self {
+ Self(Cow::Owned(node_data))
+ }
+
+ pub(crate) fn new_borrowed(node_data: &'a NodeVec) -> Self {
+ Self(Cow::Borrowed(node_data))
+ }
+
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ pub fn byte_size(&self) -> usize {
+ self.mls_encoded_len()
+ }
+
+ pub fn into_owned(self) -> ExportedTree<'static> {
+ ExportedTree(Cow::Owned(self.0.into_owned()))
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ExportedTree<'static> {
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+}
+
+impl From<ExportedTree<'_>> for NodeVec {
+ fn from(value: ExportedTree) -> Self {
+ value.0.into_owned()
+ }
+}
diff --git a/src/group/external_commit.rs b/src/group/external_commit.rs
new file mode 100644
index 0000000..34b1042
--- /dev/null
+++ b/src/group/external_commit.rs
@@ -0,0 +1,266 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::SignatureSecretKey, identity::SigningIdentity};
+
+use crate::{
+ client_config::ClientConfig,
+ group::{
+ cipher_suite_provider,
+ epoch::SenderDataSecret,
+ key_schedule::{InitSecret, KeySchedule},
+ proposal::{ExternalInit, Proposal, RemoveProposal},
+ EpochSecrets, ExternalPubExt, LeafIndex, LeafNode, MlsError, TreeKemPrivate,
+ },
+ Group, MlsMessage,
+};
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::secret_tree::SecretTree;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::{
+ framing::MlsMessagePayload,
+ message_processor::{EventOrContent, MessageProcessor},
+ message_signature::AuthenticatedContent,
+ message_verifier::verify_plaintext_authentication,
+ CustomProposal,
+};
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ PreSharedKeyProposal, {JustPreSharedKeyID, PreSharedKeyID},
+};
+
+use super::{validate_group_info_joiner, ExportedTree};
+
+/// A builder that aids with the construction of an external commit.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+pub struct ExternalCommitBuilder<C: ClientConfig> {
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ tree_data: Option<ExportedTree<'static>>,
+ to_remove: Option<u32>,
+ #[cfg(feature = "psk")]
+ external_psks: Vec<ExternalPskId>,
+ authenticated_data: Vec<u8>,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec<Proposal>,
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec<MlsMessage>,
+}
+
+impl<C: ClientConfig> ExternalCommitBuilder<C> {
+ pub(crate) fn new(
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ config: C,
+ ) -> Self {
+ Self {
+ tree_data: None,
+ to_remove: None,
+ authenticated_data: Vec::new(),
+ signer,
+ signing_identity,
+ config,
+ #[cfg(feature = "psk")]
+ external_psks: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: Vec::new(),
+ #[cfg(feature = "custom_proposal")]
+ received_custom_proposals: Vec::new(),
+ }
+ }
+
+ #[must_use]
+ /// Use external tree data if the GroupInfo message does not contain a
+ /// [`RatchetTreeExt`](crate::extension::built_in::RatchetTreeExt)
+ pub fn with_tree_data(self, tree_data: ExportedTree<'static>) -> Self {
+ Self {
+ tree_data: Some(tree_data),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Propose the removal of an old version of the client as part of the external commit.
+ /// Only one such proposal is allowed.
+ pub fn with_removal(self, to_remove: u32) -> Self {
+ Self {
+ to_remove: Some(to_remove),
+ ..self
+ }
+ }
+
+ #[must_use]
+ /// Add plaintext authenticated data to the resulting commit message.
+ pub fn with_authenticated_data(self, data: Vec<u8>) -> Self {
+ Self {
+ authenticated_data: data,
+ ..self
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[must_use]
+ /// Add an external psk to the group as part of the external commit.
+ pub fn with_external_psk(mut self, psk: ExternalPskId) -> Self {
+ self.external_psks.push(psk);
+ self
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[must_use]
+ /// Insert a [`CustomProposal`] into the current commit that is being built.
+ pub fn with_custom_proposal(mut self, proposal: CustomProposal) -> Self {
+ self.custom_proposals.push(Proposal::Custom(proposal));
+ self
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[must_use]
+ /// Insert a [`CustomProposal`] received from a current group member into the current
+ /// commit that is being built.
+ ///
+ /// # Warning
+ ///
+ /// The authenticity of the proposal is NOT fully verified. It is only verified the
+ /// same way as by [`ExternalGroup`](`crate::external_client::ExternalGroup`).
+ /// The proposal MUST be an MlsPlaintext, else the [`Self::build`] function will fail.
+ pub fn with_received_custom_proposal(mut self, proposal: MlsMessage) -> Self {
+ self.received_custom_proposals.push(proposal);
+ self
+ }
+
+ /// Build the external commit using a GroupInfo message provided by an existing group member.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn build(self, group_info: MlsMessage) -> Result<(Group<C>, MlsMessage), MlsError> {
+ let protocol_version = group_info.version;
+
+ if !self.config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let group_info = group_info
+ .into_group_info()
+ .ok_or(MlsError::UnexpectedMessageType)?;
+
+ let cipher_suite = cipher_suite_provider(
+ self.config.crypto_provider(),
+ group_info.group_context.cipher_suite,
+ )?;
+
+ let external_pub_ext = group_info
+ .extensions
+ .get_as::<ExternalPubExt>()?
+ .ok_or(MlsError::MissingExternalPubExtension)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ self.tree_data,
+ &self.config.identity_provider(),
+ &cipher_suite,
+ )
+ .await?;
+
+ let (leaf_node, _) = LeafNode::generate(
+ &cipher_suite,
+ self.config.leaf_properties(),
+ self.signing_identity,
+ &self.signer,
+ self.config.lifetime(),
+ )
+ .await?;
+
+ let (init_secret, kem_output) =
+ InitSecret::encode_for_external(&cipher_suite, &external_pub_ext.external_pub).await?;
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::new(vec![]),
+ sender_data_secret: SenderDataSecret::from(vec![]),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::empty(),
+ };
+
+ let (mut group, _) = Group::join_with(
+ self.config,
+ group_info,
+ public_tree,
+ KeySchedule::new(init_secret),
+ epoch_secrets,
+ TreeKemPrivate::new_for_external(),
+ None,
+ self.signer,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_ids = self
+ .external_psks
+ .into_iter()
+ .map(|psk_id| PreSharedKeyID::new(JustPreSharedKeyID::External(psk_id), &cipher_suite))
+ .collect::<Result<Vec<_>, MlsError>>()?;
+
+ let mut proposals = vec![Proposal::ExternalInit(ExternalInit { kem_output })];
+
+ #[cfg(feature = "psk")]
+ proposals.extend(
+ psk_ids
+ .into_iter()
+ .map(|psk| Proposal::Psk(PreSharedKeyProposal { psk })),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let mut custom_proposals = self.custom_proposals;
+ proposals.append(&mut custom_proposals);
+ }
+
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ for message in self.received_custom_proposals {
+ let MlsMessagePayload::Plain(plaintext) = message.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let auth_content = AuthenticatedContent::from(plaintext.clone());
+
+ verify_plaintext_authentication(&cipher_suite, plaintext, None, None, &group.state)
+ .await?;
+
+ group
+ .process_event_or_content(EventOrContent::Content(auth_content), true, None)
+ .await?;
+ }
+
+ if let Some(r) = self.to_remove {
+ proposals.push(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(r),
+ }));
+ }
+
+ let commit_output = group
+ .commit_internal(
+ proposals,
+ Some(&leaf_node),
+ self.authenticated_data,
+ Default::default(),
+ None,
+ None,
+ )
+ .await?;
+
+ group.apply_pending_commit().await?;
+
+ Ok((group, commit_output.commit_message))
+ }
+}
diff --git a/src/group/framing.rs b/src/group/framing.rs
new file mode 100644
index 0000000..8663b96
--- /dev/null
+++ b/src/group/framing.rs
@@ -0,0 +1,741 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use crate::{client::MlsError, tree_kem::node::LeafIndex, KeyPackage, KeyPackageRef};
+
+use super::{Commit, FramedContentAuthData, GroupInfo, MembershipTag, Welcome};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{group::Proposal, mls_rules::ProposalRef};
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider},
+ protocol_version::ProtocolVersion,
+};
+use zeroize::ZeroizeOnDrop;
+
+#[cfg(feature = "private_message")]
+use alloc::boxed::Box;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::{CustomProposal, ProposalOrRef};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u8)]
+pub enum ContentType {
+ #[cfg(feature = "private_message")]
+ Application = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal = 2u8,
+ Commit = 3u8,
+}
+
+impl From<&Content> for ContentType {
+ fn from(content: &Content) -> Self {
+ match content {
+ #[cfg(feature = "private_message")]
+ Content::Application(_) => ContentType::Application,
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(_) => ContentType::Proposal,
+ Content::Commit(_) => ContentType::Commit,
+ }
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Copy, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+#[non_exhaustive]
+/// Description of a [`MlsMessage`] sender
+pub enum Sender {
+ /// Current group member index.
+ Member(u32) = 1u8,
+ /// An external entity sending a proposal proposal identified by an index
+ /// in the current
+ /// [`ExternalSendersExt`](crate::extension::ExternalSendersExt) stored in
+ /// group context extensions.
+ #[cfg(feature = "by_ref_proposal")]
+ External(u32) = 2u8,
+ /// A new member proposing their own addition to the group.
+ #[cfg(feature = "by_ref_proposal")]
+ NewMemberProposal = 3u8,
+ /// A member sending an external commit.
+ NewMemberCommit = 4u8,
+}
+
+impl From<LeafIndex> for Sender {
+ fn from(leaf_index: LeafIndex) -> Self {
+ Sender::Member(*leaf_index)
+ }
+}
+
+impl From<u32> for Sender {
+ fn from(leaf_index: u32) -> Self {
+ Sender::Member(leaf_index)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, ZeroizeOnDrop)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ApplicationData(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ApplicationData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ApplicationData")
+ .fmt(f)
+ }
+}
+
+impl From<Vec<u8>> for ApplicationData {
+ fn from(data: Vec<u8>) -> Self {
+ Self(data)
+ }
+}
+
+impl Deref for ApplicationData {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ApplicationData {
+ /// Underlying message content.
+ pub fn as_bytes(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum Content {
+ #[cfg(feature = "private_message")]
+ Application(ApplicationData) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal(alloc::boxed::Box<Proposal>) = 2u8,
+ Commit(alloc::boxed::Box<Commit>) = 3u8,
+}
+
+impl Content {
+ pub fn content_type(&self) -> ContentType {
+ self.into()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct PublicMessage {
+ pub content: FramedContent,
+ pub auth: FramedContentAuthData,
+ pub membership_tag: Option<MembershipTag>,
+}
+
+impl MlsSize for PublicMessage {
+ fn mls_encoded_len(&self) -> usize {
+ self.content.mls_encoded_len()
+ + self.auth.mls_encoded_len()
+ + self
+ .membership_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for PublicMessage {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.content.mls_encode(writer)?;
+ self.auth.mls_encode(writer)?;
+
+ self.membership_tag
+ .as_ref()
+ .map_or(Ok(()), |tag| tag.mls_encode(writer))
+ }
+}
+
+impl MlsDecode for PublicMessage {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let content = FramedContent::mls_decode(reader)?;
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ let membership_tag = match content.sender {
+ Sender::Member(_) => Some(MembershipTag::mls_decode(reader)?),
+ _ => None,
+ };
+
+ Ok(Self {
+ content,
+ auth,
+ membership_tag,
+ })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct PrivateMessageContent {
+ pub content: Content,
+ pub auth: FramedContentAuthData,
+}
+
+#[cfg(feature = "private_message")]
+impl MlsSize for PrivateMessageContent {
+ fn mls_encoded_len(&self) -> usize {
+ let content_len_without_type = match &self.content {
+ Content::Application(c) => c.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encoded_len(),
+ Content::Commit(c) => c.mls_encoded_len(),
+ };
+
+ content_len_without_type + self.auth.mls_encoded_len()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl MlsEncode for PrivateMessageContent {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ match &self.content {
+ Content::Application(c) => c.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(c) => c.mls_encode(writer),
+ Content::Commit(c) => c.mls_encode(writer),
+ }?;
+
+ self.auth.mls_encode(writer)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl PrivateMessageContent {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ let content = match content_type {
+ ContentType::Application => Content::Application(ApplicationData::mls_decode(reader)?),
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => Content::Proposal(Box::new(Proposal::mls_decode(reader)?)),
+ ContentType::Commit => {
+ Content::Commit(alloc::boxed::Box::new(Commit::mls_decode(reader)?))
+ }
+ };
+
+ let auth = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ if reader.iter().any(|&i| i != 0u8) {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "non-zero padding bytes discovered".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(5));
+ }
+
+ Ok(Self { content, auth })
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+pub struct PrivateContentAAD {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateContentAAD {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateContentAAD")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct PrivateMessage {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub content_type: ContentType,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub authenticated_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_sender_data: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub ciphertext: Vec<u8>,
+}
+
+#[cfg(feature = "private_message")]
+impl Debug for PrivateMessage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("PrivateMessage")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("content_type", &self.content_type)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field(
+ "encrypted_sender_data",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_sender_data),
+ )
+ .field(
+ "ciphertext",
+ &mls_rs_core::debug::pretty_bytes(&self.ciphertext),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl From<&PrivateMessage> for PrivateContentAAD {
+ fn from(ciphertext: &PrivateMessage) -> Self {
+ Self {
+ group_id: ciphertext.group_id.clone(),
+ epoch: ciphertext.epoch,
+ content_type: ciphertext.content_type,
+ authenticated_data: ciphertext.authenticated_data.clone(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ ::safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+/// A MLS protocol message for sending data over the wire.
+pub struct MlsMessage {
+ pub(crate) version: ProtocolVersion,
+ pub(crate) payload: MlsMessagePayload,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+#[allow(dead_code)]
+impl MlsMessage {
+ pub(crate) fn new(version: ProtocolVersion, payload: MlsMessagePayload) -> MlsMessage {
+ Self { version, payload }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_plaintext(self) -> Option<PublicMessage> {
+ match self.payload {
+ MlsMessagePayload::Plain(plaintext) => Some(plaintext),
+ _ => None,
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline(always)]
+ pub(crate) fn into_ciphertext(self) -> Option<PrivateMessage> {
+ match self.payload {
+ MlsMessagePayload::Cipher(ciphertext) => Some(ciphertext),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub(crate) fn into_welcome(self) -> Option<Welcome> {
+ match self.payload {
+ MlsMessagePayload::Welcome(welcome) => Some(welcome),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_group_info(self) -> Option<GroupInfo> {
+ match self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn as_group_info(&self) -> Option<&GroupInfo> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(info) => Some(info),
+ _ => None,
+ }
+ }
+
+ #[inline(always)]
+ pub fn into_key_package(self) -> Option<KeyPackage> {
+ match self.payload {
+ MlsMessagePayload::KeyPackage(kp) => Some(kp),
+ _ => None,
+ }
+ }
+
+ /// The wire format value describing the contents of this message.
+ pub fn wire_format(&self) -> WireFormat {
+ match self.payload {
+ MlsMessagePayload::Plain(_) => WireFormat::PublicMessage,
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(_) => WireFormat::PrivateMessage,
+ MlsMessagePayload::Welcome(_) => WireFormat::Welcome,
+ MlsMessagePayload::GroupInfo(_) => WireFormat::GroupInfo,
+ MlsMessagePayload::KeyPackage(_) => WireFormat::KeyPackage,
+ }
+ }
+
+ /// The epoch that this message belongs to.
+ ///
+ /// Returns `None` if the message is [`WireFormat::KeyPackage`]
+ /// or [`WireFormat::Welcome`]
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn epoch(&self) -> Option<u64> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(p.content.epoch),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(c) => Some(c.epoch),
+ MlsMessagePayload::GroupInfo(gi) => Some(gi.group_context.epoch),
+ _ => None,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn cipher_suite(&self) -> Option<CipherSuite> {
+ match &self.payload {
+ MlsMessagePayload::GroupInfo(i) => Some(i.group_context.cipher_suite),
+ MlsMessagePayload::Welcome(w) => Some(w.cipher_suite),
+ MlsMessagePayload::KeyPackage(k) => Some(k.cipher_suite),
+ _ => None,
+ }
+ }
+
+ pub fn group_id(&self) -> Option<&[u8]> {
+ match &self.payload {
+ MlsMessagePayload::Plain(p) => Some(&p.content.group_id),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(p) => Some(&p.group_id),
+ MlsMessagePayload::GroupInfo(p) => Some(&p.group_context.group_id),
+ MlsMessagePayload::KeyPackage(_) | MlsMessagePayload::Welcome(_) => None,
+ }
+ }
+
+ /// Deserialize a message from transport.
+ #[inline(never)]
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut &*bytes).map_err(Into::into)
+ }
+
+ /// Serialize a message for transport.
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+
+ /// If this is a plaintext commit message, return all custom proposals committed by value.
+ /// If this is not a plaintext or not a commit, this returns an empty list.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals_by_value(&self) -> Vec<&CustomProposal> {
+ match &self.payload {
+ MlsMessagePayload::Plain(plaintext) => match &plaintext.content.content {
+ Content::Commit(commit) => Self::find_custom_proposals(commit),
+ _ => Vec::new(),
+ },
+ _ => Vec::new(),
+ }
+ }
+
+ /// If this is a welcome message, return key package references of all members who can
+ /// join using this message.
+ pub fn welcome_key_package_references(&self) -> Vec<&KeyPackageRef> {
+ let MlsMessagePayload::Welcome(welcome) = &self.payload else {
+ return Vec::new();
+ };
+
+ welcome.secrets.iter().map(|s| &s.new_member).collect()
+ }
+
+ /// If this is a key package, return its key package reference.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn key_package_reference<C: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &C,
+ ) -> Result<Option<KeyPackageRef>, MlsError> {
+ let MlsMessagePayload::KeyPackage(kp) = &self.payload else {
+ return Ok(None);
+ };
+
+ kp.to_reference(cipher_suite).await.map(Some)
+ }
+
+ /// If this is a plaintext proposal, return the proposal reference that can be matched e.g. with
+ /// [`StateUpdate::unused_proposals`](super::StateUpdate::unused_proposals).
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn into_proposal_reference<C: CipherSuiteProvider>(
+ self,
+ cipher_suite: &C,
+ ) -> Result<Option<Vec<u8>>, MlsError> {
+ let MlsMessagePayload::Plain(public_message) = self.payload else {
+ return Ok(None);
+ };
+
+ ProposalRef::from_content(cipher_suite, &public_message.into())
+ .await
+ .map(|r| Some(r.to_vec()))
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl MlsMessage {
+ fn find_custom_proposals(commit: &Commit) -> Vec<&CustomProposal> {
+ commit
+ .proposals
+ .iter()
+ .filter_map(|p| match p {
+ ProposalOrRef::Proposal(p) => match p.as_ref() {
+ crate::group::Proposal::Custom(p) => Some(p),
+ _ => None,
+ },
+ _ => None,
+ })
+ .collect()
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[repr(u16)]
+pub(crate) enum MlsMessagePayload {
+ Plain(PublicMessage) = 1u16,
+ #[cfg(feature = "private_message")]
+ Cipher(PrivateMessage) = 2u16,
+ Welcome(Welcome) = 3u16,
+ GroupInfo(GroupInfo) = 4u16,
+ KeyPackage(KeyPackage) = 5u16,
+}
+
+impl From<PublicMessage> for MlsMessagePayload {
+ fn from(m: PublicMessage) -> Self {
+ Self::Plain(m)
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(
+ Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd, MlsSize, MlsEncode, MlsDecode,
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// Content description of an [`MlsMessage`]
+pub enum WireFormat {
+ PublicMessage = 1u16,
+ PrivateMessage = 2u16,
+ Welcome = 3u16,
+ GroupInfo = 4u16,
+ KeyPackage = 5u16,
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct FramedContent {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ pub sender: Sender,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub authenticated_data: Vec<u8>,
+ pub content: Content,
+}
+
+impl Debug for FramedContent {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("FramedContent")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("epoch", &self.epoch)
+ .field("sender", &self.sender)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("content", &self.content)
+ .finish()
+ }
+}
+
+impl FramedContent {
+ pub fn content_type(&self) -> ContentType {
+ self.content.content_type()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ #[cfg(feature = "private_message")]
+ use crate::group::test_utils::random_bytes;
+
+ use crate::group::{AuthenticatedContent, MessageSignature};
+
+ use super::*;
+
+ use alloc::boxed::Box;
+
+ pub(crate) fn get_test_auth_content() -> AuthenticatedContent {
+ // This is not a valid commit and should not be validated
+ let commit = Commit {
+ proposals: Default::default(),
+ path: None,
+ };
+
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: Vec::new(),
+ epoch: 0,
+ sender: Sender::Member(1),
+ authenticated_data: Vec::new(),
+ content: Content::Commit(Box::new(commit)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn get_test_ciphertext_content() -> PrivateMessageContent {
+ PrivateMessageContent {
+ content: Content::Application(random_bytes(1024).into()),
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ impl AsRef<[u8]> for ApplicationData {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ framing::test_utils::get_test_ciphertext_content,
+ proposal_ref::test_utils::auth_content_from_proposal, RemoveProposal,
+ },
+ };
+
+ use super::*;
+
+ #[test]
+ fn test_mls_ciphertext_content_mls_encoding() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[0u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into())
+ .unwrap();
+
+ assert_eq!(ciphertext_content, decoded);
+ }
+
+ #[test]
+ fn test_mls_ciphertext_content_non_zero_padding_error() {
+ let ciphertext_content = get_test_ciphertext_content();
+
+ let mut encoded = ciphertext_content.mls_encode_to_vec().unwrap();
+ encoded.extend_from_slice(&[1u8; 128]);
+
+ let decoded =
+ PrivateMessageContent::mls_decode(&mut &*encoded, (&ciphertext_content.content).into());
+
+ assert_matches!(decoded, Err(mls_rs_codec::Error::Custom(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_ref() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_auth = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ }),
+ Sender::External(0),
+ );
+
+ let expected_ref = ProposalRef::from_content(&cs, &test_auth).await.unwrap();
+
+ let test_message = MlsMessage {
+ version: TEST_PROTOCOL_VERSION,
+ payload: MlsMessagePayload::Plain(PublicMessage {
+ content: test_auth.content,
+ auth: test_auth.auth,
+ membership_tag: Some(cs.mac(&[1, 2, 3], &[1, 2, 3]).await.unwrap().into()),
+ }),
+ };
+
+ let computed_ref = test_message
+ .into_proposal_reference(&cs)
+ .await
+ .unwrap()
+ .unwrap();
+
+ assert_eq!(computed_ref, expected_ref.to_vec());
+ }
+}
diff --git a/src/group/group_info.rs b/src/group/group_info.rs
new file mode 100644
index 0000000..a5e7268
--- /dev/null
+++ b/src/group/group_info.rs
@@ -0,0 +1,95 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{signer::Signable, tree_kem::node::LeafIndex};
+
+use super::{ConfirmationTag, GroupContext};
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+pub struct GroupInfo {
+ pub(crate) group_context: GroupContext,
+ pub(crate) extensions: ExtensionList,
+ pub(crate) confirmation_tag: ConfirmationTag,
+ pub(crate) signer: LeafIndex,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub(crate) signature: Vec<u8>,
+}
+
+impl Debug for GroupInfo {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupInfo")
+ .field("group_context", &self.group_context)
+ .field("extensions", &self.extensions)
+ .field("confirmation_tag", &self.confirmation_tag)
+ .field("signer", &self.signer)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl GroupInfo {
+ /// Group context.
+ pub fn group_context(&self) -> &GroupContext {
+ &self.group_context
+ }
+
+ /// Group info extensions (not to be confused with group context extensions),
+ /// e.g. the ratchet tree.
+ pub fn extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+
+ /// Leaf index of the sender who generated and signed this group info.
+ pub fn sender(&self) -> u32 {
+ *self.signer
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct SignableGroupInfo<'a> {
+ group_context: &'a GroupContext,
+ extensions: &'a ExtensionList,
+ confirmation_tag: &'a ConfirmationTag,
+ signer: LeafIndex,
+}
+
+impl<'a> Signable<'a> for GroupInfo {
+ const SIGN_LABEL: &'static str = "GroupInfoTBS";
+ type SigningContext = ();
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ _context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ SignableGroupInfo {
+ group_context: &self.group_context,
+ extensions: &self.extensions,
+ confirmation_tag: &self.confirmation_tag,
+ signer: self.signer,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
diff --git a/src/group/interop_test_vectors.rs b/src/group/interop_test_vectors.rs
new file mode 100644
index 0000000..abd82fc
--- /dev/null
+++ b/src/group/interop_test_vectors.rs
@@ -0,0 +1,9 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod framing;
+mod passive_client;
+mod serialization;
+mod tree_kem;
+mod tree_modifications;
diff --git a/src/group/interop_test_vectors/framing.rs b/src/group/interop_test_vectors/framing.rs
new file mode 100644
index 0000000..30e4225
--- /dev/null
+++ b/src/group/interop_test_vectors/framing.rs
@@ -0,0 +1,461 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider, SignaturePublicKey};
+
+use crate::{
+ client::test_utils::{TestClientConfig, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ confirmation_tag::ConfirmationTag,
+ epoch::EpochSecrets,
+ framing::{Content, WireFormat},
+ message_processor::{EventOrContent, MessageProcessor},
+ mls_rules::EncryptionOptions,
+ padding::PaddingMode,
+ proposal::{Proposal, RemoveProposal},
+ secret_tree::test_utils::get_test_tree,
+ test_utils::{random_bytes, test_group_custom_config},
+ AuthenticatedContent, Commit, Group, GroupContext, MlsMessage, Sender,
+ },
+ mls_rules::DefaultMlsRules,
+ test_utils::is_edwards,
+ tree_kem::{leaf_node::test_utils::get_basic_test_node, node::LeafIndex},
+};
+
+const FRAMING_N_LEAVES: u32 = 2;
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct FramingTestCase {
+ #[serde(flatten)]
+ pub context: InteropGroupContext,
+
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub sender_data_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub membership_key: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub commit_pub: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub application_priv: Vec<u8>,
+}
+
+impl FramingTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ let mut context = InteropGroupContext::random(cs);
+ context.cipher_suite = cs.cipher_suite().into();
+
+ let (mut signature_priv, signature_pub) = cs.signature_key_generate().await.unwrap();
+
+ if is_edwards(*cs.cipher_suite()) {
+ signature_priv = signature_priv[0..signature_priv.len() / 2].to_vec().into();
+ }
+
+ Self {
+ context,
+ signature_priv: signature_priv.to_vec(),
+ signature_pub: signature_pub.to_vec(),
+ encryption_secret: random_bytes(cs.kdf_extract_size()),
+ sender_data_secret: random_bytes(cs.kdf_extract_size()),
+ membership_key: random_bytes(cs.kdf_extract_size()),
+ ..Default::default()
+ }
+ }
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct InteropGroupContext {
+ pub cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ pub epoch: u64,
+ #[serde(with = "hex::serde")]
+ pub tree_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash: Vec<u8>,
+}
+
+impl InteropGroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn random<P: CipherSuiteProvider>(cs: &P) -> Self {
+ Self {
+ cipher_suite: cs.cipher_suite().into(),
+ group_id: random_bytes(cs.kdf_extract_size()),
+ epoch: 0x121212,
+ tree_hash: random_bytes(cs.kdf_extract_size()),
+ confirmed_transcript_hash: random_bytes(cs.kdf_extract_size()),
+ }
+ }
+}
+
+impl From<InteropGroupContext> for GroupContext {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn from(ctx: InteropGroupContext) -> Self {
+ Self {
+ cipher_suite: ctx.cipher_suite.into(),
+ protocol_version: TEST_PROTOCOL_VERSION,
+ group_id: ctx.group_id,
+ epoch: ctx.epoch,
+ tree_hash: ctx.tree_hash,
+ confirmed_transcript_hash: ctx.confirmed_transcript_hash.into(),
+ extensions: vec![].into(),
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_proposal() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let to_check = vec![
+ test_case.proposal_priv.clone(),
+ test_case.proposal_pub.clone(),
+ ];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let mut to_check = to_check;
+
+ #[cfg(not(target_arch = "wasm32"))]
+ for enable_encryption in [true, false] {
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Proposal(p) => assert_eq!(p.as_ref(), &proposal),
+ _ => panic!("received value not proposal"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+// Wasm uses incompatible signature secret key format
+#[cfg(not(target_arch = "wasm32"))]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_application() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let built_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ for message in [&test_case.application_priv, &built_priv] {
+ match process_message(&test_case, message, &cs).await {
+ Content::Application(data) => assert_eq!(data.as_ref(), &test_case.application),
+ _ => panic!("decrypted value not application data"),
+ };
+ }
+ }
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/message-protection.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn framing_commit() {
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<FramingTestCase> =
+ load_test_case_json!(framing, generate_framing_test_vector().await);
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.context.cipher_suite) else {
+ continue;
+ };
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ let to_check = vec![test_case.commit_priv.clone(), test_case.commit_pub.clone()];
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ let to_check = {
+ let mut to_check = to_check;
+
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &test_case.context.clone().into(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &signature_priv.into(),
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ for enable_encryption in [true, false] {
+ let built = make_group(&test_case, true, enable_encryption, &cs)
+ .await
+ .format_for_wire(auth_content.clone())
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ to_check.push(built);
+ }
+
+ to_check
+ };
+
+ for message in to_check {
+ match process_message(&test_case, &message, &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ match process_message(&test_case, &test_case.commit_priv.clone(), &cs).await {
+ Content::Commit(c) => assert_eq!(&*c, &commit),
+ _ => panic!("received value not commit"),
+ };
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_framing_test_vector() -> Vec<FramingTestCase> {
+ let mut test_vector = vec![];
+
+ for cs in CipherSuite::all() {
+ let cs = test_cipher_suite_provider(cs);
+
+ let mut test_case = FramingTestCase::random(&cs).await;
+
+ // Generate private application message
+ test_case.application = cs.random_bytes_vec(42).unwrap();
+
+ let application_priv = make_group(&test_case, true, true, &cs)
+ .await
+ .encrypt_application_message(&test_case.application, vec![])
+ .await
+ .unwrap();
+
+ test_case.application_priv = application_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public proposal message
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(2),
+ });
+
+ test_case.proposal = proposal.mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let proposal_pub = group.proposal_message(proposal.clone(), vec![]).await;
+ test_case.proposal_pub = proposal_pub.unwrap().mls_encode_to_vec().unwrap();
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let proposal_priv = group.proposal_message(proposal, vec![]).await.unwrap();
+ test_case.proposal_priv = proposal_priv.mls_encode_to_vec().unwrap();
+
+ // Generate private and public commit message
+ let commit = Commit {
+ proposals: vec![],
+ path: None,
+ };
+
+ test_case.commit = commit.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit.clone())),
+ &group.signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, false, &cs).await;
+ let commit_pub = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_pub = commit_pub.mls_encode_to_vec().unwrap();
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ group.context(),
+ Sender::Member(1),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &group.signer,
+ WireFormat::PrivateMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ let mut group = make_group(&test_case, true, true, &cs).await;
+ let commit_priv = group.format_for_wire(auth_content.clone()).await.unwrap();
+ test_case.commit_priv = commit_priv.mls_encode_to_vec().unwrap();
+
+ test_vector.push(test_case);
+ }
+
+ test_vector
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn make_group<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ for_send: bool,
+ control_encryption_enabled: bool,
+ cs: &P,
+) -> Group<TestClientConfig> {
+ let mut group =
+ test_group_custom_config(
+ TEST_PROTOCOL_VERSION,
+ test_case.context.cipher_suite.into(),
+ |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(control_encryption_enabled, PaddingMode::None),
+ ))
+ },
+ )
+ .await
+ .group;
+
+ // Add a leaf for the sender. It will get index 1.
+ let mut leaf = get_basic_test_node(cs.cipher_suite(), "leaf").await;
+
+ leaf.signing_identity.signature_key = SignaturePublicKey::from(test_case.signature_pub.clone());
+
+ group
+ .state
+ .public_tree
+ .add_leaves(vec![leaf], &group.config.0.identity_provider, cs)
+ .await
+ .unwrap();
+
+ // Convince the group that their index is 1 if they send or 0 if they receive.
+ group.private_tree.self_index = LeafIndex(if for_send { 1 } else { 0 });
+
+ // Convince the group that their signing key is the one from the test case
+ let mut signature_priv = test_case.signature_priv.clone();
+
+ if is_edwards(test_case.context.cipher_suite) {
+ signature_priv.extend(test_case.signature_pub.iter());
+ }
+
+ group.signer = signature_priv.into();
+
+ // Set the group context and secrets
+ let context = GroupContext::from(test_case.context.clone());
+ let secret_tree = get_test_tree(test_case.encryption_secret.clone(), FRAMING_N_LEAVES);
+
+ let secrets = EpochSecrets {
+ secret_tree,
+ resumption_secret: vec![0_u8; cs.kdf_extract_size()].into(),
+ sender_data_secret: test_case.sender_data_secret.clone().into(),
+ };
+
+ group.epoch_secrets = secrets;
+ group.state.context = context;
+ let membership_key = test_case.membership_key.clone();
+ group.key_schedule.set_membership_key(membership_key);
+
+ group
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn process_message<P: CipherSuiteProvider>(
+ test_case: &FramingTestCase,
+ message: &[u8],
+ cs: &P,
+) -> Content {
+ // Enabling encryption doesn't matter for processing
+ let mut group = make_group(test_case, false, true, cs).await;
+ let message = MlsMessage::mls_decode(&mut &*message).unwrap();
+ let evt_or_cont = group.get_event_from_incoming_message(message);
+
+ match evt_or_cont.await.unwrap() {
+ EventOrContent::Content(content) => content.content.content,
+ EventOrContent::Event(_) => panic!("expected content, got event"),
+ }
+}
diff --git a/src/group/interop_test_vectors/passive_client.rs b/src/group/interop_test_vectors/passive_client.rs
new file mode 100644
index 0000000..29588ed
--- /dev/null
+++ b/src/group/interop_test_vectors/passive_client.rs
@@ -0,0 +1,732 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+
+use itertools::Itertools;
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ psk::ExternalPskId,
+ time::MlsTime,
+};
+use rand::{seq::IteratorRandom, Rng, SeedableRng};
+
+use crate::{
+ client_builder::{ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{ClientConfig, CommitBuilder, ExportedTree},
+ identity::basic::BasicIdentityProvider,
+ key_package::KeyPackageGeneration,
+ mls_rules::CommitOptions,
+ storage_provider::in_memory::InMemoryKeyPackageStorage,
+ test_utils::{
+ all_process_message, generate_basic_client, get_test_basic_credential, get_test_groups,
+ make_test_ext_psk, TEST_EXT_PSK_ID,
+ },
+ tree_kem::Lifetime,
+ Client, Group, MlsMessage,
+};
+
+const VERSION: ProtocolVersion = ProtocolVersion::MLS_10;
+
+const ETERNAL_LIFETIME: Lifetime = Lifetime {
+ not_before: 0,
+ not_after: u64::MAX,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestCase {
+ pub cipher_suite: u16,
+
+ pub external_psks: Vec<TestExternalPsk>,
+ #[serde(with = "hex::serde")]
+ pub key_package: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub init_priv: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub welcome: Vec<u8>,
+ pub ratchet_tree: Option<TestRatchetTree>,
+ #[serde(with = "hex::serde")]
+ pub initial_epoch_authenticator: Vec<u8>,
+
+ pub epochs: Vec<TestEpoch>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestExternalPsk {
+ #[serde(with = "hex::serde")]
+ pub psk_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub psk: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestEpoch {
+ pub proposals: Vec<TestMlsMessage>,
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub epoch_authenticator: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl TestEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn new(
+ proposals: Vec<MlsMessage>,
+ commit: &MlsMessage,
+ epoch_authenticator: Vec<u8>,
+ ) -> Self {
+ let proposals = proposals
+ .into_iter()
+ .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
+ .collect();
+
+ Self {
+ proposals,
+ commit: commit.to_bytes().unwrap(),
+ epoch_authenticator,
+ }
+ }
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn interop_passive_client() {
+ // Test vectors can be found here:
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-welcome.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-handle-commit.json
+ // * https://github.com/mlswg/mls-implementations/blob/main/test-vectors/passive-client-random.json
+
+ #[cfg(mls_build_async)]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests().await
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests().await
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests().await
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ #[cfg(not(mls_build_async))]
+ let (test_cases_wel, test_cases_com, test_cases_rand) = {
+ let test_cases_wel: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_welcome,
+ generate_passive_client_welcome_tests()
+ );
+
+ let test_cases_com: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_handle_commit,
+ generate_passive_client_proposal_tests()
+ );
+
+ let test_cases_rand: Vec<TestCase> = load_test_case_json!(
+ interop_passive_client_random,
+ generate_passive_client_random_tests()
+ );
+
+ (test_cases_wel, test_cases_com, test_cases_rand)
+ };
+
+ for test_case in vec![]
+ .into_iter()
+ .chain(test_cases_com)
+ .chain(test_cases_wel)
+ .chain(test_cases_rand)
+ {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(test_case.cipher_suite.into()) else {
+ continue;
+ };
+
+ let message = MlsMessage::from_bytes(&test_case.key_package).unwrap();
+ let key_package = message.into_key_package().unwrap();
+ let id = key_package.leaf_node.signing_identity.clone();
+ let key = test_case.signature_priv.clone().into();
+
+ let mut client_builder = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new());
+
+ for psk in test_case.external_psks {
+ client_builder = client_builder.psk(ExternalPskId::new(psk.psk_id), psk.psk.into());
+ }
+
+ let client = client_builder
+ .signing_identity(id, key, cs.cipher_suite())
+ .build();
+
+ let key_pckg_gen = KeyPackageGeneration {
+ reference: key_package.to_reference(&cs).await.unwrap(),
+ key_package,
+ init_secret_key: test_case.init_priv.into(),
+ leaf_node_secret_key: test_case.encryption_priv.into(),
+ };
+
+ let (id, pkg) = key_pckg_gen.to_storage().unwrap();
+ client.config.key_package_repo().insert(id, pkg);
+
+ let welcome = MlsMessage::from_bytes(&test_case.welcome).unwrap();
+
+ let tree = test_case
+ .ratchet_tree
+ .map(|t| ExportedTree::from_bytes(&t.0).unwrap());
+
+ let (mut group, _info) = client.join_group(tree, &welcome).await.unwrap();
+
+ assert_eq!(
+ group.epoch_authenticator().unwrap().to_vec(),
+ test_case.initial_epoch_authenticator
+ );
+
+ for epoch in test_case.epochs {
+ for proposal in epoch.proposals.iter() {
+ let message = MlsMessage::from_bytes(&proposal.0).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+ }
+
+ let message = MlsMessage::from_bytes(&epoch.commit).unwrap();
+
+ group
+ .process_incoming_message_with_time(message, MlsTime::now())
+ .await
+ .unwrap();
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ group.epoch_authenticator().unwrap().to_vec()
+ );
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn invite_passive_client<P: CipherSuiteProvider>(
+ groups: &mut [Group<impl MlsConfig>],
+ with_psk: bool,
+ cs: &P,
+) -> TestCase {
+ let crypto_provider = TestCryptoProvider::new();
+
+ let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
+ let credential = get_test_basic_credential(b"Arnold".to_vec());
+ let identity = SigningIdentity::new(credential, public_key);
+ let key_package_repo = InMemoryKeyPackageStorage::new();
+
+ let client = ClientBuilder::new()
+ .crypto_provider(crypto_provider)
+ .identity_provider(BasicIdentityProvider::new())
+ .key_package_repo(key_package_repo.clone())
+ .key_package_lifetime(ETERNAL_LIFETIME.not_after - ETERNAL_LIFETIME.not_before)
+ .key_package_not_before(ETERNAL_LIFETIME.not_before)
+ .signing_identity(identity.clone(), secret_key.clone(), cs.cipher_suite())
+ .build();
+
+ let key_pckg = client.generate_key_package_message().await.unwrap();
+
+ let (_, key_pckg_secrets) = key_package_repo.key_packages()[0].clone();
+
+ let mut commit_builder = groups[0]
+ .commit_builder()
+ .add_member(key_pckg.clone())
+ .unwrap();
+
+ if with_psk {
+ commit_builder = commit_builder
+ .add_external_psk(ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()))
+ .unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap();
+
+ all_process_message(groups, &commit.commit_message, 0, true).await;
+
+ let external_psk = TestExternalPsk {
+ psk_id: TEST_EXT_PSK_ID.to_vec(),
+ psk: make_test_ext_psk(),
+ };
+
+ TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+ key_package: key_pckg.to_bytes().unwrap(),
+ encryption_priv: key_pckg_secrets.leaf_node_key.to_vec(),
+ init_priv: key_pckg_secrets.init_key.to_vec(),
+ welcome: commit.welcome_messages[0].to_bytes().unwrap(),
+ initial_epoch_authenticator: groups[0].epoch_authenticator().unwrap().to_vec(),
+ epochs: vec![],
+ signature_priv: secret_key.to_vec(),
+ external_psks: if with_psk { vec![external_psk] } else { vec![] },
+ ratchet_tree: None,
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_proposal_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let mut groups =
+ get_test_groups(VERSION, cs.cipher_suite(), 7, None, false, &crypto_provider).await;
+
+ let mut partial_test_case = invite_passive_client(&mut groups, true, &cs).await;
+
+ // Create a new epoch s.t. the passive member can process resumption PSK from the current one
+ let commit = groups[0].commit(vec![]).await.unwrap();
+ all_process_message(&mut groups, &commit.commit_message, 0, true).await;
+
+ partial_test_case.epochs.push(TestEpoch::new(
+ vec![],
+ &commit.commit_message,
+ groups[0].epoch_authenticator().unwrap().to_vec(),
+ ));
+
+ let psk = ExternalPskId::new(TEST_EXT_PSK_ID.to_vec());
+ let key_pckg = create_key_package(cs.cipher_suite()).await;
+
+ // Create by value proposals
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.add_member(key_pckg.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| b.remove_member(5).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[1].clone(),
+ |b| b.add_external_psk(psk.clone()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[5].clone(),
+ |b| b.add_resumption_psk(groups[1].current_epoch() - 1).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[2].clone(),
+ |b| b.set_group_context_ext(Default::default()).unwrap(),
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ let test_case = commit_by_value(
+ &mut groups[3].clone(),
+ |b| {
+ b.add_member(key_pckg)
+ .unwrap()
+ .remove_member(5)
+ .unwrap()
+ .add_external_psk(psk.clone())
+ .unwrap()
+ .add_resumption_psk(groups[4].current_epoch() - 1)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ },
+ partial_test_case.clone(),
+ )
+ .await;
+
+ test_cases.push(test_case);
+
+ // Create by reference proposals
+ let add = groups[0]
+ .propose_add(create_key_package(cs.cipher_suite()).await, vec![])
+ .await
+ .unwrap();
+
+ let add = (add, 0);
+
+ let update = (groups[1].propose_update(vec![]).await.unwrap(), 1);
+ let remove = (groups[2].propose_remove(2, vec![]).await.unwrap(), 2);
+
+ let ext_psk = groups[3]
+ .propose_external_psk(psk.clone(), vec![])
+ .await
+ .unwrap();
+
+ let ext_psk = (ext_psk, 3);
+
+ let last_ep = groups[3].current_epoch() - 1;
+
+ let res_psk = groups[3]
+ .propose_resumption_psk(last_ep, vec![])
+ .await
+ .unwrap();
+
+ let res_psk = (res_psk, 3);
+
+ let grp_ext = groups[4]
+ .propose_group_context_extensions(Default::default(), vec![])
+ .await
+ .unwrap();
+
+ let grp_ext = (grp_ext, 4);
+
+ let proposals = [add, update, remove, ext_psk, res_psk, grp_ext];
+
+ for (p, sender) in &proposals {
+ let mut groups = groups.clone();
+
+ all_process_message(&mut groups, p, *sender, false).await;
+
+ let commit = groups[5].commit(vec![]).await.unwrap().commit_message;
+
+ groups[5].apply_pending_commit().await.unwrap();
+ let auth = groups[5].epoch_authenticator().unwrap().to_vec();
+
+ let mut test_case = partial_test_case.clone();
+ let epoch = TestEpoch::new(vec![p.clone()], &commit, auth);
+ test_case.epochs.push(epoch);
+
+ test_cases.push(test_case);
+ }
+
+ let mut group = groups[4].clone();
+
+ for (p, _) in proposals.iter().filter(|(_, i)| *i != 4) {
+ group.process_incoming_message(p.clone()).await.unwrap();
+ }
+
+ let commit = group.commit(vec![]).await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let mut test_case = partial_test_case.clone();
+ let proposals = proposals.into_iter().map(|(p, _)| p).collect();
+ let epoch = TestEpoch::new(proposals, &commit, auth);
+ test_case.epochs.push(epoch);
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn commit_by_value<F, C: MlsConfig>(
+ group: &mut Group<C>,
+ proposal_adder: F,
+ partial_test_case: TestCase,
+) -> TestCase
+where
+ F: FnOnce(CommitBuilder<C>) -> CommitBuilder<C>,
+{
+ let builder = proposal_adder(group.commit_builder());
+ let commit = builder.build().await.unwrap().commit_message;
+ group.apply_pending_commit().await.unwrap();
+ let auth = group.epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+ let mut test_case = partial_test_case;
+ test_case.epochs.push(epoch);
+ test_case
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn create_key_package(cs: CipherSuite) -> MlsMessage {
+ let client = generate_basic_client(
+ cs,
+ VERSION,
+ 0xbeef,
+ None,
+ false,
+ &TestCryptoProvider::new(),
+ Some(ETERNAL_LIFETIME),
+ )
+ .await;
+
+ client.generate_key_package_message().await.unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_welcome_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto_provider = TestCryptoProvider::new();
+ let Some(cs) = crypto_provider.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ for with_tree_in_extension in [true, false] {
+ for (with_psk, with_path) in [false, true].into_iter().cartesian_product([true, false])
+ {
+ let options = CommitOptions::new()
+ .with_path_required(with_path)
+ .with_ratchet_tree_extension(with_tree_in_extension);
+
+ let mut groups = get_test_groups(
+ VERSION,
+ cs.cipher_suite(),
+ 16,
+ Some(options),
+ false,
+ &crypto_provider,
+ )
+ .await;
+
+ // Remove a member s.t. the passive member joins in their place
+ let proposal = groups[0].propose_remove(7, vec![]).await.unwrap();
+ all_process_message(&mut groups, &proposal, 0, false).await;
+
+ let mut test_case = invite_passive_client(&mut groups, with_psk, &cs).await;
+
+ if !with_tree_in_extension {
+ let tree = groups[0].export_tree().to_bytes().unwrap();
+ test_case.ratchet_tree = Some(TestRatchetTree(tree));
+ }
+
+ test_cases.push(test_case);
+ }
+ }
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_passive_client_random_tests() -> Vec<TestCase> {
+ let mut test_cases: Vec<TestCase> = vec![];
+
+ for cs in CipherSuite::all() {
+ let crypto = TestCryptoProvider::new();
+ let Some(csp) = crypto.cipher_suite_provider(cs) else {
+ continue;
+ };
+
+ let creator =
+ generate_basic_client(cs, VERSION, 0, None, false, &crypto, Some(ETERNAL_LIFETIME))
+ .await;
+
+ let creator_group = creator.create_group(Default::default()).await.unwrap();
+
+ let mut groups = vec![creator_group];
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..10 {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ i + 1,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ )
+ }
+
+ add_random_members(0, &mut groups, new_clients, None).await;
+
+ let mut test_case = invite_passive_client(&mut groups, false, &csp).await;
+
+ let passive_client_index = 11;
+
+ let seed: <rand::rngs::StdRng as SeedableRng>::Seed = rand::random();
+ let mut rng = rand::rngs::StdRng::from_seed(seed);
+ #[cfg(feature = "std")]
+ println!("generating random commits for seed {}", hex::encode(seed));
+
+ let mut next_free_idx = 11;
+ for _ in 0..100 {
+ // We keep the passive client and another member to send
+ let num_removed = rng.gen_range(0..groups.len() - 2);
+ let num_added = rng.gen_range(1..30);
+
+ let mut members = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose_multiple(&mut rng, num_removed + 1);
+
+ let sender = members.pop().unwrap();
+
+ remove_members(members, sender, &mut groups, Some(&mut test_case)).await;
+
+ let sender = (0..groups.len())
+ .filter(|i| groups[*i].current_member_index() != passive_client_index)
+ .choose(&mut rng)
+ .unwrap();
+
+ let mut new_clients = Vec::new();
+
+ for i in 0..num_added {
+ new_clients.push(
+ generate_basic_client(
+ cs,
+ VERSION,
+ next_free_idx + i,
+ None,
+ false,
+ &crypto,
+ Some(ETERNAL_LIFETIME),
+ )
+ .await,
+ );
+ }
+
+ add_random_members(sender, &mut groups, new_clients, Some(&mut test_case)).await;
+
+ next_free_idx += num_added;
+ }
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn add_random_members<C: MlsConfig>(
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ clients: Vec<Client<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let committer_index = groups[committer].current_member_index() as usize;
+
+ let mut key_packages = Vec::new();
+
+ for client in &clients {
+ let key_package = client.generate_key_package_message().await.unwrap();
+ key_packages.push(key_package);
+ }
+
+ let mut add_proposals = Vec::new();
+
+ let committer_group = &mut groups[committer];
+
+ for key_package in key_packages {
+ add_proposals.push(
+ committer_group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap(),
+ );
+ }
+
+ for p in &add_proposals {
+ all_process_message(groups, p, committer_index, false).await;
+ }
+
+ let commit_output = groups[committer].commit(vec![]).await.unwrap();
+
+ all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(add_proposals, &commit_output.commit_message, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let tree_data = groups[committer].export_tree().into_owned();
+
+ for client in &clients {
+ let commit = commit_output.welcome_messages[0].clone();
+
+ let group = client
+ .join_group(Some(tree_data.clone()), &commit)
+ .await
+ .unwrap()
+ .0;
+
+ groups.push(group);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn remove_members<C: MlsConfig>(
+ removed_members: Vec<usize>,
+ committer: usize,
+ groups: &mut Vec<Group<C>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let remove_indexes = removed_members
+ .iter()
+ .map(|removed| groups[*removed].current_member_index())
+ .collect::<Vec<u32>>();
+
+ let mut commit_builder = groups[committer].commit_builder();
+
+ for index in remove_indexes {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap().commit_message;
+ let committer_index = groups[committer].current_member_index() as usize;
+ all_process_message(groups, &commit, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let mut index = 0;
+
+ groups.retain(|_| {
+ index += 1;
+ !(removed_members.contains(&(index - 1)))
+ });
+}
diff --git a/src/group/interop_test_vectors/serialization.rs b/src/group/interop_test_vectors/serialization.rs
new file mode 100644
index 0000000..cbaf6fa
--- /dev/null
+++ b/src/group/interop_test_vectors/serialization.rs
@@ -0,0 +1,169 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+
+use mls_rs_core::extension::ExtensionList;
+
+use crate::{
+ group::{
+ framing::ContentType,
+ proposal::{
+ AddProposal, ExternalInit, PreSharedKeyProposal, ReInitProposal, RemoveProposal,
+ UpdateProposal,
+ },
+ Commit, GroupSecrets, MlsMessage,
+ },
+ tree_kem::node::NodeVec,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestCase {
+ #[serde(with = "hex::serde")]
+ mls_welcome: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_group_info: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ mls_key_package: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_secrets: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ add_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ update_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ remove_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pre_shared_key_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ re_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_init_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ group_context_extensions_proposal: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ commit: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ public_message_application: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_proposal: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public_message_commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ private_message: Vec<u8>,
+}
+
+// The test vector can be found here:
+// https://github.com/mlswg/mls-implementations/blob/main/test-vectors/messages.json
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn serialization() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(serialization, Vec::<TestCase>::new());
+
+ for test_case in test_cases.into_iter() {
+ let message = MlsMessage::from_bytes(&test_case.mls_welcome).unwrap();
+ message.clone().into_welcome().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_welcome);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_group_info).unwrap();
+ message.clone().into_group_info().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_group_info);
+
+ let message = MlsMessage::from_bytes(&test_case.mls_key_package).unwrap();
+ message.clone().into_key_package().unwrap();
+ assert_eq!(&message.to_bytes().unwrap(), &test_case.mls_key_package);
+
+ let tree = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ assert_eq!(&tree.mls_encode_to_vec().unwrap(), &test_case.ratchet_tree);
+
+ let secs = GroupSecrets::mls_decode(&mut &*test_case.group_secrets).unwrap();
+
+ assert_eq!(&secs.mls_encode_to_vec().unwrap(), &test_case.group_secrets);
+
+ let proposal = AddProposal::mls_decode(&mut &*test_case.add_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.add_proposal
+ );
+
+ let proposal = UpdateProposal::mls_decode(&mut &*test_case.update_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.update_proposal
+ );
+
+ let proposal = RemoveProposal::mls_decode(&mut &*test_case.remove_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.remove_proposal
+ );
+
+ let proposal = ReInitProposal::mls_decode(&mut &*test_case.re_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.re_init_proposal
+ );
+
+ let proposal =
+ PreSharedKeyProposal::mls_decode(&mut &*test_case.pre_shared_key_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.pre_shared_key_proposal
+ );
+
+ let proposal = ExternalInit::mls_decode(&mut &*test_case.external_init_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.external_init_proposal
+ );
+
+ let proposal =
+ ExtensionList::mls_decode(&mut &*test_case.group_context_extensions_proposal).unwrap();
+
+ assert_eq!(
+ &proposal.mls_encode_to_vec().unwrap(),
+ &test_case.group_context_extensions_proposal
+ );
+
+ let commit = Commit::mls_decode(&mut &*test_case.commit).unwrap();
+
+ assert_eq!(&commit.mls_encode_to_vec().unwrap(), &test_case.commit);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_application).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_application);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Application);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_proposal).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_proposal);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Proposal);
+
+ let message = MlsMessage::from_bytes(&test_case.public_message_commit).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.public_message_commit);
+ let content_type = message.into_plaintext().unwrap().content.content_type();
+ assert_eq!(content_type, ContentType::Commit);
+
+ let message = MlsMessage::from_bytes(&test_case.private_message).unwrap();
+ let serialized = message.mls_encode_to_vec().unwrap();
+ assert_eq!(&serialized, &test_case.private_message);
+ message.into_ciphertext().unwrap();
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_kem.rs b/src/group/interop_test_vectors/tree_kem.rs
new file mode 100644
index 0000000..0a04312
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_kem.rs
@@ -0,0 +1,185 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, framing::Content, message_processor::MessageProcessor,
+ message_signature::AuthenticatedContent, test_utils::GroupWithoutKeySchedule, Commit,
+ GroupContext, PathSecret, Sender,
+ },
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ node::{LeafIndex, NodeVec},
+ TreeKemPrivate, TreeKemPublic, UpdatePath,
+ },
+ WireFormat,
+};
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::MlsDecode;
+use mls_rs_core::{crypto::CipherSuiteProvider, extension::ExtensionList};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeKemTestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ epoch: u64,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ratchet_tree: Vec<u8>,
+
+ leaves_private: Vec<TestLeafPrivate>,
+ update_paths: Vec<TestUpdatePath>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestLeafPrivate {
+ index: u32,
+ #[serde(with = "hex::serde")]
+ encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signature_priv: Vec<u8>,
+ path_secrets: Vec<TestPathSecretPrivate>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestPathSecretPrivate {
+ node: u32,
+ #[serde(with = "hex::serde")]
+ path_secret: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TestUpdatePath {
+ sender: u32,
+ #[serde(with = "hex::serde")]
+ update_path: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_kem() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/treekem.json
+
+ let test_cases: Vec<TreeKemTestCase> =
+ load_test_case_json!(interop_tree_kem, Vec::<TreeKemTestCase>::new());
+
+ for test_case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ // Import the public ratchet tree
+ let nodes = NodeVec::mls_decode(&mut &*test_case.ratchet_tree).unwrap();
+
+ let mut tree =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ // Construct GroupContext
+ let group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs.cipher_suite(),
+ group_id: test_case.group_id,
+ epoch: test_case.epoch,
+ tree_hash: tree.tree_hash(&cs).await.unwrap(),
+ confirmed_transcript_hash: test_case.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ for leaf in test_case.leaves_private.iter() {
+ // Construct the private ratchet tree
+ let mut tree_private = TreeKemPrivate::new(LeafIndex(leaf.index));
+
+ // Set and validate HPKE keys on direct path
+ let path = tree.nodes.direct_copath(tree_private.self_index);
+
+ tree_private.secret_keys = Vec::new();
+
+ for dp in path {
+ let dp = dp.path;
+
+ let secret = leaf
+ .path_secrets
+ .iter()
+ .find_map(|s| (s.node == dp).then_some(s.path_secret.clone()));
+
+ let private_key = if let Some(secret) = secret {
+ let (secret_key, public_key) = PathSecret::from(secret)
+ .to_hpke_key_pair(&cs)
+ .await
+ .unwrap();
+
+ let tree_public = &tree.nodes.borrow_as_parent(dp).unwrap().public_key;
+ assert_eq!(&public_key, tree_public);
+
+ Some(secret_key)
+ } else {
+ None
+ };
+
+ tree_private.secret_keys.push(private_key);
+ }
+
+ // Set HPKE key for leaf
+ tree_private
+ .secret_keys
+ .insert(0, Some(leaf.encryption_priv.clone().into()));
+
+ let paths = test_case
+ .update_paths
+ .iter()
+ .filter(|path| path.sender != leaf.index);
+
+ for update_path in paths {
+ let mut group = GroupWithoutKeySchedule::new(cs.cipher_suite()).await;
+ group.state.context = group_context.clone();
+ group.state.public_tree = tree.clone();
+ group.private_tree = tree_private.clone();
+
+ let path = UpdatePath::mls_decode(&mut &*update_path.update_path).unwrap();
+
+ let commit = Commit {
+ proposals: vec![],
+ path: Some(path),
+ };
+
+ let mut auth_content = AuthenticatedContent::new(
+ &group_context,
+ Sender::Member(update_path.sender),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ vec![],
+ WireFormat::PublicMessage,
+ );
+
+ auth_content.auth.confirmation_tag = Some(ConfirmationTag::empty(&cs).await);
+
+ // Hack not to increment epoch
+ group.state.context.epoch -= 1;
+
+ group.process_commit(auth_content, None).await.unwrap();
+
+ // Check that we got the expected commit secret and correctly merged the update path.
+ // This implies that we computed the path secrets correctly.
+ let commit_secret = group.secrets.unwrap().1;
+
+ assert_eq!(&*commit_secret, &update_path.commit_secret);
+
+ let new_tree = &mut group.provisional_public_state.unwrap().public_tree;
+ let new_tree_hash = new_tree.tree_hash(&cs).await.unwrap();
+
+ assert_eq!(&new_tree_hash, &update_path.tree_hash_after);
+ }
+ }
+ }
+}
diff --git a/src/group/interop_test_vectors/tree_modifications.rs b/src/group/interop_test_vectors/tree_modifications.rs
new file mode 100644
index 0000000..a172e0c
--- /dev/null
+++ b/src/group/interop_test_vectors/tree_modifications.rs
@@ -0,0 +1,177 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::CipherSuite;
+
+use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ group::{
+ proposal::{AddProposal, Proposal, ProposalOrRef, RemoveProposal, UpdateProposal},
+ proposal_cache::test_utils::CommitReceiver,
+ proposal_ref::ProposalRef,
+ test_utils::TEST_GROUP,
+ LeafIndex, Sender, TreeKemPublic,
+ },
+ identity::basic::BasicIdentityProvider,
+ key_package::test_utils::test_key_package,
+ tree_kem::{
+ leaf_node::test_utils::default_properties, node::NodeVec, test_utils::TreeWithSigners,
+ },
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeModsTestCase {
+ #[serde(with = "hex::serde")]
+ pub tree_before: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub proposal: Vec<u8>,
+ pub proposal_sender: u32,
+ #[serde(with = "hex::serde")]
+ pub tree_after: Vec<u8>,
+}
+
+impl TreeModsTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn new(tree_before: TreeKemPublic, proposal: Proposal, proposal_sender: u32) -> Self {
+ let tree_after = apply_proposal(proposal.clone(), proposal_sender, &tree_before).await;
+
+ Self {
+ tree_before: tree_before.nodes.mls_encode_to_vec().unwrap(),
+ proposal: proposal.mls_encode_to_vec().unwrap(),
+ tree_after: tree_after.nodes.mls_encode_to_vec().unwrap(),
+ proposal_sender,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_tree_mods_tests() -> Vec<TreeModsTestCase> {
+ let mut test_vector = vec![];
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Update
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ let update = generate_update(6, &tree_before).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, update, 6).await);
+
+ // Add in the middle
+ let mut tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ tree_before.remove_member(3);
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end
+ let tree_before = TreeWithSigners::make_full_tree(6, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Add at the end, tree grows
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_add().await, 2).await);
+
+ // Remove in the middle
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(2), 2).await);
+
+ // Remove at the end
+ let tree_before = TreeWithSigners::make_full_tree(8, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(7), 2).await);
+
+ // Remove at the end, tree shrinks
+ let tree_before = TreeWithSigners::make_full_tree(9, &cs).await;
+ test_vector.push(TreeModsTestCase::new(tree_before.tree, generate_remove(8), 2).await);
+
+ test_vector
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+async fn tree_modifications_interop() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/tree-operations.json
+
+ // All test vectors use cipher suite 1
+ if try_test_cipher_suite_provider(*CipherSuite::CURVE25519_AES128).is_none() {
+ return;
+ }
+
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests());
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<TreeModsTestCase> =
+ load_test_case_json!(tree_modifications_interop, generate_tree_mods_tests().await);
+
+ for test_case in test_cases.into_iter() {
+ let nodes = NodeVec::mls_decode(&mut &*test_case.tree_before).unwrap();
+
+ let tree_before =
+ TreeKemPublic::import_node_data(nodes, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ let proposal = Proposal::mls_decode(&mut &*test_case.proposal).unwrap();
+
+ let tree_after = apply_proposal(proposal, test_case.proposal_sender, &tree_before).await;
+
+ let tree_after = tree_after.nodes.mls_encode_to_vec().unwrap();
+
+ assert_eq!(tree_after, test_case.tree_after);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn apply_proposal(
+ proposal: Proposal,
+ sender: u32,
+ tree_before: &TreeKemPublic,
+) -> TreeKemPublic {
+ let cs = test_cipher_suite_provider(CipherSuite::CURVE25519_AES128);
+ let p_ref = ProposalRef::new_fake(b"fake ref".to_vec());
+
+ CommitReceiver::new(tree_before, Sender::Member(0), LeafIndex(1), cs)
+ .cache(p_ref.clone(), proposal, Sender::Member(sender))
+ .receive(vec![ProposalOrRef::Reference(p_ref)])
+ .await
+ .unwrap()
+ .public_tree
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_add() -> Proposal {
+ let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "Roger").await;
+ Proposal::Add(Box::new(AddProposal { key_package }))
+}
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+fn generate_remove(i: u32) -> Proposal {
+ let to_remove = LeafIndex(i);
+ Proposal::Remove(RemoveProposal { to_remove })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_update(i: u32, tree: &TreeWithSigners) -> Proposal {
+ let signer = tree.signers[i as usize].as_ref().unwrap();
+ let mut leaf_node = tree.tree.get_leaf_node(LeafIndex(i)).unwrap().clone();
+
+ leaf_node
+ .update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ i,
+ default_properties(),
+ None,
+ signer,
+ )
+ .await
+ .unwrap();
+
+ Proposal::Update(UpdateProposal { leaf_node })
+}
diff --git a/src/group/key_schedule.rs b/src/group/key_schedule.rs
new file mode 100644
index 0000000..77c1d65
--- /dev/null
+++ b/src/group/key_schedule.rs
@@ -0,0 +1,988 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::extension::ExternalPubExt;
+use crate::group::{GroupContext, MembershipTag};
+use crate::psk::secret::PskSecret;
+#[cfg(feature = "psk")]
+use crate::psk::PreSharedKey;
+use crate::tree_kem::path_secret::PathSecret;
+use crate::CipherSuiteProvider;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use crate::group::SecretTree;
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use crate::crypto::{HpkeContextR, HpkeContextS, HpkePublicKey, HpkeSecretKey};
+
+use super::epoch::{EpochSecrets, SenderDataSecret};
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, PartialEq, Eq, Default, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct KeySchedule {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ exporter_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub authentication_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ external_secret: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ membership_key: Zeroizing<Vec<u8>>,
+ init_secret: InitSecret,
+}
+
+impl Debug for KeySchedule {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("KeySchedule")
+ .field(
+ "exporter_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.exporter_secret),
+ )
+ .field(
+ "authentication_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.authentication_secret),
+ )
+ .field(
+ "external_secret",
+ &mls_rs_core::debug::pretty_bytes(&self.external_secret),
+ )
+ .field(
+ "membership_key",
+ &mls_rs_core::debug::pretty_bytes(&self.membership_key),
+ )
+ .field("init_secret", &self.init_secret)
+ .finish()
+ }
+}
+
+pub(crate) struct KeyScheduleDerivationResult {
+ pub(crate) key_schedule: KeySchedule,
+ pub(crate) confirmation_key: Zeroizing<Vec<u8>>,
+ pub(crate) joiner_secret: JoinerSecret,
+ pub(crate) epoch_secrets: EpochSecrets,
+}
+
+impl KeySchedule {
+ pub fn new(init_secret: InitSecret) -> Self {
+ KeySchedule {
+ init_secret,
+ ..Default::default()
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_for_external<P: CipherSuiteProvider>(
+ &self,
+ kem_output: &[u8],
+ cipher_suite: &P,
+ ) -> Result<KeySchedule, MlsError> {
+ let (secret, public) = self.get_external_key_pair(cipher_suite).await?;
+
+ let init_secret =
+ InitSecret::decode_for_external(cipher_suite, kem_output, &secret, &public).await?;
+
+ Ok(KeySchedule::new(init_secret))
+ }
+
+ /// Returns the derived epoch as well as the joiner secret required for building welcome
+ /// messages
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_key_schedule<P: CipherSuiteProvider>(
+ last_key_schedule: &KeySchedule,
+ commit_secret: &PathSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ cipher_suite_provider: &P,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let joiner_seed = cipher_suite_provider
+ .kdf_extract(&last_key_schedule.init_secret.0, commit_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let joiner_secret = kdf_expand_with_label(
+ cipher_suite_provider,
+ &joiner_seed,
+ b"joiner",
+ &context.mls_encode_to_vec()?,
+ None,
+ )
+ .await?
+ .into();
+
+ let key_schedule_result = Self::from_joiner(
+ cipher_suite_provider,
+ &joiner_secret,
+ context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ psk_secret,
+ )
+ .await?;
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule: key_schedule_result.key_schedule,
+ confirmation_key: key_schedule_result.confirmation_key,
+ joiner_secret,
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ joiner_secret: &JoinerSecret,
+ context: &GroupContext,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ psk_secret: &PskSecret,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_seed =
+ get_pre_epoch_secret(cipher_suite_provider, psk_secret, joiner_secret).await?;
+ let context = context.mls_encode_to_vec()?;
+
+ let epoch_secret =
+ kdf_expand_with_label(cipher_suite_provider, &epoch_seed, b"epoch", &context, None)
+ .await?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_random_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let epoch_secret = cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Self::from_epoch_secret(
+ cipher_suite_provider,
+ &epoch_secret,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ epoch_secret: &[u8],
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree_size: u32,
+ ) -> Result<KeyScheduleDerivationResult, MlsError> {
+ let secrets_producer = SecretsProducer::new(cipher_suite_provider, epoch_secret);
+
+ let epoch_secrets = EpochSecrets {
+ #[cfg(feature = "psk")]
+ resumption_secret: PreSharedKey::from(secrets_producer.derive(b"resumption").await?),
+ sender_data_secret: SenderDataSecret::from(
+ secrets_producer.derive(b"sender data").await?,
+ ),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ secret_tree: SecretTree::new(
+ secret_tree_size,
+ secrets_producer.derive(b"encryption").await?,
+ ),
+ };
+
+ let key_schedule = Self {
+ exporter_secret: secrets_producer.derive(b"exporter").await?,
+ authentication_secret: secrets_producer.derive(b"authentication").await?,
+ external_secret: secrets_producer.derive(b"external").await?,
+ membership_key: secrets_producer.derive(b"membership").await?,
+ init_secret: InitSecret(secrets_producer.derive(b"init").await?),
+ };
+
+ Ok(KeyScheduleDerivationResult {
+ key_schedule,
+ confirmation_key: secrets_producer.derive(b"confirm").await?,
+ joiner_secret: Zeroizing::new(vec![]).into(),
+ epoch_secrets,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret<P: CipherSuiteProvider>(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ cipher_suite: &P,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let secret = kdf_derive_secret(cipher_suite, &self.exporter_secret, label).await?;
+
+ let context_hash = cipher_suite
+ .hash(context)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ kdf_expand_with_label(cipher_suite, &secret, b"exported", &context_hash, Some(len)).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_membership_tag<P: CipherSuiteProvider>(
+ &self,
+ content: &AuthenticatedContent,
+ context: &GroupContext,
+ cipher_suite_provider: &P,
+ ) -> Result<MembershipTag, MlsError> {
+ MembershipTag::create(
+ content,
+ context,
+ &self.membership_key,
+ cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
+ cipher_suite
+ .kem_derive(&self.external_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_external_key_pair_ext<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<ExternalPubExt, MlsError> {
+ let (_external_secret, external_pub) = self.get_external_key_pair(cipher_suite).await?;
+
+ Ok(ExternalPubExt { external_pub })
+ }
+}
+
+#[derive(MlsEncode, MlsSize)]
+struct Label<'a> {
+ length: u16,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ context: &'a [u8],
+}
+
+impl<'a> Label<'a> {
+ fn new(length: u16, label: &'a [u8], context: &'a [u8]) -> Self {
+ Self {
+ length,
+ label: [b"MLS 1.0 ", label].concat(),
+ context,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_expand_with_label<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+ context: &[u8],
+ len: Option<usize>,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let extract_size = cipher_suite_provider.kdf_extract_size();
+ let len = len.unwrap_or(extract_size);
+ let label = Label::new(len as u16, label, context);
+
+ cipher_suite_provider
+ .kdf_expand(secret, &label.mls_encode_to_vec()?, len)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn kdf_derive_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ label: &[u8],
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None).await
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub(crate) struct JoinerSecret(#[mls_codec(with = "mls_rs_codec::byte_vec")] Zeroizing<Vec<u8>>);
+
+impl Debug for JoinerSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("JoinerSecret")
+ .fmt(f)
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for JoinerSecret {
+ fn from(bytes: Zeroizing<Vec<u8>>) -> Self {
+ Self(bytes)
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_pre_epoch_secret<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ psk_secret: &PskSecret,
+ joiner_secret: &JoinerSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ cipher_suite_provider
+ .kdf_extract(&joiner_secret.0, psk_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+struct SecretsProducer<'a, P: CipherSuiteProvider> {
+ cipher_suite_provider: &'a P,
+ epoch_secret: &'a [u8],
+}
+
+impl<'a, P: CipherSuiteProvider> SecretsProducer<'a, P> {
+ fn new(cipher_suite_provider: &'a P, epoch_secret: &'a [u8]) -> Self {
+ Self {
+ cipher_suite_provider,
+ epoch_secret,
+ }
+ }
+
+ // TODO document somewhere in the crypto provider that the RFC defines the length of all secrets as
+ // KDF extract size but then inputs secrets as MAC keys etc, therefore, we require that these
+ // lengths match in the crypto provider
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive(&self, label: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_derive_secret(self.cipher_suite_provider, self.epoch_secret, label).await
+ }
+}
+
+const EXPORTER_CONTEXT: &[u8] = b"MLS 1.0 external init secret";
+
+#[derive(Clone, Eq, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct InitSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for InitSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InitSecret")
+ .fmt(f)
+ }
+}
+
+impl InitSecret {
+ /// Returns init secret and KEM output to be used when creating an external commit.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ external_pub: &HpkePublicKey,
+ ) -> Result<(Self, Vec<u8>), MlsError> {
+ let (kem_output, context) = cipher_suite
+ .hpke_setup_s(external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let init_secret = context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok((InitSecret(Zeroizing::new(init_secret)), kem_output))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decode_for_external<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ kem_output: &[u8],
+ external_secret: &HpkeSecretKey,
+ external_pub: &HpkePublicKey,
+ ) -> Result<Self, MlsError> {
+ let context = cipher_suite
+ .hpke_setup_r(kem_output, external_secret, external_pub, &[])
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ context
+ .export(EXPORTER_CONTEXT, cipher_suite.kdf_extract_size())
+ .await
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+pub(crate) struct WelcomeSecret<'a, P: CipherSuiteProvider> {
+ cipher_suite: &'a P,
+ key: Zeroizing<Vec<u8>>,
+ nonce: Zeroizing<Vec<u8>>,
+}
+
+impl<'a, P: CipherSuiteProvider> WelcomeSecret<'a, P> {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_joiner_secret(
+ cipher_suite: &'a P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+ ) -> Result<WelcomeSecret<'a, P>, MlsError> {
+ let welcome_secret = get_welcome_secret(cipher_suite, joiner_secret, psk_secret).await?;
+
+ let key_len = cipher_suite.aead_key_size();
+ let key = kdf_expand_with_label(cipher_suite, &welcome_secret, b"key", &[], Some(key_len))
+ .await?;
+
+ let nonce_len = cipher_suite.aead_nonce_size();
+
+ let nonce = kdf_expand_with_label(
+ cipher_suite,
+ &welcome_secret,
+ b"nonce",
+ &[],
+ Some(nonce_len),
+ )
+ .await?;
+
+ Ok(Self {
+ cipher_suite,
+ key,
+ nonce,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>, MlsError> {
+ self.cipher_suite
+ .aead_seal(&self.key, plaintext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn decrypt(&self, ciphertext: &[u8]) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ self.cipher_suite
+ .aead_open(&self.key, ciphertext, None, &self.nonce)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn get_welcome_secret<P: CipherSuiteProvider>(
+ cipher_suite: &P,
+ joiner_secret: &JoinerSecret,
+ psk_secret: &PskSecret,
+) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ let epoch_seed = get_pre_epoch_secret(cipher_suite, psk_secret, joiner_secret).await?;
+ kdf_derive_secret(cipher_suite, &epoch_seed, b"welcome").await
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
+
+ use super::{InitSecret, JoinerSecret, KeySchedule};
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use mls_rs_core::error::IntoAnyError;
+
+ #[cfg(all(feature = "rfc_compliant", not(mls_build_async)))]
+ use super::MlsError;
+
+ impl From<JoinerSecret> for Vec<u8> {
+ fn from(mut value: JoinerSecret) -> Self {
+ core::mem::take(&mut value.0)
+ }
+ }
+
+ pub(crate) fn get_test_key_schedule(cipher_suite: CipherSuite) -> KeySchedule {
+ let key_size = test_cipher_suite_provider(cipher_suite).kdf_extract_size();
+ let fake_secret = Zeroizing::new(vec![1u8; key_size]);
+
+ KeySchedule {
+ exporter_secret: fake_secret.clone(),
+ authentication_secret: fake_secret.clone(),
+ external_secret: fake_secret.clone(),
+ membership_key: fake_secret,
+ init_secret: InitSecret::new(vec![0u8; key_size]),
+ }
+ }
+
+ impl InitSecret {
+ pub fn new(init_secret: Vec<u8>) -> Self {
+ InitSecret(Zeroizing::new(init_secret))
+ }
+
+ #[cfg(all(feature = "rfc_compliant", test, not(mls_build_async)))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn random<P: CipherSuiteProvider>(cipher_suite: &P) -> Result<Self, MlsError> {
+ cipher_suite
+ .random_bytes_vec(cipher_suite.kdf_extract_size())
+ .map(Zeroizing::new)
+ .map(InitSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "rfc_compliant")]
+ impl KeySchedule {
+ pub fn set_membership_key(&mut self, key: Vec<u8>) {
+ self.membership_key = Zeroizing::new(key)
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::client::test_utils::TEST_PROTOCOL_VERSION;
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+ use crate::group::key_schedule::{
+ get_welcome_secret, kdf_derive_secret, kdf_expand_with_label,
+ };
+ use crate::group::GroupContext;
+ use alloc::string::String;
+ use alloc::vec::Vec;
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use mls_rs_core::extension::ExtensionList;
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ group::{
+ key_schedule::KeyScheduleDerivationResult, test_utils::random_bytes, InitSecret,
+ PskSecret,
+ },
+ };
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ use alloc::{string::ToString, vec};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+ use zeroize::Zeroizing;
+
+ use super::test_utils::get_test_key_schedule;
+ use super::KeySchedule;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ group_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ initial_init_secret: Vec<u8>,
+ epochs: Vec<KeyScheduleEpoch>,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleEpoch {
+ #[serde(with = "hex::serde")]
+ commit_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ psk_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmed_transcript_hash: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ group_context: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ joiner_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ welcome_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ init_secret: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ sender_data_secret: Vec<u8>,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ exporter_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ epoch_authenticator: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ external_secret: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ membership_key: Vec<u8>,
+ #[cfg(feature = "psk")]
+ #[serde(with = "hex::serde")]
+ resumption_psk: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ external_pub: Vec<u8>,
+
+ exporter: KeyScheduleExporter,
+ }
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct KeyScheduleExporter {
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_schedule() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(key_schedule_test_vector, generate_test_vector());
+
+ for test_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret.0 = Zeroizing::new(test_case.initial_init_secret);
+
+ for (i, epoch) in test_case.epochs.into_iter().enumerate() {
+ let context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: test_case.group_id.clone(),
+ epoch: i as u64,
+ tree_hash: epoch.tree_hash,
+ confirmed_transcript_hash: epoch.confirmed_transcript_hash.into(),
+ extensions: ExtensionList::new(),
+ };
+
+ assert_eq!(context.mls_encode_to_vec().unwrap(), epoch.group_context);
+
+ let psk = epoch.psk_secret.into();
+ let commit = epoch.commit_secret.into();
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit,
+ &context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk,
+ &cs_provider,
+ )
+ .await
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule;
+
+ let welcome =
+ get_welcome_secret(&cs_provider, &key_schedule_res.joiner_secret, &psk)
+ .await
+ .unwrap();
+
+ assert_eq!(*welcome, epoch.welcome_secret);
+
+ let expected: Vec<u8> = key_schedule_res.joiner_secret.into();
+ assert_eq!(epoch.joiner_secret, expected);
+
+ assert_eq!(&key_schedule.init_secret.0.to_vec(), &epoch.init_secret);
+
+ assert_eq!(
+ epoch.sender_data_secret,
+ *key_schedule_res.epoch_secrets.sender_data_secret.to_vec()
+ );
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ assert_eq!(
+ epoch.encryption_secret,
+ *key_schedule_res.epoch_secrets.secret_tree.get_root_secret()
+ );
+
+ assert_eq!(epoch.exporter_secret, key_schedule.exporter_secret.to_vec());
+
+ assert_eq!(
+ epoch.epoch_authenticator,
+ key_schedule.authentication_secret.to_vec()
+ );
+
+ assert_eq!(epoch.external_secret, key_schedule.external_secret.to_vec());
+
+ assert_eq!(
+ epoch.confirmation_key,
+ key_schedule_res.confirmation_key.to_vec()
+ );
+
+ assert_eq!(epoch.membership_key, key_schedule.membership_key.to_vec());
+
+ #[cfg(feature = "psk")]
+ {
+ let expected: Vec<u8> =
+ key_schedule_res.epoch_secrets.resumption_secret.to_vec();
+
+ assert_eq!(epoch.resumption_psk, expected);
+ }
+
+ let (_external_sec, external_pub) = key_schedule
+ .get_external_key_pair(&cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(epoch.external_pub, *external_pub);
+
+ let exp = epoch.exporter;
+
+ let exported = key_schedule
+ .export_secret(exp.label.as_bytes(), &exp.context, exp.length, &cs_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(exported.to_vec(), exp.secret);
+ }
+ }
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let key_size = cs_provider.kdf_extract_size();
+
+ let mut group_context = GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite: cs_provider.cipher_suite(),
+ group_id: b"my group 5".to_vec(),
+ epoch: 0,
+ tree_hash: random_bytes(key_size),
+ confirmed_transcript_hash: random_bytes(key_size).into(),
+ extensions: Default::default(),
+ };
+
+ let initial_init_secret = InitSecret::random(&cs_provider).unwrap();
+ let mut key_schedule = get_test_key_schedule(cs_provider.cipher_suite());
+ key_schedule.init_secret = initial_init_secret.clone();
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ key_schedule = key_schedule_res.key_schedule.clone();
+
+ let epoch1 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ group_context.epoch += 1;
+ group_context.confirmed_transcript_hash = random_bytes(key_size).into();
+ group_context.tree_hash = random_bytes(key_size);
+
+ let commit_secret = random_bytes(key_size).into();
+ let psk_secret = PskSecret::new(&cs_provider);
+
+ let key_schedule_res = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ 32,
+ &psk_secret,
+ &cs_provider,
+ )
+ .unwrap();
+
+ let epoch2 = KeyScheduleEpoch::new(
+ key_schedule_res,
+ psk_secret,
+ commit_secret.to_vec(),
+ &group_context,
+ &cs_provider,
+ );
+
+ let test_case = TestCase {
+ cipher_suite: cs_provider.cipher_suite().into(),
+ group_id: group_context.group_id.clone(),
+ initial_init_secret: initial_init_secret.0.to_vec(),
+ epochs: vec![epoch1, epoch2],
+ };
+
+ test_cases.push(test_case);
+ }
+
+ test_cases
+ }
+
+ #[cfg(not(all(not(mls_build_async), feature = "rfc_compliant")))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rfc_compliant"))]
+ impl KeyScheduleEpoch {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn new<P: CipherSuiteProvider>(
+ key_schedule_res: KeyScheduleDerivationResult,
+ psk_secret: PskSecret,
+ commit_secret: Vec<u8>,
+ group_context: &GroupContext,
+ cs: &P,
+ ) -> Self {
+ let (_external_sec, external_pub) = key_schedule_res
+ .key_schedule
+ .get_external_key_pair(cs)
+ .unwrap();
+
+ let mut exporter = KeyScheduleExporter {
+ label: "exporter label 15".to_string(),
+ context: b"exporter context".to_vec(),
+ length: 64,
+ secret: vec![],
+ };
+
+ exporter.secret = key_schedule_res
+ .key_schedule
+ .export_secret(
+ exporter.label.as_bytes(),
+ &exporter.context,
+ exporter.length,
+ cs,
+ )
+ .unwrap()
+ .to_vec();
+
+ let welcome_secret =
+ get_welcome_secret(cs, &key_schedule_res.joiner_secret, &psk_secret)
+ .unwrap()
+ .to_vec();
+
+ KeyScheduleEpoch {
+ commit_secret,
+ welcome_secret,
+ psk_secret: psk_secret.to_vec(),
+ group_context: group_context.mls_encode_to_vec().unwrap(),
+ joiner_secret: key_schedule_res.joiner_secret.into(),
+ init_secret: key_schedule_res.key_schedule.init_secret.0.to_vec(),
+ sender_data_secret: key_schedule_res.epoch_secrets.sender_data_secret.to_vec(),
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ encryption_secret: key_schedule_res.epoch_secrets.secret_tree.get_root_secret(),
+ exporter_secret: key_schedule_res.key_schedule.exporter_secret.to_vec(),
+ epoch_authenticator: key_schedule_res.key_schedule.authentication_secret.to_vec(),
+ external_secret: key_schedule_res.key_schedule.external_secret.to_vec(),
+ confirmation_key: key_schedule_res.confirmation_key.to_vec(),
+ membership_key: key_schedule_res.key_schedule.membership_key.to_vec(),
+ #[cfg(feature = "psk")]
+ resumption_psk: key_schedule_res.epoch_secrets.resumption_secret.to_vec(),
+ external_pub: external_pub.to_vec(),
+ exporter,
+ confirmed_transcript_hash: group_context.confirmed_transcript_hash.to_vec(),
+ tree_hash: group_context.tree_hash.clone(),
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct ExpandWithLabelTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct DeriveSecretTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ expand_with_label: ExpandWithLabelTestCase,
+ derive_secret: DeriveSecretTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ let test_exp = &test_case.expand_with_label;
+
+ let computed = kdf_expand_with_label(
+ &cs,
+ &test_exp.secret,
+ test_exp.label.as_bytes(),
+ &test_exp.context,
+ Some(test_exp.length),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_exp.out);
+
+ let test_derive = &test_case.derive_secret;
+
+ let computed =
+ kdf_derive_secret(&cs, &test_derive.secret, test_derive.label.as_bytes())
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &test_derive.out);
+ }
+ }
+ }
+}
diff --git a/src/group/membership_tag.rs b/src/group/membership_tag.rs
new file mode 100644
index 0000000..b28edea
--- /dev/null
+++ b/src/group/membership_tag.rs
@@ -0,0 +1,163 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::group::message_signature::{AuthenticatedContentTBS, FramedContentAuthData};
+use crate::group::GroupContext;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+use super::message_signature::AuthenticatedContent;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+struct AuthenticatedContentTBM<'a> {
+ content_tbs: AuthenticatedContentTBS<'a>,
+ auth: &'a FramedContentAuthData,
+}
+
+impl<'a> AuthenticatedContentTBM<'a> {
+ pub fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: &'a GroupContext,
+ ) -> AuthenticatedContentTBM<'a> {
+ AuthenticatedContentTBM {
+ content_tbs: AuthenticatedContentTBS::from_authenticated_content(
+ auth_content,
+ Some(group_context),
+ group_context.protocol_version,
+ ),
+ auth: &auth_content.auth,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub struct MembershipTag(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+impl Debug for MembershipTag {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MembershipTag")
+ .fmt(f)
+ }
+}
+
+impl Deref for MembershipTag {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MembershipTag {
+ fn from(m: Vec<u8>) -> Self {
+ Self(m)
+ }
+}
+
+impl MembershipTag {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ authenticated_content: &AuthenticatedContent,
+ group_context: &GroupContext,
+ membership_key: &[u8],
+ cipher_suite_provider: &P,
+ ) -> Result<Self, MlsError> {
+ let plaintext_tbm = AuthenticatedContentTBM::from_authenticated_content(
+ authenticated_content,
+ group_context,
+ );
+
+ let serialized_tbm = plaintext_tbm.mls_encode_to_vec()?;
+
+ let tag = cipher_suite_provider
+ .mac(membership_key, &serialized_tbm)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(MembershipTag(tag))
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider};
+ use crate::group::{
+ framing::test_utils::get_test_auth_content, test_utils::get_test_group_context,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::crypto::test_utils::TestCryptoProvider;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ tag: Vec<u8>,
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cipher_suite),
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cipher_suite),
+ )
+ .unwrap();
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ tag: tag.to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_cases() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(membership_tag, generate_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag() {
+ for case in load_test_cases() {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let tag = MembershipTag::create(
+ &get_test_auth_content(),
+ &get_test_group_context(1, cs_provider.cipher_suite()).await,
+ b"membership_key".as_ref(),
+ &test_cipher_suite_provider(cs_provider.cipher_suite()),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(**tag, case.tag);
+ }
+ }
+}
diff --git a/src/group/message_processor.rs b/src/group/message_processor.rs
new file mode 100644
index 0000000..8084a58
--- /dev/null
+++ b/src/group/message_processor.rs
@@ -0,0 +1,1039 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ commit_sender,
+ confirmation_tag::ConfirmationTag,
+ framing::{
+ ApplicationData, Content, ContentType, MlsMessage, MlsMessagePayload, PublicMessage, Sender,
+ },
+ message_signature::AuthenticatedContent,
+ mls_rules::{CommitDirection, MlsRules},
+ proposal_filter::ProposalBundle,
+ state::GroupState,
+ transcript_hash::InterimTranscriptHash,
+ transcript_hashes, validate_group_info_member, GroupContext, GroupInfo, Welcome,
+};
+use crate::{
+ client::MlsError,
+ key_package::validate_key_package_properties,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ path_secret::PathSecret,
+ validate_update_path, TreeKemPrivate, TreeKemPublic, ValidatedUpdatePath,
+ },
+ CipherSuiteProvider, KeyPackage,
+};
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_core::{
+ identity::IdentityProvider, protocol_version::ProtocolVersion, psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_ref::ProposalRef;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::group::proposal_cache::resolve_for_commit;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[cfg(feature = "custom_proposal")]
+use super::proposal_filter::ProposalInfo;
+
+#[cfg(feature = "state_update")]
+use mls_rs_core::{
+ crypto::CipherSuite,
+ group::{MemberUpdate, RosterUpdate},
+};
+
+#[cfg(all(feature = "state_update", feature = "psk"))]
+use mls_rs_core::psk::ExternalPskId;
+
+#[cfg(feature = "state_update")]
+use crate::tree_kem::UpdatePath;
+
+#[cfg(feature = "state_update")]
+use super::{member_from_key_package, member_from_leaf_node};
+
+#[cfg(all(feature = "state_update", feature = "custom_proposal"))]
+use super::proposal::CustomProposal;
+
+#[cfg(feature = "private_message")]
+use crate::group::framing::PrivateMessage;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[derive(Debug)]
+pub(crate) struct ProvisionalState {
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) group_context: GroupContext,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+//By default, the path field of a Commit MUST be populated. The path field MAY be omitted if
+//(a) it covers at least one proposal and (b) none of the proposals covered by the Commit are
+//of "path required" types. A proposal type requires a path if it cannot change the group
+//membership in a way that requires the forward secrecy and post-compromise security guarantees
+//that an UpdatePath provides. The only proposal types defined in this document that do not
+//require a path are:
+
+// add
+// psk
+// reinit
+pub(crate) fn path_update_required(proposals: &ProposalBundle) -> bool {
+ let res = proposals.external_init_proposals().first().is_some();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res || !proposals.update_proposals().is_empty();
+
+ res || proposals.length() == 0
+ || proposals.group_context_extensions_proposal().is_some()
+ || !proposals.remove_proposals().is_empty()
+}
+
+/// Representation of changes made by a [commit](crate::Group::commit).
+#[cfg(feature = "state_update")]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {
+ pub(crate) roster_update: RosterUpdate,
+ #[cfg(feature = "psk")]
+ pub(crate) added_psks: Vec<ExternalPskId>,
+ pub(crate) pending_reinit: Option<CipherSuite>,
+ pub(crate) active: bool,
+ pub(crate) epoch: u64,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) unused_proposals: Vec<crate::mls_rules::ProposalInfo<Proposal>>,
+}
+
+#[cfg(not(feature = "state_update"))]
+#[non_exhaustive]
+#[derive(Clone, Debug, PartialEq)]
+pub struct StateUpdate {}
+
+#[cfg(feature = "state_update")]
+impl StateUpdate {
+ /// Changes to the roster as a result of proposals.
+ pub fn roster_update(&self) -> &RosterUpdate {
+ &self.roster_update
+ }
+
+ #[cfg(feature = "psk")]
+ /// Pre-shared keys that have been added to the group.
+ pub fn added_psks(&self) -> &[ExternalPskId] {
+ &self.added_psks
+ }
+
+ /// Flag to indicate if the group is now pending reinitialization due to
+ /// receiving a [`ReInit`](crate::group::proposal::Proposal::ReInit)
+ /// proposal.
+ pub fn is_pending_reinit(&self) -> bool {
+ self.pending_reinit.is_some()
+ }
+
+ /// Flag to indicate the group is still active. This will be false if the
+ /// member processing the commit has been removed from the group.
+ pub fn is_active(&self) -> bool {
+ self.active
+ }
+
+ /// The new epoch of the group state.
+ pub fn new_epoch(&self) -> u64 {
+ self.epoch
+ }
+
+ /// Custom proposals that were committed to.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ /// Proposals that were received in the prior epoch but not committed to.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn unused_proposals(&self) -> &[crate::mls_rules::ProposalInfo<Proposal>] {
+ &self.unused_proposals
+ }
+
+ pub fn pending_reinit_ciphersuite(&self) -> Option<CipherSuite> {
+ self.pending_reinit
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+#[allow(clippy::large_enum_variant)]
+/// An event generated as a result of processing a message for a group with
+/// [`Group::process_incoming_message`](crate::group::Group::process_incoming_message).
+pub enum ReceivedMessage {
+ /// An application message was decrypted.
+ ApplicationMessage(ApplicationMessageDescription),
+ /// A new commit was processed creating a new group state.
+ Commit(CommitMessageDescription),
+ /// A proposal was received.
+ Proposal(ProposalMessageDescription),
+ /// Validated GroupInfo object
+ GroupInfo(GroupInfo),
+ /// Validated welcome message
+ Welcome,
+ /// Validated key package
+ KeyPackage(KeyPackage),
+}
+
+impl TryFrom<ApplicationMessageDescription> for ReceivedMessage {
+ type Error = MlsError;
+
+ fn try_from(value: ApplicationMessageDescription) -> Result<Self, Self::Error> {
+ Ok(ReceivedMessage::ApplicationMessage(value))
+ }
+}
+
+impl From<CommitMessageDescription> for ReceivedMessage {
+ fn from(value: CommitMessageDescription) -> Self {
+ ReceivedMessage::Commit(value)
+ }
+}
+
+impl From<ProposalMessageDescription> for ReceivedMessage {
+ fn from(value: ProposalMessageDescription) -> Self {
+ ReceivedMessage::Proposal(value)
+ }
+}
+
+impl From<GroupInfo> for ReceivedMessage {
+ fn from(value: GroupInfo) -> Self {
+ ReceivedMessage::GroupInfo(value)
+ }
+}
+
+impl From<Welcome> for ReceivedMessage {
+ fn from(_: Welcome) -> Self {
+ ReceivedMessage::Welcome
+ }
+}
+
+impl From<KeyPackage> for ReceivedMessage {
+ fn from(value: KeyPackage) -> Self {
+ ReceivedMessage::KeyPackage(value)
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq)]
+/// Description of a MLS application message.
+pub struct ApplicationMessageDescription {
+ /// Index of this user in the group state.
+ pub sender_index: u32,
+ /// Received application data.
+ data: ApplicationData,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for ApplicationMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ApplicationMessageDescription")
+ .field("sender_index", &self.sender_index)
+ .field("data", &self.data)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl ApplicationMessageDescription {
+ pub fn data(&self) -> &[u8] {
+ self.data.as_bytes()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq)]
+#[non_exhaustive]
+/// Description of a processed MLS commit message.
+pub struct CommitMessageDescription {
+ /// True if this is the result of an external commit.
+ pub is_external: bool,
+ /// The index in the group state of the member who performed this commit.
+ pub committer: u32,
+ /// A full description of group state changes as a result of this commit.
+ pub state_update: StateUpdate,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+}
+
+impl Debug for CommitMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CommitMessageDescription")
+ .field("is_external", &self.is_external)
+ .field("committer", &self.committer)
+ .field("state_update", &self.state_update)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+/// Proposal sender type.
+pub enum ProposalSender {
+ /// A current member of the group by index in the group state.
+ Member(u32),
+ /// An external entity by index within an
+ /// [`ExternalSendersExt`](crate::extension::built_in::ExternalSendersExt).
+ External(u32),
+ /// A new member proposing their addition to the group.
+ NewMember,
+}
+
+impl TryFrom<Sender> for ProposalSender {
+ type Error = MlsError;
+
+ fn try_from(value: Sender) -> Result<Self, Self::Error> {
+ match value {
+ Sender::Member(index) => Ok(Self::Member(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(index) => Ok(Self::External(index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Ok(Self::NewMember),
+ Sender::NewMemberCommit => Err(MlsError::InvalidSender),
+ }
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone)]
+#[non_exhaustive]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {
+ /// Sender of the proposal.
+ pub sender: ProposalSender,
+ /// Proposal content.
+ pub proposal: Proposal,
+ /// Plaintext authenticated data in the received MLS packet.
+ pub authenticated_data: Vec<u8>,
+ /// Proposal reference.
+ pub proposal_ref: ProposalRef,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalMessageDescription {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalMessageDescription")
+ .field("sender", &self.sender)
+ .field("proposal", &self.proposal)
+ .field(
+ "authenticated_data",
+ &mls_rs_core::debug::pretty_bytes(&self.authenticated_data),
+ )
+ .field("proposal_ref", &self.proposal_ref)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(MlsSize, MlsEncode, MlsDecode)]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) proposal_ref: ProposalRef,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl CachedProposal {
+ /// Deserialize the proposal
+ pub fn from_bytes(bytes: &[u8]) -> Result<Self, MlsError> {
+ Ok(Self::mls_decode(&mut &*bytes)?)
+ }
+
+ /// Serialize the proposal
+ pub fn to_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.mls_encode_to_vec()?)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalMessageDescription {
+ pub fn cached_proposal(self) -> CachedProposal {
+ let sender = match self.sender {
+ ProposalSender::Member(i) => Sender::Member(i),
+ ProposalSender::External(i) => Sender::External(i),
+ ProposalSender::NewMember => Sender::NewMemberProposal,
+ };
+
+ CachedProposal {
+ proposal: self.proposal,
+ proposal_ref: self.proposal_ref,
+ sender,
+ }
+ }
+
+ pub fn proposal_ref(&self) -> Vec<u8> {
+ self.proposal_ref.to_vec()
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone)]
+/// Description of a processed MLS proposal message.
+pub struct ProposalMessageDescription {}
+
+#[allow(clippy::large_enum_variant)]
+pub(crate) enum EventOrContent<E> {
+ #[cfg_attr(
+ not(all(feature = "private_message", feature = "external_client")),
+ allow(dead_code)
+ )]
+ Event(E),
+ Content(AuthenticatedContent),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+pub(crate) trait MessageProcessor: Send + Sync {
+ type OutputType: TryFrom<ApplicationMessageDescription, Error = MlsError>
+ + From<CommitMessageDescription>
+ + From<ProposalMessageDescription>
+ + From<GroupInfo>
+ + From<Welcome>
+ + From<KeyPackage>
+ + Send;
+
+ type MlsRules: MlsRules;
+ type IdentityProvider: IdentityProvider;
+ type CipherSuiteProvider: CipherSuiteProvider;
+ type PreSharedKeyStorage: PreSharedKeyStorage;
+
+ async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ ) -> Result<Self::OutputType, MlsError> {
+ self.process_incoming_message_with_time(
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ None,
+ )
+ .await
+ }
+
+ async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event_or_content = self.get_event_from_incoming_message(message).await?;
+
+ self.process_event_or_content(
+ event_or_content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await
+ }
+
+ async fn get_event_from_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.check_metadata(&message)?;
+
+ match message.payload {
+ MlsMessagePayload::Plain(plaintext) => {
+ self.verify_plaintext_authentication(plaintext).await
+ }
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(cipher_text) => self.process_ciphertext(&cipher_text).await,
+ MlsMessagePayload::GroupInfo(group_info) => {
+ validate_group_info_member(
+ self.group_state(),
+ message.version,
+ &group_info,
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ Ok(EventOrContent::Event(group_info.into()))
+ }
+ MlsMessagePayload::Welcome(welcome) => {
+ self.validate_welcome(&welcome, message.version)?;
+
+ Ok(EventOrContent::Event(welcome.into()))
+ }
+ MlsMessagePayload::KeyPackage(key_package) => {
+ self.validate_key_package(&key_package, message.version)
+ .await?;
+
+ Ok(EventOrContent::Event(key_package.into()))
+ }
+ }
+ }
+
+ async fn process_event_or_content(
+ &mut self,
+ event_or_content: EventOrContent<Self::OutputType>,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let msg = match event_or_content {
+ EventOrContent::Event(event) => event,
+ EventOrContent::Content(content) => {
+ self.process_auth_content(
+ content,
+ #[cfg(feature = "by_ref_proposal")]
+ cache_proposal,
+ time_sent,
+ )
+ .await?
+ }
+ };
+
+ Ok(msg)
+ }
+
+ async fn process_auth_content(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] cache_proposal: bool,
+ time_sent: Option<MlsTime>,
+ ) -> Result<Self::OutputType, MlsError> {
+ let event = match auth_content.content.content {
+ #[cfg(feature = "private_message")]
+ Content::Application(data) => {
+ let authenticated_data = auth_content.content.authenticated_data;
+ let sender = auth_content.content.sender;
+
+ self.process_application_message(data, sender, authenticated_data)
+ .and_then(Self::OutputType::try_from)
+ }
+ Content::Commit(_) => self
+ .process_commit(auth_content, time_sent)
+ .await
+ .map(Self::OutputType::from),
+ #[cfg(feature = "by_ref_proposal")]
+ Content::Proposal(ref proposal) => self
+ .process_proposal(&auth_content, proposal, cache_proposal)
+ .await
+ .map(Self::OutputType::from),
+ }?;
+
+ Ok(event)
+ }
+
+ #[cfg(feature = "private_message")]
+ fn process_application_message(
+ &self,
+ data: ApplicationData,
+ sender: Sender,
+ authenticated_data: Vec<u8>,
+ ) -> Result<ApplicationMessageDescription, MlsError> {
+ let Sender::Member(sender_index) = sender else {
+ return Err(MlsError::InvalidSender);
+ };
+
+ Ok(ApplicationMessageDescription {
+ authenticated_data,
+ sender_index,
+ data,
+ })
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn process_proposal(
+ &mut self,
+ auth_content: &AuthenticatedContent,
+ proposal: &Proposal,
+ cache_proposal: bool,
+ ) -> Result<ProposalMessageDescription, MlsError> {
+ let proposal_ref =
+ ProposalRef::from_content(self.cipher_suite_provider(), auth_content).await?;
+
+ let group_state = self.group_state_mut();
+
+ if cache_proposal {
+ let proposal_ref = proposal_ref.clone();
+
+ group_state.proposals.insert(
+ proposal_ref.clone(),
+ proposal.clone(),
+ auth_content.content.sender,
+ );
+ }
+
+ Ok(ProposalMessageDescription {
+ authenticated_data: auth_content.content.authenticated_data.clone(),
+ proposal: proposal.clone(),
+ sender: auth_content.content.sender.try_into()?,
+ proposal_ref,
+ })
+ }
+
+ #[cfg(feature = "state_update")]
+ async fn make_state_update(
+ &self,
+ provisional: &ProvisionalState,
+ path: Option<&UpdatePath>,
+ sender: LeafIndex,
+ ) -> Result<StateUpdate, MlsError> {
+ let added = provisional
+ .applied_proposals
+ .additions
+ .iter()
+ .zip(provisional.indexes_of_added_kpkgs.iter())
+ .map(|(p, index)| member_from_key_package(&p.proposal.key_package, *index))
+ .collect::<Vec<_>>();
+
+ let mut added = added;
+
+ let old_tree = &self.group_state().public_tree;
+
+ let removed = provisional
+ .applied_proposals
+ .removals
+ .iter()
+ .map(|p| {
+ let index = p.proposal.to_remove;
+ let node = old_tree.nodes.borrow_as_leaf(index)?;
+ Ok(member_from_leaf_node(node, index))
+ })
+ .collect::<Result<_, MlsError>>()?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut updated = provisional
+ .applied_proposals
+ .update_senders
+ .iter()
+ .map(|index| {
+ let prior = old_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ let new = provisional
+ .public_tree
+ .get_leaf_node(*index)
+ .map(|n| member_from_leaf_node(n, *index))?;
+
+ Ok::<_, MlsError>(MemberUpdate::new(prior, new))
+ })
+ .collect::<Result<Vec<_>, _>>()?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let mut updated = Vec::new();
+
+ if let Some(path) = path {
+ if !provisional
+ .applied_proposals
+ .external_initializations
+ .is_empty()
+ {
+ added.push(member_from_leaf_node(&path.leaf_node, sender))
+ } else {
+ let prior = old_tree
+ .get_leaf_node(sender)
+ .map(|n| member_from_leaf_node(n, sender))?;
+
+ let new = member_from_leaf_node(&path.leaf_node, sender);
+
+ updated.push(MemberUpdate::new(prior, new))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ let psks = provisional
+ .applied_proposals
+ .psks
+ .iter()
+ .filter_map(|psk| psk.proposal.external_psk_id().cloned())
+ .collect::<Vec<_>>();
+
+ let roster_update = RosterUpdate::new(added, removed, updated);
+
+ let update = StateUpdate {
+ roster_update,
+ #[cfg(feature = "psk")]
+ added_psks: psks,
+ pending_reinit: provisional
+ .applied_proposals
+ .reinitializations
+ .first()
+ .map(|ri| ri.proposal.new_cipher_suite()),
+ active: true,
+ epoch: provisional.group_context.epoch,
+ #[cfg(feature = "custom_proposal")]
+ custom_proposals: provisional.applied_proposals.custom_proposals.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals: provisional.unused_proposals.clone(),
+ };
+
+ Ok(update)
+ }
+
+ async fn process_commit(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ time_sent: Option<MlsTime>,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ if self.group_state().pending_reinit.is_some() {
+ return Err(MlsError::GroupUsedAfterReInit);
+ }
+
+ // Update the new GroupContext's confirmed and interim transcript hashes using the new Commit.
+ let (interim_transcript_hash, confirmed_transcript_hash) = transcript_hashes(
+ self.cipher_suite_provider(),
+ &self.group_state().interim_transcript_hash,
+ &auth_content,
+ )
+ .await?;
+
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ let commit = match auth_content.content.content {
+ Content::Commit(commit) => Ok(commit),
+ _ => Err(MlsError::UnexpectedMessageType),
+ }?;
+
+ #[cfg(not(any(feature = "private_message", feature = "by_ref_proposal")))]
+ let Content::Commit(commit) = auth_content.content.content;
+
+ let group_state = self.group_state();
+ let id_provider = self.identity_provider();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = group_state
+ .proposals
+ .resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let proposals = resolve_for_commit(auth_content.content.sender, commit.proposals)?;
+
+ let mut provisional_state = group_state
+ .apply_resolved(
+ auth_content.content.sender,
+ proposals,
+ commit.path.as_ref().map(|path| &path.leaf_node),
+ &id_provider,
+ self.cipher_suite_provider(),
+ &self.psk_storage(),
+ &self.mls_rules(),
+ time_sent,
+ CommitDirection::Receive,
+ )
+ .await?;
+
+ let sender = commit_sender(&auth_content.content.sender, &provisional_state)?;
+
+ #[cfg(feature = "state_update")]
+ let mut state_update = self
+ .make_state_update(&provisional_state, commit.path.as_ref(), sender)
+ .await?;
+
+ #[cfg(not(feature = "state_update"))]
+ let state_update = StateUpdate {};
+
+ //Verify that the path value is populated if the proposals vector contains any Update
+ // or Remove proposals, or if it's empty. Otherwise, the path value MAY be omitted.
+ if path_update_required(&provisional_state.applied_proposals) && commit.path.is_none() {
+ return Err(MlsError::CommitMissingPath);
+ }
+
+ if !self.can_continue_processing(&provisional_state) {
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+
+ return Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ });
+ }
+
+ let update_path = match commit.path {
+ Some(update_path) => Some(
+ validate_update_path(
+ &self.identity_provider(),
+ self.cipher_suite_provider(),
+ update_path,
+ &provisional_state,
+ sender,
+ time_sent,
+ )
+ .await?,
+ ),
+ None => None,
+ };
+
+ let new_secrets = match update_path {
+ Some(update_path) => {
+ self.apply_update_path(sender, &update_path, &mut provisional_state)
+ .await
+ }
+ None => Ok(None),
+ }?;
+
+ // Update the transcript hash to get the new context.
+ provisional_state.group_context.confirmed_transcript_hash = confirmed_transcript_hash;
+
+ // Update the parent hashes in the new context
+ provisional_state
+ .public_tree
+ .update_hashes(&[sender], self.cipher_suite_provider())
+ .await?;
+
+ // Update the tree hash in the new context
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(self.cipher_suite_provider())
+ .await?;
+
+ if let Some(reinit) = provisional_state.applied_proposals.reinitializations.pop() {
+ self.group_state_mut().pending_reinit = Some(reinit.proposal);
+
+ #[cfg(feature = "state_update")]
+ {
+ state_update.active = false;
+ }
+ }
+
+ if let Some(confirmation_tag) = &auth_content.auth.confirmation_tag {
+ // Update the key schedule to calculate new private keys
+ self.update_key_schedule(
+ new_secrets,
+ interim_transcript_hash,
+ confirmation_tag,
+ provisional_state,
+ )
+ .await?;
+
+ Ok(CommitMessageDescription {
+ is_external: matches!(auth_content.content.sender, Sender::NewMemberCommit),
+ authenticated_data: auth_content.content.authenticated_data,
+ committer: *sender,
+ state_update,
+ })
+ } else {
+ Err(MlsError::InvalidConfirmationTag)
+ }
+ }
+
+ fn group_state(&self) -> &GroupState;
+ fn group_state_mut(&mut self) -> &mut GroupState;
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex>;
+ fn mls_rules(&self) -> Self::MlsRules;
+ fn identity_provider(&self) -> Self::IdentityProvider;
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider;
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage;
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool;
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64>;
+
+ fn check_metadata(&self, message: &MlsMessage) -> Result<(), MlsError> {
+ let context = &self.group_state().context;
+
+ if message.version != context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if let Some((group_id, epoch, content_type)) = match &message.payload {
+ MlsMessagePayload::Plain(plaintext) => Some((
+ &plaintext.content.group_id,
+ plaintext.content.epoch,
+ plaintext.content.content_type(),
+ )),
+ #[cfg(feature = "private_message")]
+ MlsMessagePayload::Cipher(ciphertext) => Some((
+ &ciphertext.group_id,
+ ciphertext.epoch,
+ ciphertext.content_type,
+ )),
+ _ => None,
+ } {
+ if group_id != &context.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ match content_type {
+ ContentType::Commit => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => {
+ if context.epoch != epoch {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ }
+ #[cfg(feature = "private_message")]
+ ContentType::Application => {
+ if let Some(min) = self.min_epoch_available() {
+ if epoch < min {
+ Err(MlsError::InvalidEpoch)
+ } else {
+ Ok(())
+ }
+ } else {
+ Ok(())
+ }
+ }
+ }?;
+
+ // Proposal and commit messages must be sent in the current epoch
+ let check_epoch = content_type == ContentType::Commit;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let check_epoch = check_epoch || content_type == ContentType::Proposal;
+
+ if check_epoch && epoch != context.epoch {
+ return Err(MlsError::InvalidEpoch);
+ }
+
+ // Unencrypted application messages are not allowed
+ #[cfg(feature = "private_message")]
+ if !matches!(&message.payload, MlsMessagePayload::Cipher(_))
+ && content_type == ContentType::Application
+ {
+ return Err(MlsError::UnencryptedApplicationMessage);
+ }
+ }
+
+ Ok(())
+ }
+
+ fn validate_welcome(
+ &self,
+ welcome: &Welcome,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let state = self.group_state();
+
+ (welcome.cipher_suite == state.context.cipher_suite
+ && version == state.context.protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidWelcomeMessage)
+ }
+
+ async fn validate_key_package(
+ &self,
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ ) -> Result<(), MlsError> {
+ let cs = self.cipher_suite_provider();
+ let id = self.identity_provider();
+
+ validate_key_package(key_package, version, cs, &id).await
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError>;
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await
+ .map(|_| None)
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError>;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_key_package<C: CipherSuiteProvider, I: IdentityProvider>(
+ key_package: &KeyPackage,
+ version: ProtocolVersion,
+ cs: &C,
+ id: &I,
+) -> Result<(), MlsError> {
+ let validator = LeafNodeValidator::new(cs, id, None);
+
+ #[cfg(feature = "std")]
+ let context = Some(MlsTime::now());
+
+ #[cfg(not(feature = "std"))]
+ let context = None;
+
+ let context = ValidationContext::Add(context);
+
+ validator
+ .check_if_valid(&key_package.leaf_node, context)
+ .await?;
+
+ validate_key_package_properties(key_package, version, cs).await?;
+
+ Ok(())
+}
diff --git a/src/group/message_signature.rs b/src/group/message_signature.rs
new file mode 100644
index 0000000..3c08935
--- /dev/null
+++ b/src/group/message_signature.rs
@@ -0,0 +1,274 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::framing::Content;
+use crate::client::MlsError;
+use crate::crypto::SignatureSecretKey;
+use crate::group::framing::{ContentType, FramedContent, PublicMessage, Sender, WireFormat};
+use crate::group::{ConfirmationTag, GroupContext};
+use crate::signer::Signable;
+use crate::CipherSuiteProvider;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::protocol_version::ProtocolVersion;
+
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct FramedContentAuthData {
+ pub signature: MessageSignature,
+ pub confirmation_tag: Option<ConfirmationTag>,
+}
+
+impl MlsSize for FramedContentAuthData {
+ fn mls_encoded_len(&self) -> usize {
+ self.signature.mls_encoded_len()
+ + self
+ .confirmation_tag
+ .as_ref()
+ .map_or(0, |tag| tag.mls_encoded_len())
+ }
+}
+
+impl MlsEncode for FramedContentAuthData {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.signature.mls_encode(writer)?;
+
+ if let Some(ref tag) = self.confirmation_tag {
+ tag.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl FramedContentAuthData {
+ pub(crate) fn mls_decode(
+ reader: &mut &[u8],
+ content_type: ContentType,
+ ) -> Result<Self, mls_rs_codec::Error> {
+ Ok(FramedContentAuthData {
+ signature: MessageSignature::mls_decode(reader)?,
+ confirmation_tag: match content_type {
+ ContentType::Commit => Some(ConfirmationTag::mls_decode(reader)?),
+ #[cfg(feature = "private_message")]
+ ContentType::Application => None,
+ #[cfg(feature = "by_ref_proposal")]
+ ContentType::Proposal => None,
+ },
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct AuthenticatedContent {
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: FramedContent,
+ pub(crate) auth: FramedContentAuthData,
+}
+
+impl From<PublicMessage> for AuthenticatedContent {
+ fn from(p: PublicMessage) -> Self {
+ Self {
+ wire_format: WireFormat::PublicMessage,
+ content: p.content,
+ auth: p.auth,
+ }
+ }
+}
+
+impl AuthenticatedContent {
+ pub(crate) fn new(
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ authenticated_data: Vec<u8>,
+ wire_format: WireFormat,
+ ) -> AuthenticatedContent {
+ AuthenticatedContent {
+ wire_format,
+ content: FramedContent {
+ group_id: context.group_id.clone(),
+ epoch: context.epoch,
+ sender,
+ authenticated_data,
+ content,
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::empty(),
+ confirmation_tag: None,
+ },
+ }
+ }
+
+ #[inline(never)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new_signed<P: CipherSuiteProvider>(
+ signature_provider: &P,
+ context: &GroupContext,
+ sender: Sender,
+ content: Content,
+ signer: &SignatureSecretKey,
+ wire_format: WireFormat,
+ authenticated_data: Vec<u8>,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ // Construct an MlsPlaintext object containing the content
+ let mut plaintext =
+ AuthenticatedContent::new(context, sender, content, authenticated_data, wire_format);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ // Sign the MlsPlaintext using the current epoch's GroupContext as context.
+ plaintext
+ .sign(signature_provider, signer, &signing_context)
+ .await?;
+
+ Ok(plaintext)
+ }
+}
+
+impl MlsDecode for AuthenticatedContent {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let wire_format = WireFormat::mls_decode(reader)?;
+ let content = FramedContent::mls_decode(reader)?;
+ let auth_data = FramedContentAuthData::mls_decode(reader, content.content_type())?;
+
+ Ok(AuthenticatedContent {
+ wire_format,
+ content,
+ auth: auth_data,
+ })
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub(crate) struct AuthenticatedContentTBS<'a> {
+ pub(crate) protocol_version: ProtocolVersion,
+ pub(crate) wire_format: WireFormat,
+ pub(crate) content: &'a FramedContent,
+ pub(crate) context: Option<&'a GroupContext>,
+}
+
+impl<'a> MlsSize for AuthenticatedContentTBS<'a> {
+ fn mls_encoded_len(&self) -> usize {
+ self.protocol_version.mls_encoded_len()
+ + self.wire_format.mls_encoded_len()
+ + self.content.mls_encoded_len()
+ + self.context.as_ref().map_or(0, |ctx| ctx.mls_encoded_len())
+ }
+}
+
+impl<'a> MlsEncode for AuthenticatedContentTBS<'a> {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.protocol_version.mls_encode(writer)?;
+ self.wire_format.mls_encode(writer)?;
+ self.content.mls_encode(writer)?;
+
+ if let Some(context) = self.context {
+ context.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+impl<'a> AuthenticatedContentTBS<'a> {
+ /// The group context must not be `None` when the sender is `Member` or `NewMember`.
+ pub(crate) fn from_authenticated_content(
+ auth_content: &'a AuthenticatedContent,
+ group_context: Option<&'a GroupContext>,
+ protocol_version: ProtocolVersion,
+ ) -> Self {
+ AuthenticatedContentTBS {
+ protocol_version,
+ wire_format: auth_content.wire_format,
+ content: &auth_content.content,
+ context: match auth_content.content.sender {
+ Sender::Member(_) | Sender::NewMemberCommit => group_context,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => None,
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => None,
+ },
+ }
+ }
+}
+
+#[derive(Debug)]
+pub(crate) struct MessageSigningContext<'a> {
+ pub group_context: Option<&'a GroupContext>,
+ pub protocol_version: ProtocolVersion,
+}
+
+impl<'a> Signable<'a> for AuthenticatedContent {
+ const SIGN_LABEL: &'static str = "FramedContentTBS";
+
+ type SigningContext = MessageSigningContext<'a>;
+
+ fn signature(&self) -> &[u8] {
+ &self.auth.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &MessageSigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ AuthenticatedContentTBS::from_authenticated_content(
+ self,
+ context.group_context,
+ context.protocol_version,
+ )
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.auth.signature = MessageSignature::from(signature)
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct MessageSignature(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for MessageSignature {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("MessageSignature")
+ .fmt(f)
+ }
+}
+
+impl MessageSignature {
+ pub(crate) fn empty() -> Self {
+ MessageSignature(vec![])
+ }
+}
+
+impl Deref for MessageSignature {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for MessageSignature {
+ fn from(v: Vec<u8>) -> Self {
+ MessageSignature(v)
+ }
+}
diff --git a/src/group/message_verifier.rs b/src/group/message_verifier.rs
new file mode 100644
index 0000000..7a2bc59
--- /dev/null
+++ b/src/group/message_verifier.rs
@@ -0,0 +1,680 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "by_ref_proposal")]
+use alloc::{vec, vec::Vec};
+
+use crate::{
+ client::MlsError,
+ crypto::SignaturePublicKey,
+ group::{GroupContext, PublicMessage, Sender},
+ signer::Signable,
+ tree_kem::{node::LeafIndex, TreeKemPublic},
+ CipherSuiteProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{extension::ExternalSendersExt, identity::SigningIdentity};
+
+use super::{
+ key_schedule::KeySchedule,
+ message_signature::{AuthenticatedContent, MessageSigningContext},
+ state::GroupState,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal::Proposal;
+
+#[derive(Debug)]
+pub(crate) enum SignaturePublicKeysContainer<'a> {
+ RatchetTree(&'a TreeKemPublic),
+ #[cfg(feature = "private_message")]
+ List(&'a [Option<SignaturePublicKey>]),
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_plaintext_authentication<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ plaintext: PublicMessage,
+ key_schedule: Option<&KeySchedule>,
+ self_index: Option<LeafIndex>,
+ state: &GroupState,
+) -> Result<AuthenticatedContent, MlsError> {
+ let tag = plaintext.membership_tag.clone();
+ let auth_content = AuthenticatedContent::from(plaintext);
+ let context = &state.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_signers = external_signers(context);
+
+ let current_tree = &state.public_tree;
+
+ // Verify the membership tag if needed
+ match &auth_content.content.sender {
+ Sender::Member(index) => {
+ if let Some(key_schedule) = key_schedule {
+ let expected_tag = &key_schedule
+ .get_membership_tag(&auth_content, context, cipher_suite_provider)
+ .await?;
+
+ let plaintext_tag = tag.as_ref().ok_or(MlsError::InvalidMembershipTag)?;
+
+ if expected_tag != plaintext_tag {
+ return Err(MlsError::InvalidMembershipTag);
+ }
+ }
+
+ if self_index == Some(LeafIndex(*index)) {
+ return Err(MlsError::CantProcessMessageFromSelf);
+ }
+ }
+ _ => {
+ tag.is_none()
+ .then_some(())
+ .ok_or(MlsError::MembershipTagForNonMember)?;
+ }
+ }
+
+ // Verify that the signature on the MLSAuthenticatedContent verifies using the public key
+ // from the credential stored at the leaf in the tree indicated by the sender field.
+ verify_auth_content_signature(
+ cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(current_tree),
+ context,
+ &auth_content,
+ #[cfg(feature = "by_ref_proposal")]
+ &external_signers,
+ )
+ .await?;
+
+ Ok(auth_content)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn external_signers(context: &GroupContext) -> Vec<SigningIdentity> {
+ context
+ .extensions
+ .get_as::<ExternalSendersExt>()
+ .unwrap_or(None)
+ .map_or(vec![], |extern_senders_ext| {
+ extern_senders_ext.allowed_senders
+ })
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn verify_auth_content_signature<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ signature_keys_container: SignaturePublicKeysContainer<'_>,
+ context: &GroupContext,
+ auth_content: &AuthenticatedContent,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<(), MlsError> {
+ let sender_public_key = signing_identity_for_sender(
+ signature_keys_container,
+ &auth_content.content.sender,
+ &auth_content.content.content,
+ #[cfg(feature = "by_ref_proposal")]
+ external_signers,
+ )?;
+
+ let context = MessageSigningContext {
+ group_context: Some(context),
+ protocol_version: context.protocol_version,
+ };
+
+ auth_content
+ .verify(cipher_suite_provider, &sender_public_key, &context)
+ .await?;
+
+ Ok(())
+}
+
+fn signing_identity_for_sender(
+ signature_keys_container: SignaturePublicKeysContainer,
+ sender: &Sender,
+ content: &super::framing::Content,
+ #[cfg(feature = "by_ref_proposal")] external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ match sender {
+ Sender::Member(leaf_index) => {
+ signing_identity_for_member(signature_keys_container, LeafIndex(*leaf_index))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(external_key_index) => {
+ signing_identity_for_external(*external_key_index, external_signers)
+ }
+ Sender::NewMemberCommit => signing_identity_for_new_member_commit(content),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => signing_identity_for_new_member_proposal(content),
+ }
+}
+
+fn signing_identity_for_member(
+ signature_keys_container: SignaturePublicKeysContainer,
+ leaf_index: LeafIndex,
+) -> Result<SignaturePublicKey, MlsError> {
+ match signature_keys_container {
+ SignaturePublicKeysContainer::RatchetTree(tree) => Ok(tree
+ .get_leaf_node(leaf_index)?
+ .signing_identity
+ .signature_key
+ .clone()), // TODO: We can probably get rid of this clone
+ #[cfg(feature = "private_message")]
+ SignaturePublicKeysContainer::List(list) => list
+ .get(leaf_index.0 as usize)
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::LeafNotFound(*leaf_index)),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_external(
+ index: u32,
+ external_signers: &[SigningIdentity],
+) -> Result<SignaturePublicKey, MlsError> {
+ external_signers
+ .get(index as usize)
+ .map(|spk| spk.signature_key.clone())
+ .ok_or(MlsError::UnknownSigningIdentityForExternalSender)
+}
+
+fn signing_identity_for_new_member_commit(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Commit(commit) => {
+ if let Some(path) = &commit.path {
+ Ok(path.leaf_node.signing_identity.signature_key.clone())
+ } else {
+ Err(MlsError::CommitMissingPath)
+ }
+ }
+ #[cfg(any(feature = "private_message", feature = "by_ref_proposal"))]
+ _ => Err(MlsError::ExpectedCommitForNewMemberCommit),
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn signing_identity_for_new_member_proposal(
+ content: &super::framing::Content,
+) -> Result<SignaturePublicKey, MlsError> {
+ match content {
+ super::framing::Content::Proposal(proposal) => {
+ if let Proposal::Add(p) = proposal.as_ref() {
+ Ok(p.key_package
+ .leaf_node
+ .signing_identity
+ .signature_key
+ .clone())
+ } else {
+ Err(MlsError::ExpectedAddProposalForNewMemberProposal)
+ }
+ }
+ _ => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::{
+ test_utils::{test_client_with_key_pkg, TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ MlsError,
+ },
+ client_builder::test_utils::TestClientConfig,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ membership_tag::MembershipTag,
+ message_signature::{AuthenticatedContent, MessageSignature},
+ test_utils::{test_group_custom, TestGroup},
+ Group, PublicMessage,
+ },
+ tree_kem::node::LeafIndex,
+ };
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{extension::ExternalSendersExt, ExtensionList};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ crypto::SignatureSecretKey,
+ group::{
+ message_signature::MessageSigningContext,
+ proposal::{AddProposal, Proposal, RemoveProposal},
+ Content,
+ },
+ key_package::KeyPackageGeneration,
+ signer::Signable,
+ WireFormat,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use alloc::boxed::Box;
+
+ use crate::group::{
+ test_utils::{test_group, test_member},
+ Sender,
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::identity::test_utils::get_test_signing_identity;
+
+ use super::{verify_auth_content_signature, verify_plaintext_authentication};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_signed_plaintext(group: &mut Group<TestClientConfig>) -> PublicMessage {
+ group
+ .commit(vec![])
+ .await
+ .unwrap()
+ .commit_message
+ .into_plaintext()
+ .unwrap()
+ }
+
+ struct TestEnv {
+ alice: TestGroup,
+ bob: TestGroup,
+ }
+
+ impl TestEnv {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new() -> Self {
+ let mut alice = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ None,
+ )
+ .await;
+
+ let (bob_client, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit_output = alice
+ .group
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.group.apply_pending_commit().await.unwrap();
+
+ let (bob, _) = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .unwrap();
+
+ TestEnv {
+ alice,
+ bob: TestGroup { group: bob },
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_plaintext_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_auth_content_is_verified() {
+ let mut env = TestEnv::new().await;
+
+ let message = AuthenticatedContent::from(make_signed_plaintext(&mut env.alice.group).await);
+
+ verify_auth_content_signature(
+ &env.bob.group.cipher_suite_provider,
+ super::SignaturePublicKeysContainer::RatchetTree(&env.bob.group.state.public_tree),
+ env.bob.group.context(),
+ &message,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_plaintext_is_not_verified() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.auth.signature = MessageSignature::from(b"test".to_vec());
+
+ message.membership_tag = env
+ .alice
+ .group
+ .key_schedule
+ .get_membership_tag(
+ &AuthenticatedContent::from(message.clone()),
+ env.alice.group.context(),
+ &test_cipher_suite_provider(env.alice.group.cipher_suite()),
+ )
+ .await
+ .unwrap()
+ .into();
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_member_requires_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = None;
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_fails_with_invalid_membership_tag() {
+ let mut env = TestEnv::new().await;
+ let mut message = make_signed_plaintext(&mut env.alice.group).await;
+ message.membership_tag = Some(MembershipTag::from(b"test".to_vec()));
+
+ let res = verify_plaintext_authentication(
+ &env.bob.group.cipher_suite_provider,
+ message,
+ Some(&env.bob.group.key_schedule),
+ None,
+ &env.bob.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidMembershipTag));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_new_member_proposal<F>(
+ key_pkg_gen: KeyPackageGeneration,
+ signer: &SignatureSecretKey,
+ test_group: &TestGroup,
+ mut edit: F,
+ ) -> PublicMessage
+ where
+ F: FnMut(&mut AuthenticatedContent),
+ {
+ let mut content = AuthenticatedContent::new_signed(
+ &test_group.group.cipher_suite_provider,
+ test_group.group.context(),
+ Sender::NewMemberProposal,
+ Content::Proposal(Box::new(Proposal::Add(Box::new(AddProposal {
+ key_package: key_pkg_gen.key_package,
+ })))),
+ signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .await
+ .unwrap();
+
+ edit(&mut content);
+
+ let signing_context = MessageSigningContext {
+ group_context: Some(test_group.group.context()),
+ protocol_version: test_group.group.protocol_version(),
+ };
+
+ content
+ .sign(
+ &test_group.group.cipher_suite_provider,
+ signer,
+ &signing_context,
+ )
+ .await
+ .unwrap();
+
+ PublicMessage {
+ content: content.content,
+ auth: content.auth,
+ membership_tag: None,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_new_member_is_verified() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_new_member_must_not_have_membership_tag() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let mut message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |_| {}).await;
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_proposal_sender_must_be_add_proposal() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.content = Content::Proposal(Box::new(Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(0),
+ })))
+ })
+ .await;
+
+ let res: Result<AuthenticatedContent, MlsError> = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedAddProposalForNewMemberProposal));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_be_external_commit() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (key_pkg_gen, signer) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let message = test_new_member_proposal(key_pkg_gen, &signer, &test_group, |msg| {
+ msg.content.sender = Sender::NewMemberCommit;
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExpectedCommitForNewMemberCommit));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn valid_proposal_from_external_is_verified() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (ted_signing, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let mut extensions = ExtensionList::default();
+
+ extensions
+ .set_from(ExternalSendersExt {
+ allowed_senders: vec![ted_signing],
+ })
+ .unwrap();
+
+ test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(extensions)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_proposal_must_be_from_valid_sender() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let message = test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |msg| {
+ msg.content.sender = Sender::External(0)
+ })
+ .await;
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnknownSigningIdentityForExternalSender));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_from_external_sender_must_not_have_membership_tag() {
+ let (bob_key_pkg_gen, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let (_, ted_secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"ted").await;
+
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut message =
+ test_new_member_proposal(bob_key_pkg_gen, &ted_secret, &test_group, |_| {}).await;
+
+ message.membership_tag = Some(MembershipTag::from(vec![]));
+
+ let res = verify_plaintext_authentication(
+ &test_group.group.cipher_suite_provider,
+ message,
+ Some(&test_group.group.key_schedule),
+ None,
+ &test_group.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn plaintext_from_self_fails_verification() {
+ let mut env = TestEnv::new().await;
+
+ let message = make_signed_plaintext(&mut env.alice.group).await;
+
+ let res = verify_plaintext_authentication(
+ &env.alice.group.cipher_suite_provider,
+ message,
+ Some(&env.alice.group.key_schedule),
+ Some(LeafIndex::new(env.alice.group.current_member_index())),
+ &env.alice.group.state,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::CantProcessMessageFromSelf))
+ }
+}
diff --git a/src/group/mls_rules.rs b/src/group/mls_rules.rs
new file mode 100644
index 0000000..98b1dac
--- /dev/null
+++ b/src/group/mls_rules.rs
@@ -0,0 +1,283 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::group::{proposal_filter::ProposalBundle, Roster};
+
+#[cfg(feature = "private_message")]
+use crate::{
+ group::{padding::PaddingMode, Sender},
+ WireFormat,
+};
+
+use alloc::boxed::Box;
+use core::convert::Infallible;
+use mls_rs_core::{
+ error::IntoAnyError, extension::ExtensionList, group::Member, identity::SigningIdentity,
+};
+
+#[derive(Copy, Clone, Debug, PartialEq, Eq)]
+pub enum CommitDirection {
+ Send,
+ Receive,
+}
+
+/// The source of the commit: either a current member or a new member joining
+/// via external commit.
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub enum CommitSource {
+ ExistingMember(Member),
+ NewMember(SigningIdentity),
+}
+
+/// Options controlling commit generation
+#[derive(Clone, Copy, Debug, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct CommitOptions {
+ pub path_required: bool,
+ pub ratchet_tree_extension: bool,
+ pub single_welcome_message: bool,
+ pub allow_external_commit: bool,
+}
+
+impl Default for CommitOptions {
+ fn default() -> Self {
+ CommitOptions {
+ path_required: false,
+ ratchet_tree_extension: true,
+ single_welcome_message: true,
+ allow_external_commit: false,
+ }
+ }
+}
+
+impl CommitOptions {
+ pub fn new() -> Self {
+ Self::default()
+ }
+
+ pub fn with_path_required(self, path_required: bool) -> Self {
+ Self {
+ path_required,
+ ..self
+ }
+ }
+
+ pub fn with_ratchet_tree_extension(self, ratchet_tree_extension: bool) -> Self {
+ Self {
+ ratchet_tree_extension,
+ ..self
+ }
+ }
+
+ pub fn with_single_welcome_message(self, single_welcome_message: bool) -> Self {
+ Self {
+ single_welcome_message,
+ ..self
+ }
+ }
+
+ pub fn with_allow_external_commit(self, allow_external_commit: bool) -> Self {
+ Self {
+ allow_external_commit,
+ ..self
+ }
+ }
+}
+
+/// Options controlling encryption of control and application messages
+#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
+#[non_exhaustive]
+pub struct EncryptionOptions {
+ #[cfg(feature = "private_message")]
+ pub encrypt_control_messages: bool,
+ #[cfg(feature = "private_message")]
+ pub padding_mode: PaddingMode,
+}
+
+#[cfg(feature = "private_message")]
+impl EncryptionOptions {
+ pub fn new(encrypt_control_messages: bool, padding_mode: PaddingMode) -> Self {
+ Self {
+ encrypt_control_messages,
+ padding_mode,
+ }
+ }
+
+ pub(crate) fn control_wire_format(&self, sender: Sender) -> WireFormat {
+ match sender {
+ Sender::Member(_) if self.encrypt_control_messages => WireFormat::PrivateMessage,
+ _ => WireFormat::PublicMessage,
+ }
+ }
+}
+
+/// A set of user controlled rules that customize the behavior of MLS.
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+pub trait MlsRules: Send + Sync {
+ type Error: IntoAnyError;
+
+ /// This is called when preparing or receiving a commit to pre-process the set of committed
+ /// proposals.
+ ///
+ /// Both proposals received during the current epoch and at the time of commit
+ /// will be presented for validation and filtering. Filter and validate will
+ /// present a raw list of proposals. Standard MLS rules are applied internally
+ /// on the result of these rules.
+ ///
+ /// Each member of a group MUST apply the same proposal rules in order to
+ /// maintain a working group.
+ ///
+ /// Typically, any invalid proposal should result in an error. The exception are invalid
+ /// by-reference proposals processed when _preparing_ a commit, which should be filtered
+ /// out instead. This is to avoid the deadlock situation when no commit can be generated
+ /// after receiving an invalid set of proposal messages.
+ ///
+ /// `ProposalBundle` can be arbitrarily modified. For example, a Remove proposal that
+ /// removes a moderator can result in adding a GroupContextExtensions proposal that updates
+ /// the moderator list in the group context. The resulting `ProposalBundle` is validated
+ /// by the library.
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error>;
+
+ /// This is called when preparing a commit to determine various options: whether to enforce an update
+ /// path in case it is not mandated by MLS, whether to include the ratchet tree in the welcome
+ /// message (if the commit adds members) and whether to generate a single welcome message, or one
+ /// welcome message for each added member.
+ ///
+ /// The `new_roster` and `new_extension_list` describe the group state after the commit.
+ fn commit_options(
+ &self,
+ new_roster: &Roster,
+ new_extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error>;
+
+ /// This is called when sending any packet. For proposals and commits, this determines whether to
+ /// encrypt them. For any encrypted packet, this determines the padding mode used.
+ ///
+ /// Note that for commits, the `current_roster` and `current_extension_list` describe the group state
+ /// before the commit, unlike in [commit_options](MlsRules::commit_options).
+ fn encryption_options(
+ &self,
+ current_roster: &Roster,
+ current_extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error>;
+}
+
+macro_rules! delegate_mls_rules {
+ ($implementer:ty) => {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl<T: MlsRules + ?Sized> MlsRules for $implementer {
+ type Error = T::Error;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ &self,
+ direction: CommitDirection,
+ source: CommitSource,
+ current_roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ (**self)
+ .filter_proposals(direction, source, current_roster, extension_list, proposals)
+ .await
+ }
+
+ fn commit_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ (**self).commit_options(roster, extension_list, proposals)
+ }
+
+ fn encryption_options(
+ &self,
+ roster: &Roster,
+ extension_list: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ (**self).encryption_options(roster, extension_list)
+ }
+ }
+ };
+}
+
+delegate_mls_rules!(Box<T>);
+delegate_mls_rules!(&T);
+
+#[derive(Clone, Debug, Default)]
+#[non_exhaustive]
+/// Default MLS rules with pass-through proposal filter and customizable options.
+pub struct DefaultMlsRules {
+ pub commit_options: CommitOptions,
+ pub encryption_options: EncryptionOptions,
+}
+
+impl DefaultMlsRules {
+ /// Create new MLS rules with default settings: do not enforce path and do
+ /// put the ratchet tree in the extension.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Set commit options.
+ pub fn with_commit_options(self, commit_options: CommitOptions) -> Self {
+ Self {
+ commit_options,
+ encryption_options: self.encryption_options,
+ }
+ }
+
+ /// Set encryption options.
+ pub fn with_encryption_options(self, encryption_options: EncryptionOptions) -> Self {
+ Self {
+ commit_options: self.commit_options,
+ encryption_options,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl MlsRules for DefaultMlsRules {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _direction: CommitDirection,
+ _source: CommitSource,
+ _current_roster: &Roster,
+ _extension_list: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Ok(proposals)
+ }
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(self.commit_options)
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(self.encryption_options)
+ }
+}
diff --git a/src/group/mod.rs b/src/group/mod.rs
new file mode 100644
index 0000000..0d84a84
--- /dev/null
+++ b/src/group/mod.rs
@@ -0,0 +1,4236 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use mls_rs_core::secret::Secret;
+use mls_rs_core::time::MlsTime;
+
+use crate::cipher_suite::CipherSuite;
+use crate::client::MlsError;
+use crate::client_config::ClientConfig;
+use crate::crypto::{HpkeCiphertext, SignatureSecretKey};
+use crate::extension::RatchetTreeExt;
+use crate::identity::SigningIdentity;
+use crate::key_package::{KeyPackage, KeyPackageRef};
+use crate::protocol_version::ProtocolVersion;
+use crate::psk::secret::PskSecret;
+use crate::psk::PreSharedKeyID;
+use crate::signer::Signable;
+use crate::tree_kem::hpke_encryption::HpkeEncryptable;
+use crate::tree_kem::kem::TreeKem;
+use crate::tree_kem::node::LeafIndex;
+use crate::tree_kem::path_secret::PathSecret;
+pub use crate::tree_kem::Capabilities;
+use crate::tree_kem::{
+ leaf_node::LeafNode,
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+};
+use crate::tree_kem::{math as tree_math, ValidatedUpdatePath};
+use crate::tree_kem::{TreeKemPrivate, TreeKemPublic};
+use crate::{CipherSuiteProvider, CryptoProvider};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::crypto::{HpkePublicKey, HpkeSecretKey};
+
+use crate::extension::ExternalPubExt;
+
+#[cfg(feature = "private_message")]
+use self::mls_rules::{EncryptionOptions, MlsRules};
+
+#[cfg(feature = "psk")]
+pub use self::resumption::ReinitClient;
+
+#[cfg(feature = "psk")]
+use crate::psk::{
+ resolver::PskResolver, secret::PskSecretInput, ExternalPskId, JustPreSharedKeyID, PskGroupId,
+ ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "private_message")]
+use ciphertext_processor::*;
+
+use confirmation_tag::*;
+use framing::*;
+use key_schedule::*;
+use membership_tag::*;
+use message_signature::*;
+use message_verifier::*;
+use proposal::*;
+#[cfg(feature = "by_ref_proposal")]
+use proposal_cache::*;
+use state::*;
+use transcript_hash::*;
+
+#[cfg(test)]
+pub(crate) use self::commit::test_utils::CommitModifiers;
+
+#[cfg(all(test, feature = "private_message"))]
+pub use self::framing::PrivateMessage;
+
+#[cfg(feature = "psk")]
+use self::proposal_filter::ProposalInfo;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+use secret_tree::*;
+
+#[cfg(feature = "prior_epoch")]
+use self::epoch::PriorEpoch;
+
+use self::epoch::EpochSecrets;
+pub use self::message_processor::{
+ ApplicationMessageDescription, CommitMessageDescription, ProposalMessageDescription,
+ ProposalSender, ReceivedMessage, StateUpdate,
+};
+use self::message_processor::{EventOrContent, MessageProcessor, ProvisionalState};
+#[cfg(feature = "by_ref_proposal")]
+use self::proposal_ref::ProposalRef;
+use self::state_repo::GroupStateRepository;
+pub use group_info::GroupInfo;
+
+pub use self::framing::{ContentType, Sender};
+pub use commit::*;
+pub use context::GroupContext;
+pub use roster::*;
+
+pub(crate) use transcript_hash::ConfirmedTranscriptHash;
+pub(crate) use util::*;
+
+#[cfg(all(feature = "by_ref_proposal", feature = "external_client"))]
+pub use self::message_processor::CachedProposal;
+
+#[cfg(feature = "private_message")]
+mod ciphertext_processor;
+
+mod commit;
+pub(crate) mod confirmation_tag;
+mod context;
+pub(crate) mod epoch;
+pub(crate) mod framing;
+mod group_info;
+pub(crate) mod key_schedule;
+mod membership_tag;
+pub(crate) mod message_processor;
+pub(crate) mod message_signature;
+pub(crate) mod message_verifier;
+pub mod mls_rules;
+#[cfg(feature = "private_message")]
+pub(crate) mod padding;
+/// Proposals to evolve a MLS [`Group`]
+pub mod proposal;
+mod proposal_cache;
+pub(crate) mod proposal_filter;
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) mod proposal_ref;
+#[cfg(feature = "psk")]
+mod resumption;
+mod roster;
+pub(crate) mod snapshot;
+pub(crate) mod state;
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) mod state_repo;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) mod state_repo_light;
+#[cfg(not(feature = "prior_epoch"))]
+pub(crate) use state_repo_light as state_repo;
+
+pub(crate) mod transcript_hash;
+mod util;
+
+/// External commit building.
+pub mod external_commit;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub(crate) mod secret_tree;
+
+#[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+pub use secret_tree::MessageKeyData as MessageKey;
+
+#[cfg(all(test, feature = "rfc_compliant"))]
+mod interop_test_vectors;
+
+mod exported_tree;
+
+pub use exported_tree::ExportedTree;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+struct GroupSecrets {
+ joiner_secret: JoinerSecret,
+ path_secret: Option<PathSecret>,
+ psks: Vec<PreSharedKeyID>,
+}
+
+impl HpkeEncryptable for GroupSecrets {
+ const ENCRYPT_LABEL: &'static str = "Welcome";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Self::mls_decode(&mut bytes.as_slice()).map_err(Into::into)
+ }
+
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ self.mls_encode_to_vec().map_err(Into::into)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct EncryptedGroupSecrets {
+ pub new_member: KeyPackageRef,
+ pub encrypted_group_secrets: HpkeCiphertext,
+}
+
+#[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+pub(crate) struct Welcome {
+ pub cipher_suite: CipherSuite,
+ pub secrets: Vec<EncryptedGroupSecrets>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub encrypted_group_info: Vec<u8>,
+}
+
+impl Debug for Welcome {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("Welcome")
+ .field("cipher_suite", &self.cipher_suite)
+ .field("secrets", &self.secrets)
+ .field(
+ "encrypted_group_info",
+ &mls_rs_core::debug::pretty_bytes(&self.encrypted_group_info),
+ )
+ .finish()
+ }
+}
+
+#[derive(Clone, Debug)]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[non_exhaustive]
+/// Information provided to new members upon joining a group.
+pub struct NewMemberInfo {
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ pub group_info_extensions: ExtensionList,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl NewMemberInfo {
+ pub(crate) fn new(group_info_extensions: ExtensionList) -> Self {
+ let mut new_member_info = Self {
+ group_info_extensions,
+ };
+
+ new_member_info.ungrease();
+
+ new_member_info
+ }
+
+ /// Group info extensions found within the Welcome message used to join
+ /// the group.
+ #[cfg(feature = "ffi")]
+ pub fn group_info_extensions(&self) -> &ExtensionList {
+ &self.group_info_extensions
+ }
+}
+
+/// An MLS end-to-end encrypted group.
+///
+/// # Group Evolution
+///
+/// MLS Groups are evolved via a propose-then-commit system. Each group state
+/// produced by a commit is called an epoch and can produce and consume
+/// application, proposal, and commit messages. A [commit](Group::commit) is used
+/// to advance to the next epoch by applying existing proposals sent in
+/// the current epoch by-reference along with an optional set of proposals
+/// that are included by-value using a [`CommitBuilder`].
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone)]
+pub struct Group<C>
+where
+ C: ClientConfig,
+{
+ config: C,
+ cipher_suite_provider: <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider,
+ state_repo: GroupStateRepository<C::GroupStateStorage, C::KeyPackageRepository>,
+ pub(crate) state: GroupState,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>, // Hash of leaf node hpke public key to secret key
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ #[cfg(feature = "psk")]
+ previous_psk: Option<PskSecretInput>,
+ #[cfg(test)]
+ pub(crate) commit_modifiers: CommitModifiers,
+ pub(crate) signer: SignatureSecretKey,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn new(
+ config: C,
+ group_id: Option<Vec<u8>>,
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ signing_identity: SigningIdentity,
+ group_context_extensions: ExtensionList,
+ signer: SignatureSecretKey,
+ ) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(config.crypto_provider(), cipher_suite)?;
+
+ let (leaf_node, leaf_node_secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ config.leaf_properties(),
+ signing_identity,
+ &signer,
+ config.lifetime(),
+ )
+ .await?;
+
+ let identity_provider = config.identity_provider();
+
+ let leaf_node_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &identity_provider,
+ Some(&group_context_extensions),
+ );
+
+ leaf_node_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await?;
+
+ let (mut public_tree, private_tree) = TreeKemPublic::derive(
+ leaf_node,
+ leaf_node_secret,
+ &config.identity_provider(),
+ &group_context_extensions,
+ )
+ .await?;
+
+ let tree_hash = public_tree.tree_hash(&cipher_suite_provider).await?;
+
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ let context = GroupContext::new_group(
+ protocol_version,
+ cipher_suite,
+ group_id,
+ tree_hash,
+ group_context_extensions,
+ );
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ let key_schedule_result = KeySchedule::from_random_epoch_secret(
+ &cipher_suite_provider,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ )
+ .await?;
+
+ let confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &vec![].into(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ let interim_hash = InterimTranscriptHash::create(
+ &cipher_suite_provider,
+ &vec![].into(),
+ &confirmation_tag,
+ )
+ .await?;
+
+ Ok(Self {
+ config,
+ state: GroupState::new(context, public_tree, interim_hash, confirmation_tag),
+ private_tree,
+ key_schedule: key_schedule_result.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: key_schedule_result.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ Self::from_welcome_message(
+ welcome,
+ tree_data,
+ config,
+ signer,
+ #[cfg(feature = "psk")]
+ None,
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn from_welcome_message(
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ config: C,
+ signer: SignatureSecretKey,
+ #[cfg(feature = "psk")] additional_psk: Option<PskSecretInput>,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let protocol_version = welcome.version;
+
+ if !config.version_supported(protocol_version) {
+ return Err(MlsError::UnsupportedProtocolVersion(protocol_version));
+ }
+
+ let MlsMessagePayload::Welcome(welcome) = &welcome.payload else {
+ return Err(MlsError::UnexpectedMessageType);
+ };
+
+ let cipher_suite_provider =
+ cipher_suite_provider(config.crypto_provider(), welcome.cipher_suite)?;
+
+ let (encrypted_group_secrets, key_package_generation) =
+ find_key_package_generation(&config.key_package_repo(), &welcome.secrets).await?;
+
+ let key_package_version = key_package_generation.key_package.version;
+
+ if key_package_version != protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ // Decrypt the encrypted_group_secrets using HPKE with the algorithms indicated by the
+ // cipher suite and the HPKE private key corresponding to the GroupSecrets. If a
+ // PreSharedKeyID is part of the GroupSecrets and the client is not in possession of
+ // the corresponding PSK, return an error
+ let group_secrets = GroupSecrets::decrypt(
+ &cipher_suite_provider,
+ &key_package_generation.init_secret_key,
+ &key_package_generation.key_package.hpke_init_key,
+ &welcome.encrypted_group_info,
+ &encrypted_group_secrets.encrypted_group_secrets,
+ )
+ .await?;
+
+ #[cfg(feature = "psk")]
+ let psk_secret = if let Some(psk) = additional_psk {
+ let psk_id = group_secrets
+ .psks
+ .first()
+ .ok_or(MlsError::UnexpectedPskId)?;
+
+ match &psk_id.key_id {
+ JustPreSharedKeyID::Resumption(r) if r.usage != ResumptionPSKUsage::Application => {
+ Ok(())
+ }
+ _ => Err(MlsError::UnexpectedPskId),
+ }?;
+
+ let mut psk = psk;
+ psk.id.psk_nonce = psk_id.psk_nonce.clone();
+ PskSecret::calculate(&[psk], &cipher_suite_provider).await?
+ } else {
+ PskResolver::<
+ <C as ClientConfig>::GroupStateStorage,
+ <C as ClientConfig>::KeyPackageRepository,
+ <C as ClientConfig>::PskStore,
+ > {
+ group_context: None,
+ current_epoch: None,
+ prior_epochs: None,
+ psk_store: &config.secret_store(),
+ }
+ .resolve_to_secret(&group_secrets.psks, &cipher_suite_provider)
+ .await?
+ };
+
+ #[cfg(not(feature = "psk"))]
+ let psk_secret = PskSecret::new(&cipher_suite_provider);
+
+ // From the joiner_secret in the decrypted GroupSecrets object and the PSKs specified in
+ // the GroupSecrets, derive the welcome_secret and using that the welcome_key and
+ // welcome_nonce.
+ let welcome_secret = WelcomeSecret::from_joiner_secret(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &psk_secret,
+ )
+ .await?;
+
+ // Use the key and nonce to decrypt the encrypted_group_info field.
+ let decrypted_group_info = welcome_secret
+ .decrypt(&welcome.encrypted_group_info)
+ .await?;
+
+ let group_info = GroupInfo::mls_decode(&mut &**decrypted_group_info)?;
+
+ let public_tree = validate_group_info_joiner(
+ protocol_version,
+ &group_info,
+ tree_data,
+ &config.identity_provider(),
+ &cipher_suite_provider,
+ )
+ .await?;
+
+ // Identify a leaf in the tree array (any even-numbered node) whose leaf_node is identical
+ // to the leaf_node field of the KeyPackage. If no such field exists, return an error. Let
+ // index represent the index of this node among the leaves in the tree, namely the index of
+ // the node in the tree array divided by two.
+ let self_index = public_tree
+ .find_leaf_node(&key_package_generation.key_package.leaf_node)
+ .ok_or(MlsError::WelcomeKeyPackageNotFound)?;
+
+ let used_key_package_ref = key_package_generation.reference;
+
+ let mut private_tree =
+ TreeKemPrivate::new_self_leaf(self_index, key_package_generation.leaf_node_secret_key);
+
+ // If the path_secret value is set in the GroupSecrets object
+ if let Some(path_secret) = group_secrets.path_secret {
+ private_tree
+ .update_secrets(
+ &cipher_suite_provider,
+ group_info.signer,
+ path_secret,
+ &public_tree,
+ )
+ .await?;
+ }
+
+ // Use the joiner_secret from the GroupSecrets object to generate the epoch secret and
+ // other derived secrets for the current epoch.
+ let key_schedule_result = KeySchedule::from_joiner(
+ &cipher_suite_provider,
+ &group_secrets.joiner_secret,
+ &group_info.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ public_tree.total_leaf_count(),
+ &psk_secret,
+ )
+ .await?;
+
+ // Verify the confirmation tag in the GroupInfo using the derived confirmation key and the
+ // confirmed_transcript_hash from the GroupInfo.
+ if !group_info
+ .confirmation_tag
+ .matches(
+ &key_schedule_result.confirmation_key,
+ &group_info.group_context.confirmed_transcript_hash,
+ &cipher_suite_provider,
+ )
+ .await?
+ {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ Self::join_with(
+ config,
+ group_info,
+ public_tree,
+ key_schedule_result.key_schedule,
+ key_schedule_result.epoch_secrets,
+ private_tree,
+ Some(used_key_package_ref),
+ signer,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn join_with(
+ config: C,
+ group_info: GroupInfo,
+ public_tree: TreeKemPublic,
+ key_schedule: KeySchedule,
+ epoch_secrets: EpochSecrets,
+ private_tree: TreeKemPrivate,
+ used_key_package_ref: Option<KeyPackageRef>,
+ signer: SignatureSecretKey,
+ ) -> Result<(Self, NewMemberInfo), MlsError> {
+ let cs = group_info.group_context.cipher_suite;
+
+ let cs = config
+ .crypto_provider()
+ .cipher_suite_provider(cs)
+ .ok_or(MlsError::UnsupportedCipherSuite(cs))?;
+
+ // Use the confirmed transcript hash and confirmation tag to compute the interim transcript
+ // hash in the new state.
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ &cs,
+ &group_info.group_context.confirmed_transcript_hash,
+ &group_info.confirmation_tag,
+ )
+ .await?;
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ group_info.group_context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ used_key_package_ref,
+ )?;
+
+ let group = Group {
+ config,
+ state: GroupState::new(
+ group_info.group_context,
+ public_tree,
+ interim_transcript_hash,
+ group_info.confirmation_tag,
+ ),
+ private_tree,
+ key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets,
+ state_repo,
+ cipher_suite_provider: cs,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer,
+ };
+
+ Ok((group, NewMemberInfo::new(group_info.extensions)))
+ }
+
+ #[inline(always)]
+ pub(crate) fn current_epoch_tree(&self) -> &TreeKemPublic {
+ &self.state.public_tree
+ }
+
+ /// The current epoch of the group. This value is incremented each
+ /// time a [`Group::commit`] message is processed.
+ #[inline(always)]
+ pub fn current_epoch(&self) -> u64 {
+ self.context().epoch
+ }
+
+ /// Index within the group's state for the local group instance.
+ ///
+ /// This index corresponds to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ #[inline(always)]
+ pub fn current_member_index(&self) -> u32 {
+ self.private_tree.self_index.0
+ }
+
+ fn current_user_leaf_node(&self) -> Result<&LeafNode, MlsError> {
+ self.current_epoch_tree()
+ .get_leaf_node(self.private_tree.self_index)
+ }
+
+ /// Signing identity currently in use by the local group instance.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn current_member_signing_identity(&self) -> Result<&SigningIdentity, MlsError> {
+ self.current_user_leaf_node().map(|ln| &ln.signing_identity)
+ }
+
+ /// Member at a specific index in the group state.
+ ///
+ /// These indexes correspond to indexes in content descriptions within
+ /// [`ReceivedMessage`].
+ pub fn member_at_index(&self, index: u32) -> Option<Member> {
+ let leaf_index = LeafIndex(index);
+
+ self.current_epoch_tree()
+ .get_leaf_node(leaf_index)
+ .ok()
+ .map(|ln| member_from_leaf_node(ln, leaf_index))
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn proposal_message(
+ &mut self,
+ proposal: Proposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let sender = Sender::Member(*self.private_tree.self_index);
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ sender,
+ Content::Proposal(alloc::boxed::Box::new(proposal.clone())),
+ &self.signer,
+ #[cfg(feature = "private_message")]
+ self.encryption_options()?.control_wire_format(sender),
+ #[cfg(not(feature = "private_message"))]
+ WireFormat::PublicMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ let proposal_ref =
+ ProposalRef::from_content(&self.cipher_suite_provider, &auth_content).await?;
+
+ self.state
+ .proposals
+ .insert(proposal_ref, proposal, auth_content.content.sender);
+
+ self.format_for_wire(auth_content).await
+ }
+
+ /// Unique identifier for this group.
+ pub fn group_id(&self) -> &[u8] {
+ &self.context().group_id
+ }
+
+ fn provisional_private_tree(
+ &self,
+ provisional_state: &ProvisionalState,
+ ) -> Result<(TreeKemPrivate, Option<SignatureSecretKey>), MlsError> {
+ let mut provisional_private_tree = self.private_tree.clone();
+ let self_index = provisional_private_tree.self_index;
+
+ // Remove secret keys for blanked nodes
+ let path = provisional_state
+ .public_tree
+ .nodes
+ .direct_copath(self_index);
+
+ provisional_private_tree
+ .secret_keys
+ .resize(path.len() + 1, None);
+
+ for (i, n) in path.iter().enumerate() {
+ if provisional_state.public_tree.nodes.is_blank(n.path)? {
+ provisional_private_tree.secret_keys[i + 1] = None;
+ }
+ }
+
+ // Apply own update
+ let new_signer = None;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut new_signer = new_signer;
+
+ #[cfg(feature = "by_ref_proposal")]
+ for p in &provisional_state.applied_proposals.updates {
+ if p.sender == Sender::Member(*self_index) {
+ let leaf_pk = &p.proposal.leaf_node.public_key;
+
+ // Update the leaf in the private tree if this is our update
+ #[cfg(feature = "std")]
+ let new_leaf_sk_and_signer = self.pending_updates.get(leaf_pk);
+
+ #[cfg(not(feature = "std"))]
+ let new_leaf_sk_and_signer = self
+ .pending_updates
+ .iter()
+ .find_map(|(pk, sk)| (pk == leaf_pk).then_some(sk));
+
+ let new_leaf_sk = new_leaf_sk_and_signer.map(|(sk, _)| sk.clone());
+ new_signer = new_leaf_sk_and_signer.and_then(|(_, sk)| sk.clone());
+
+ provisional_private_tree
+ .update_leaf(new_leaf_sk.ok_or(MlsError::UpdateErrorNoSecretKey)?);
+
+ break;
+ }
+ }
+
+ Ok((provisional_private_tree, new_signer))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_group_secrets(
+ &self,
+ key_package: &KeyPackage,
+ leaf_index: LeafIndex,
+ joiner_secret: &JoinerSecret,
+ path_secrets: Option<&Vec<Option<PathSecret>>>,
+ #[cfg(feature = "psk")] psks: Vec<PreSharedKeyID>,
+ encrypted_group_info: &[u8],
+ ) -> Result<EncryptedGroupSecrets, MlsError> {
+ let path_secret = path_secrets
+ .map(|secrets| {
+ secrets
+ .get(
+ tree_math::leaf_lca_level(*self.private_tree.self_index, *leaf_index)
+ as usize
+ - 1,
+ )
+ .cloned()
+ .flatten()
+ .ok_or(MlsError::InvalidTreeKemPrivateKey)
+ })
+ .transpose()?;
+
+ #[cfg(not(feature = "psk"))]
+ let psks = Vec::new();
+
+ let group_secrets = GroupSecrets {
+ joiner_secret: joiner_secret.clone(),
+ path_secret,
+ psks,
+ };
+
+ let encrypted_group_secrets = group_secrets
+ .encrypt(
+ &self.cipher_suite_provider,
+ &key_package.hpke_init_key,
+ encrypted_group_info,
+ )
+ .await?;
+
+ Ok(EncryptedGroupSecrets {
+ new_member: key_package
+ .to_reference(&self.cipher_suite_provider)
+ .await?,
+ encrypted_group_secrets,
+ })
+ }
+
+ /// Create a proposal message that adds a new member to the group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_add(
+ &mut self,
+ key_package: MlsMessage,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.add_proposal(key_package)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn add_proposal(&self, key_package: MlsMessage) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Add(alloc::boxed::Box::new(AddProposal {
+ key_package: key_package
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)?,
+ })))
+ }
+
+ /// Create a proposal message that updates your own public keys.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update(
+ &mut self,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.update_proposal(None, None).await?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that updates your own public keys
+ /// as well as your credential.
+ ///
+ /// This proposal is useful for contributing additional forward secrecy
+ /// and post-compromise security to the group without having to perform
+ /// the necessary computation of a [`Group::commit`].
+ ///
+ /// Identity updates are allowed by the group by default assuming that the
+ /// new identity provided is considered
+ /// [valid](crate::IdentityProvider::validate_member)
+ /// by and matches the output of the
+ /// [identity](crate::IdentityProvider)
+ /// function of the current
+ /// [`IdentityProvider`](crate::IdentityProvider).
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_update_with_identity(
+ &mut self,
+ signer: SignatureSecretKey,
+ signing_identity: SigningIdentity,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self
+ .update_proposal(Some(signer), Some(signing_identity))
+ .await?;
+
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_proposal(
+ &mut self,
+ signer: Option<SignatureSecretKey>,
+ signing_identity: Option<SigningIdentity>,
+ ) -> Result<Proposal, MlsError> {
+ // Grab a copy of the current node and update it to have new key material
+ let mut new_leaf_node = self.current_user_leaf_node()?.clone();
+
+ let secret_key = new_leaf_node
+ .update(
+ &self.cipher_suite_provider,
+ self.group_id(),
+ self.current_member_index(),
+ self.config.leaf_properties(),
+ signing_identity,
+ signer.as_ref().unwrap_or(&self.signer),
+ )
+ .await?;
+
+ // Store the secret key in the pending updates storage for later
+ #[cfg(feature = "std")]
+ self.pending_updates
+ .insert(new_leaf_node.public_key.clone(), (secret_key, signer));
+
+ #[cfg(not(feature = "std"))]
+ self.pending_updates
+ .push((new_leaf_node.public_key.clone(), (secret_key, signer)));
+
+ Ok(Proposal::Update(UpdateProposal {
+ leaf_node: new_leaf_node,
+ }))
+ }
+
+ /// Create a proposal message that removes an existing member from the
+ /// group.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_remove(
+ &mut self,
+ index: u32,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.remove_proposal(index)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn remove_proposal(&self, index: u32) -> Result<Proposal, MlsError> {
+ let leaf_index = LeafIndex(index);
+
+ // Verify that this leaf is actually in the tree
+ self.current_epoch_tree().get_leaf_node(leaf_index)?;
+
+ Ok(Proposal::Remove(RemoveProposal {
+ to_remove: leaf_index,
+ }))
+ }
+
+ /// Create a proposal message that adds an external pre shared key to the group.
+ ///
+ /// Each group member will need to have the PSK associated with
+ /// [`ExternalPskId`](mls_rs_core::psk::ExternalPskId) installed within
+ /// the [`PreSharedKeyStorage`](mls_rs_core::psk::PreSharedKeyStorage)
+ /// in use by this group upon processing a [commit](Group::commit) that
+ /// contains this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_external_psk(
+ &mut self,
+ psk: ExternalPskId,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.psk_proposal(JustPreSharedKeyID::External(psk))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ #[cfg(feature = "psk")]
+ fn psk_proposal(&self, key_id: JustPreSharedKeyID) -> Result<Proposal, MlsError> {
+ Ok(Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(key_id, &self.cipher_suite_provider)?,
+ }))
+ }
+
+ /// Create a proposal message that adds a pre shared key from a previous
+ /// epoch to the current group state.
+ ///
+ /// Each group member will need to have the secret state from `psk_epoch`.
+ /// In particular, the members who joined between `psk_epoch` and the
+ /// current epoch cannot process a commit containing this proposal.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_resumption_psk(
+ &mut self,
+ psk_epoch: u64,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let key_id = ResumptionPsk {
+ psk_epoch,
+ usage: ResumptionPSKUsage::Application,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ };
+
+ let proposal = self.psk_proposal(JustPreSharedKeyID::Resumption(key_id))?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ /// Create a proposal message that requests for this group to be
+ /// reinitialized.
+ ///
+ /// Once a [`ReInitProposal`](proposal::ReInitProposal)
+ /// has been sent, another group member can complete reinitialization of
+ /// the group by calling [`Group::get_reinit_client`].
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_reinit(
+ &mut self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.reinit_proposal(group_id, version, cipher_suite, extensions)?;
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn reinit_proposal(
+ &self,
+ group_id: Option<Vec<u8>>,
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extensions: ExtensionList,
+ ) -> Result<Proposal, MlsError> {
+ let group_id = group_id.map(Ok).unwrap_or_else(|| {
+ self.cipher_suite_provider
+ .random_bytes_vec(self.cipher_suite_provider.kdf_extract_size())
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ })?;
+
+ Ok(Proposal::ReInit(ReInitProposal {
+ group_id,
+ version,
+ cipher_suite,
+ extensions,
+ }))
+ }
+
+ /// Create a proposal message that sets extensions stored in the group
+ /// state.
+ ///
+ /// # Warning
+ ///
+ /// This function does not create a diff that will be applied to the
+ /// current set of extension that are in use. In order for an existing
+ /// extension to not be overwritten by this proposal, it must be included
+ /// in the new set of extensions being proposed.
+ ///
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_group_context_extensions(
+ &mut self,
+ extensions: ExtensionList,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ let proposal = self.group_context_extensions_proposal(extensions);
+ self.proposal_message(proposal, authenticated_data).await
+ }
+
+ fn group_context_extensions_proposal(&self, extensions: ExtensionList) -> Proposal {
+ Proposal::GroupContextExtensions(extensions)
+ }
+
+ /// Create a custom proposal message.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(all(feature = "custom_proposal", feature = "by_ref_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn propose_custom(
+ &mut self,
+ proposal: CustomProposal,
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ self.proposal_message(Proposal::Custom(proposal), authenticated_data)
+ .await
+ }
+
+ /// Delete all sent and received proposals cached for commit.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn clear_proposal_cache(&mut self) {
+ self.state.proposals.clear()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn format_for_wire(
+ &mut self,
+ content: AuthenticatedContent,
+ ) -> Result<MlsMessage, MlsError> {
+ #[cfg(feature = "private_message")]
+ let payload = if content.wire_format == WireFormat::PrivateMessage {
+ MlsMessagePayload::Cipher(self.create_ciphertext(content).await?)
+ } else {
+ MlsMessagePayload::Plain(self.create_plaintext(content).await?)
+ };
+ #[cfg(not(feature = "private_message"))]
+ let payload = MlsMessagePayload::Plain(self.create_plaintext(content).await?);
+
+ Ok(MlsMessage::new(self.protocol_version(), payload))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_plaintext(
+ &self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PublicMessage, MlsError> {
+ let membership_tag = if matches!(auth_content.content.sender, Sender::Member(_)) {
+ let tag = self
+ .key_schedule
+ .get_membership_tag(&auth_content, self.context(), &self.cipher_suite_provider)
+ .await?;
+
+ Some(tag)
+ } else {
+ None
+ };
+
+ Ok(PublicMessage {
+ content: auth_content.content,
+ auth: auth_content.auth,
+ membership_tag,
+ })
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn create_ciphertext(
+ &mut self,
+ auth_content: AuthenticatedContent,
+ ) -> Result<PrivateMessage, MlsError> {
+ let padding_mode = self.encryption_options()?.padding_mode;
+
+ let mut encryptor = CiphertextProcessor::new(self, self.cipher_suite_provider.clone());
+
+ encryptor.seal(auth_content, padding_mode).await
+ }
+
+ /// Encrypt an application message using the current group state.
+ ///
+ /// `authenticated_data` will be sent unencrypted along with the contents
+ /// of the proposal message.
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encrypt_application_message(
+ &mut self,
+ message: &[u8],
+ authenticated_data: Vec<u8>,
+ ) -> Result<MlsMessage, MlsError> {
+ // A group member that has observed one or more proposals within an epoch MUST send a Commit message
+ // before sending application data
+ #[cfg(feature = "by_ref_proposal")]
+ if !self.state.proposals.is_empty() {
+ return Err(MlsError::CommitRequired);
+ }
+
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.cipher_suite_provider,
+ self.context(),
+ Sender::Member(*self.private_tree.self_index),
+ Content::Application(message.to_vec().into()),
+ &self.signer,
+ WireFormat::PrivateMessage,
+ authenticated_data,
+ )
+ .await?;
+
+ self.format_for_wire(auth_content).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn decrypt_incoming_ciphertext(
+ &mut self,
+ message: &PrivateMessage,
+ ) -> Result<AuthenticatedContent, MlsError> {
+ let epoch_id = message.epoch;
+
+ let auth_content = if epoch_id == self.context().epoch {
+ let content = CiphertextProcessor::new(self, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::RatchetTree(&self.state.public_tree),
+ self.context(),
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok::<_, MlsError>(content)
+ } else {
+ #[cfg(feature = "prior_epoch")]
+ {
+ let epoch = self
+ .state_repo
+ .get_epoch_mut(epoch_id)
+ .await?
+ .ok_or(MlsError::EpochNotFound)?;
+
+ let content = CiphertextProcessor::new(epoch, self.cipher_suite_provider.clone())
+ .open(message)
+ .await?;
+
+ verify_auth_content_signature(
+ &self.cipher_suite_provider,
+ SignaturePublicKeysContainer::List(&epoch.signature_public_keys),
+ &epoch.context,
+ &content,
+ #[cfg(feature = "by_ref_proposal")]
+ &[],
+ )
+ .await?;
+
+ Ok(content)
+ }
+
+ #[cfg(not(feature = "prior_epoch"))]
+ Err(MlsError::EpochNotFound)
+ }?;
+
+ Ok(auth_content)
+ }
+
+ /// Apply a pending commit that was created by [`Group::commit`] or
+ /// [`CommitBuilder::build`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn apply_pending_commit(&mut self) -> Result<CommitMessageDescription, MlsError> {
+ let pending_commit = self
+ .pending_commit
+ .clone()
+ .ok_or(MlsError::PendingCommitNotFound)?;
+
+ self.process_commit(pending_commit.content, None).await
+ }
+
+ /// Returns true if a commit has been created but not yet applied
+ /// with [`Group::apply_pending_commit`] or cleared with [`Group::clear_pending_commit`]
+ pub fn has_pending_commit(&self) -> bool {
+ self.pending_commit.is_some()
+ }
+
+ /// Clear the currently pending commit.
+ ///
+ /// This function will automatically be called in the event that a
+ /// commit message is processed using [`Group::process_incoming_message`]
+ /// before [`Group::apply_pending_commit`] is called.
+ pub fn clear_pending_commit(&mut self) {
+ self.pending_commit = None
+ }
+
+ /// Process an inbound message for this group.
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn process_incoming_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ if let Some(pending) = &self.pending_commit {
+ let message_hash = CommitHash::compute(&self.cipher_suite_provider, &message).await?;
+
+ if message_hash == pending.commit_message_hash {
+ let message_description = self.apply_pending_commit().await?;
+
+ return Ok(ReceivedMessage::Commit(message_description));
+ }
+ }
+
+ MessageProcessor::process_incoming_message(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ )
+ .await
+ }
+
+ /// Process an inbound message for this group, providing additional context
+ /// with a message timestamp.
+ ///
+ /// Providing a timestamp is useful when the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// in use by the group can determine validity based on a timestamp.
+ /// For example, this allows for checking X.509 certificate expiration
+ /// at the time when `message` was received by a server rather than when
+ /// a specific client asynchronously received `message`
+ ///
+ /// # Warning
+ ///
+ /// Changes to the group's state as a result of processing `message` will
+ /// not be persisted by the
+ /// [`GroupStateStorage`](crate::GroupStateStorage)
+ /// in use by this group until [`Group::write_to_storage`] is called.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn process_incoming_message_with_time(
+ &mut self,
+ message: MlsMessage,
+ time: MlsTime,
+ ) -> Result<ReceivedMessage, MlsError> {
+ MessageProcessor::process_incoming_message_with_time(
+ self,
+ message,
+ #[cfg(feature = "by_ref_proposal")]
+ true,
+ Some(time),
+ )
+ .await
+ }
+
+ /// Find a group member by
+ /// [identity](crate::IdentityProvider::identity)
+ ///
+ /// This function determines identity by calling the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn member_with_identity(&self, identity: &[u8]) -> Result<Member, MlsError> {
+ let tree = &self.state.public_tree;
+
+ #[cfg(feature = "tree_index")]
+ let index = tree.get_leaf_node_with_identity(identity);
+
+ #[cfg(not(feature = "tree_index"))]
+ let index = tree
+ .get_leaf_node_with_identity(
+ identity,
+ &self.identity_provider(),
+ &self.state.context.extensions,
+ )
+ .await?;
+
+ let index = index.ok_or(MlsError::MemberNotFound)?;
+ let node = self.state.public_tree.get_leaf_node(index)?;
+
+ Ok(member_from_leaf_node(node, index))
+ }
+
+ /// Create a group info message that can be used for external proposals and commits.
+ ///
+ /// The returned `GroupInfo` is suitable for one external commit for the current epoch.
+ /// If `with_tree_in_extension` is set to true, the returned `GroupInfo` contains the
+ /// ratchet tree and therefore contains all information needed to join the group. Otherwise,
+ /// the ratchet tree must be obtained separately, e.g. via
+ /// (ExternalClient::export_tree)[crate::external_client::ExternalGroup::export_tree].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_allowing_ext_commit(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from({
+ self.key_schedule
+ .get_external_key_pair_ext(&self.cipher_suite_provider)
+ .await?
+ })?;
+
+ self.group_info_message_internal(extensions, with_tree_in_extension)
+ .await
+ }
+
+ /// Create a group info message that can be used for external proposals.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message(
+ &self,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ self.group_info_message_internal(ExtensionList::new(), with_tree_in_extension)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn group_info_message_internal(
+ &self,
+ mut initial_extensions: ExtensionList,
+ with_tree_in_extension: bool,
+ ) -> Result<MlsMessage, MlsError> {
+ if with_tree_in_extension {
+ initial_extensions.set_from(RatchetTreeExt {
+ tree_data: ExportedTree::new(self.state.public_tree.nodes.clone()),
+ })?;
+ }
+
+ let mut info = GroupInfo {
+ group_context: self.context().clone(),
+ extensions: initial_extensions,
+ confirmation_tag: self.state.confirmation_tag.clone(),
+ signer: self.private_tree.self_index,
+ signature: Vec::new(),
+ };
+
+ info.grease(self.cipher_suite_provider())?;
+
+ info.sign(&self.cipher_suite_provider, &self.signer, &())
+ .await?;
+
+ Ok(MlsMessage::new(
+ self.protocol_version(),
+ MlsMessagePayload::GroupInfo(info),
+ ))
+ }
+
+ /// Get the current group context summarizing various information about the group.
+ #[inline(always)]
+ pub fn context(&self) -> &GroupContext {
+ &self.group_state().context
+ }
+
+ /// Get the
+ /// [epoch_authenticator](https://messaginglayersecurity.rocks/mls-protocol/draft-ietf-mls-protocol.html#name-key-schedule)
+ /// of the current epoch.
+ pub fn epoch_authenticator(&self) -> Result<Secret, MlsError> {
+ Ok(self.key_schedule.authentication_secret.clone().into())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn export_secret(
+ &self,
+ label: &[u8],
+ context: &[u8],
+ len: usize,
+ ) -> Result<Secret, MlsError> {
+ self.key_schedule
+ .export_secret(label, context, len, &self.cipher_suite_provider)
+ .await
+ .map(Into::into)
+ }
+
+ /// Export the current epoch's ratchet tree in serialized format.
+ ///
+ /// This function is used to provide the current group tree to new members
+ /// when the `ratchet_tree_extension` is not used according to [`MlsRules::commit_options`].
+ pub fn export_tree(&self) -> ExportedTree<'_> {
+ ExportedTree::new_borrowed(&self.current_epoch_tree().nodes)
+ }
+
+ /// Current version of the MLS protocol in use by this group.
+ pub fn protocol_version(&self) -> ProtocolVersion {
+ self.context().protocol_version
+ }
+
+ /// Current cipher suite in use by this group.
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.context().cipher_suite
+ }
+
+ /// Current roster
+ pub fn roster(&self) -> Roster<'_> {
+ self.group_state().public_tree.roster()
+ }
+
+ /// Determines equality of two different groups internal states.
+ /// Useful for testing.
+ ///
+ pub fn equal_group_state(a: &Group<C>, b: &Group<C>) -> bool {
+ a.state == b.state && a.key_schedule == b.key_schedule && a.epoch_secrets == b.epoch_secrets
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_psk(
+ &self,
+ psks: &[ProposalInfo<PreSharedKeyProposal>],
+ ) -> Result<(PskSecret, Vec<PreSharedKeyID>), MlsError> {
+ if let Some(psk) = self.previous_psk.clone() {
+ // TODO consider throwing error if psks not empty
+ let psk_id = vec![psk.id.clone()];
+ let psk = PskSecret::calculate(&[psk], self.cipher_suite_provider()).await?;
+
+ Ok((psk, psk_id))
+ } else {
+ let psks = psks
+ .iter()
+ .map(|psk| psk.proposal.psk.clone())
+ .collect::<Vec<_>>();
+
+ let psk = PskResolver {
+ group_context: Some(self.context()),
+ current_epoch: Some(&self.epoch_secrets),
+ prior_epochs: Some(&self.state_repo),
+ psk_store: &self.config.secret_store(),
+ }
+ .resolve_to_secret(&psks, self.cipher_suite_provider())
+ .await?;
+
+ Ok((psk, psks))
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ pub(crate) fn encryption_options(&self) -> Result<EncryptionOptions, MlsError> {
+ self.config
+ .mls_rules()
+ .encryption_options(&self.roster(), self.group_context().extensions())
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))
+ }
+
+ #[cfg(not(feature = "psk"))]
+ fn get_psk(&self) -> PskSecret {
+ PskSecret::new(self.cipher_suite_provider())
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn next_encryption_key(&mut self) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .next_message_key(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(self.private_tree.self_index),
+ KeyType::Application,
+ )
+ .await
+ }
+
+ #[cfg(feature = "secret_tree_access")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive_decryption_key(
+ &mut self,
+ sender: u32,
+ generation: u32,
+ ) -> Result<MessageKey, MlsError> {
+ self.epoch_secrets
+ .secret_tree
+ .message_key_generation(
+ &self.cipher_suite_provider,
+ crate::tree_kem::node::NodeIndex::from(sender),
+ KeyType::Application,
+ generation,
+ )
+ .await
+ }
+}
+
+#[cfg(feature = "private_message")]
+impl<C> GroupStateProvider for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ fn group_context(&self) -> &GroupContext {
+ self.context()
+ }
+
+ fn self_index(&self) -> LeafIndex {
+ self.private_tree.self_index
+ }
+
+ fn epoch_secrets_mut(&mut self) -> &mut EpochSecrets {
+ &mut self.epoch_secrets
+ }
+
+ fn epoch_secrets(&self) -> &EpochSecrets {
+ &self.epoch_secrets
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl<C> MessageProcessor for Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ type MlsRules = C::MlsRules;
+ type IdentityProvider = C::IdentityProvider;
+ type PreSharedKeyStorage = C::PskStore;
+ type OutputType = ReceivedMessage;
+ type CipherSuiteProvider = <C::CryptoProvider as CryptoProvider>::CipherSuiteProvider;
+
+ #[cfg(feature = "private_message")]
+ fn self_index(&self) -> Option<LeafIndex> {
+ Some(self.private_tree.self_index)
+ }
+
+ #[cfg(feature = "private_message")]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.decrypt_incoming_ciphertext(cipher_text)
+ .await
+ .map(EventOrContent::Content)
+ }
+
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ let auth_content = verify_plaintext_authentication(
+ &self.cipher_suite_provider,
+ message,
+ Some(&self.key_schedule),
+ Some(self.private_tree.self_index),
+ &self.state,
+ )
+ .await?;
+
+ Ok(EventOrContent::Content(auth_content))
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ // Update the private tree to create a provisional private tree
+ let (mut provisional_private_tree, new_signer) =
+ self.provisional_private_tree(provisional_state)?;
+
+ if let Some(signer) = new_signer {
+ self.signer = signer;
+ }
+
+ provisional_state
+ .public_tree
+ .apply_update_path(
+ sender,
+ update_path,
+ &provisional_state.group_context.extensions,
+ self.identity_provider(),
+ self.cipher_suite_provider(),
+ )
+ .await?;
+
+ if let Some(pending) = &self.pending_commit {
+ Ok(Some((
+ pending.pending_private_tree.clone(),
+ pending.pending_commit_secret.clone(),
+ )))
+ } else {
+ // Update the tree hash to get context for decryption
+ provisional_state.group_context.tree_hash = provisional_state
+ .public_tree
+ .tree_hash(&self.cipher_suite_provider)
+ .await?;
+
+ let context_bytes = provisional_state.group_context.mls_encode_to_vec()?;
+
+ TreeKem::new(
+ &mut provisional_state.public_tree,
+ &mut provisional_private_tree,
+ )
+ .decap(
+ sender,
+ update_path,
+ &provisional_state.indexes_of_added_kpkgs,
+ &context_bytes,
+ &self.cipher_suite_provider,
+ )
+ .await
+ .map(|root_secret| Some((provisional_private_tree, root_secret)))
+ }
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ provisional_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ let commit_secret = if let Some(secrets) = secrets {
+ self.private_tree = secrets.0;
+ secrets.1
+ } else {
+ PathSecret::empty(&self.cipher_suite_provider)
+ };
+
+ // Use the commit_secret, the psk_secret, the provisional GroupContext, and the init secret
+ // from the previous epoch (or from the external init) to compute the epoch secret and
+ // derived secrets for the new epoch
+
+ let key_schedule = match provisional_state
+ .applied_proposals
+ .external_initializations
+ .first()
+ .cloned()
+ {
+ Some(ext_init) if self.pending_commit.is_none() => {
+ self.key_schedule
+ .derive_for_external(&ext_init.proposal.kem_output, &self.cipher_suite_provider)
+ .await?
+ }
+ _ => self.key_schedule.clone(),
+ };
+
+ #[cfg(feature = "psk")]
+ let (psk, _) = self
+ .get_psk(&provisional_state.applied_proposals.psks)
+ .await?;
+
+ #[cfg(not(feature = "psk"))]
+ let psk = self.get_psk();
+
+ let key_schedule_result = KeySchedule::from_key_schedule(
+ &key_schedule,
+ &commit_secret,
+ &provisional_state.group_context,
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ provisional_state.public_tree.total_leaf_count(),
+ &psk,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ // Use the confirmation_key for the new epoch to compute the confirmation tag for
+ // this message, as described below, and verify that it is the same as the
+ // confirmation_tag field in the MlsPlaintext object.
+ let new_confirmation_tag = ConfirmationTag::create(
+ &key_schedule_result.confirmation_key,
+ &provisional_state.group_context.confirmed_transcript_hash,
+ &self.cipher_suite_provider,
+ )
+ .await?;
+
+ if &new_confirmation_tag != confirmation_tag {
+ return Err(MlsError::InvalidConfirmationTag);
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ let signature_public_keys = self
+ .state
+ .public_tree
+ .leaves()
+ .map(|l| l.map(|n| n.signing_identity.signature_key.clone()))
+ .collect();
+
+ #[cfg(feature = "prior_epoch")]
+ let past_epoch = PriorEpoch {
+ context: self.context().clone(),
+ self_index: self.private_tree.self_index,
+ secrets: self.epoch_secrets.clone(),
+ signature_public_keys,
+ };
+
+ #[cfg(feature = "prior_epoch")]
+ self.state_repo.insert(past_epoch).await?;
+
+ self.epoch_secrets = key_schedule_result.epoch_secrets;
+ self.state.context = provisional_state.group_context;
+ self.state.interim_transcript_hash = interim_transcript_hash;
+ self.key_schedule = key_schedule_result.key_schedule;
+ self.state.public_tree = provisional_state.public_tree;
+ self.state.confirmation_tag = new_confirmation_tag;
+
+ // Clear the proposals list
+ #[cfg(feature = "by_ref_proposal")]
+ self.state.proposals.clear();
+
+ // Clear the pending updates list
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ self.pending_updates = Default::default();
+ }
+
+ self.pending_commit = None;
+
+ Ok(())
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.config.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.config.identity_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.config.secret_store()
+ }
+
+ fn group_state(&self) -> &GroupState {
+ &self.state
+ }
+
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ &mut self.state
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ !(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| p.proposal.to_remove == self.private_tree.self_index)
+ && self.pending_commit.is_none())
+ }
+
+ #[cfg(feature = "private_message")]
+ fn min_epoch_available(&self) -> Option<u64> {
+ None
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ &self.cipher_suite_provider
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils;
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{
+ test_client_with_key_pkg, TestClientBuilder, TEST_CIPHER_SUITE,
+ TEST_CUSTOM_PROPOSAL_TYPE, TEST_PROTOCOL_VERSION,
+ },
+ client_builder::{test_utils::TestClientConfig, ClientBuilder, MlsConfig},
+ crypto::test_utils::TestCryptoProvider,
+ group::{
+ mls_rules::{CommitDirection, CommitSource},
+ proposal_filter::ProposalBundle,
+ },
+ identity::{
+ basic::BasicIdentityProvider,
+ test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ },
+ key_package::test_utils::test_key_package_message,
+ mls_rules::CommitOptions,
+ tree_kem::{
+ leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
+ UpdatePathNode,
+ },
+ };
+
+ #[cfg(any(feature = "private_message", feature = "custom_proposal"))]
+ use crate::group::mls_rules::DefaultMlsRules;
+
+ #[cfg(feature = "prior_epoch")]
+ use crate::group::padding::PaddingMode;
+
+ use crate::{extension::RequiredCapabilitiesExt, key_package::test_utils::test_key_package};
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ use super::test_utils::test_group_custom_config;
+
+ #[cfg(feature = "psk")]
+ use crate::{client::Client, psk::PreSharedKey};
+
+ #[cfg(any(feature = "by_ref_proposal", feature = "private_message"))]
+ use crate::group::test_utils::random_bytes;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::test_utils::TestExtension, identity::test_utils::get_test_basic_credential,
+ time::MlsTime,
+ };
+
+ use super::{
+ test_utils::{
+ get_test_25519_key, get_test_groups_with_features, group_extensions, process_commit,
+ test_group, test_group_custom, test_n_member_group, TestGroup, TEST_GROUP,
+ },
+ *,
+ };
+
+ use assert_matches::assert_matches;
+
+ use mls_rs_core::extension::{Extension, ExtensionType};
+ use mls_rs_core::identity::{Credential, CredentialType, CustomCredential};
+
+ #[cfg(feature = "by_ref_proposal")]
+ use mls_rs_core::identity::CertificateChain;
+
+ #[cfg(feature = "state_update")]
+ use itertools::Itertools;
+
+ #[cfg(feature = "state_update")]
+ use alloc::format;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{crypto::test_utils::test_cipher_suite_provider, extension::ExternalSendersExt};
+
+ #[cfg(any(feature = "private_message", feature = "state_update"))]
+ use super::test_utils::test_member;
+
+ use mls_rs_core::extension::MlsExtension;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_group() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let test_group = test_group(protocol_version, cipher_suite).await;
+ let group = test_group.group;
+
+ assert_eq!(group.cipher_suite(), cipher_suite);
+ assert_eq!(group.state.context.epoch, 0);
+ assert_eq!(group.state.context.group_id, TEST_GROUP.to_vec());
+ assert_eq!(group.state.context.extensions, group_extensions());
+
+ assert_eq!(
+ group.state.context.confirmed_transcript_hash,
+ ConfirmedTranscriptHash::from(vec![])
+ );
+
+ #[cfg(feature = "private_message")]
+ assert!(group.state.proposals.is_empty());
+
+ #[cfg(feature = "by_ref_proposal")]
+ assert!(group.pending_updates.is_empty());
+
+ assert!(!group.has_pending_commit());
+
+ assert_eq!(
+ group.private_tree.self_index.0,
+ group.current_member_index()
+ );
+ }
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_pending_proposals_application_data() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create a proposal
+ let (bob_key_package, _) =
+ test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"bob").await;
+
+ let proposal = test_group
+ .group
+ .add_proposal(bob_key_package.key_package_message())
+ .unwrap();
+
+ test_group
+ .group
+ .proposal_message(proposal, vec![])
+ .await
+ .unwrap();
+
+ // We should not be able to send application messages until a commit happens
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitRequired));
+
+ // We should be able to send application messages after a commit
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(test_group.group.has_pending_commit());
+
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ let res = test_group
+ .group
+ .encrypt_application_message(b"test", vec![])
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_proposals() {
+ let new_extension = TestExtension { foo: 10 };
+ let mut extension_list = ExtensionList::default();
+ extension_list.set_from(new_extension).unwrap();
+
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![42.into()],
+ Some(extension_list.clone()),
+ None,
+ )
+ .await;
+
+ let existing_leaf = test_group.group.current_user_leaf_node().unwrap().clone();
+
+ // Create an update proposal
+ let proposal = test_group.update_proposal().await;
+
+ let update = match proposal {
+ Proposal::Update(update) => update,
+ _ => panic!("non update proposal found"),
+ };
+
+ assert_ne!(update.leaf_node.public_key, existing_leaf.public_key);
+
+ assert_eq!(
+ update.leaf_node.signing_identity,
+ existing_leaf.signing_identity
+ );
+
+ assert_eq!(update.leaf_node.ungreased_extensions(), extension_list);
+ assert_eq!(
+ update.leaf_node.ungreased_capabilities().sorted(),
+ Capabilities {
+ extensions: vec![42.into()],
+ ..get_test_capabilities()
+ }
+ .sorted()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_commit_self_update() {
+ let mut test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Create an update proposal
+ let proposal_msg = test_group.group.propose_update(vec![]).await.unwrap();
+
+ let proposal = match proposal_msg.into_plaintext().unwrap().content.content {
+ Content::Proposal(p) => p,
+ _ => panic!("found non-proposal message"),
+ };
+
+ let update_leaf = match *proposal {
+ Proposal::Update(u) => u.leaf_node,
+ _ => panic!("found proposal message that isn't an update"),
+ };
+
+ test_group.group.commit(vec![]).await.unwrap();
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ // The leaf node should not be the one from the update, because the committer rejects it
+ assert_ne!(
+ &update_leaf,
+ test_group.group.current_user_leaf_node().unwrap()
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_with_bad_key_package_is_ignored_when_committing() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut proposal = alice_group.update_proposal().await;
+
+ if let Proposal::Update(ref mut update) = proposal {
+ update.leaf_node.signature = random_bytes(32);
+ } else {
+ panic!("Invalid update proposal")
+ }
+
+ let proposal_message = alice_group
+ .group
+ .proposal_message(proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let proposal_plaintext = match proposal_message.payload {
+ MlsMessagePayload::Plain(p) => p,
+ _ => panic!("Unexpected non-plaintext message"),
+ };
+
+ let proposal_ref = ProposalRef::from_content(
+ &bob_group.group.cipher_suite_provider,
+ &proposal_plaintext.clone().into(),
+ )
+ .await
+ .unwrap();
+
+ // Hack bob's receipt of the proposal
+ bob_group.group.state.proposals.insert(
+ proposal_ref,
+ proposal,
+ proposal_plaintext.content.sender,
+ );
+
+ let commit_output = bob_group.group.commit(vec![]).await.unwrap();
+
+ assert_matches!(
+ commit_output.commit_message,
+ MlsMessage {
+ payload: MlsMessagePayload::Plain(
+ PublicMessage {
+ content: FramedContent {
+ content: Content::Commit(c),
+ ..
+ },
+ ..
+ }),
+ ..
+ } if c.proposals.is_empty()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_two_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ tree_ext: bool,
+ ) -> (TestGroup, TestGroup) {
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(tree_ext)),
+ )
+ .await;
+
+ let (bob_test_group, _) = test_group.join("bob").await;
+
+ assert!(Group::equal_group_state(
+ &test_group.group,
+ &bob_test_group.group
+ ));
+
+ (test_group, bob_test_group)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_exported_tree() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, false).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_tree_extension() {
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_welcome_processing_missing_tree() {
+ let mut test_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_ratchet_tree_extension(false)),
+ )
+ .await;
+
+ let (bob_client, bob_key_package) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ // Add bob to the group
+ let commit_output = test_group
+ .group
+ .commit_builder()
+ .add_member(bob_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Group from Bob's perspective
+ let bob_group = Group::join(
+ &commit_output.welcome_messages[0],
+ None,
+ bob_client.config,
+ bob_client.signer.unwrap(),
+ )
+ .await
+ .map(|_| ());
+
+ assert_matches!(bob_group, Err(MlsError::RatchetTreeNotFound));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_create() {
+ let test_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let proposal = test_group
+ .group
+ .group_context_extensions_proposal(extension_list.clone());
+
+ assert_matches!(proposal, Proposal::GroupContextExtensions(ext) if ext == extension_list);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_context_extension_proposal_test(
+ ext_list: ExtensionList,
+ ) -> (TestGroup, Result<MlsMessage, MlsError>) {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group =
+ test_group_custom(protocol_version, cipher_suite, vec![42.into()], None, None).await;
+
+ let commit = test_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(ext_list)
+ .unwrap()
+ .build()
+ .await
+ .map(|commit_output| commit_output.commit_message);
+
+ (test_group, commit)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_commit() {
+ let mut extension_list = ExtensionList::new();
+
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (mut test_group, _) =
+ group_context_extension_proposal_test(extension_list.clone()).await;
+
+ #[cfg(feature = "state_update")]
+ {
+ let update = test_group.group.apply_pending_commit().await.unwrap();
+ assert!(update.state_update.active);
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ test_group.group.apply_pending_commit().await.unwrap();
+
+ assert_eq!(test_group.group.state.context.extensions, extension_list)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_context_ext_proposal_invalid() {
+ let mut extension_list = ExtensionList::new();
+ extension_list
+ .set_from(RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ })
+ .unwrap();
+
+ let (_, commit) = group_context_extension_proposal_test(extension_list.clone()).await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredExtensionNotFound(a)) if a == 999.into()
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_group_with_required_capabilities(
+ required_caps: RequiredCapabilitiesExt,
+ ) -> Result<Group<TestClientConfig>, MlsError> {
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(required_caps.into_extension().unwrap()).collect())
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_credential_type_fails() {
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ credentials: vec![CredentialType::BASIC, CredentialType::X509],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_extension_type_fails() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredExtensionNotFound(EXTENSION_TYPE))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_required_proposal_type_fails() {
+ const PROPOSAL_TYPE: ProposalType = ProposalType::new(33);
+
+ let group_creation = make_group_with_required_capabilities(RequiredCapabilitiesExt {
+ proposals: vec![PROPOSAL_TYPE],
+ ..Default::default()
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredProposalNotFound(PROPOSAL_TYPE))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn creating_group_with_member_not_supporting_external_sender_credential_fails() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let group_creation =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "alice")
+ .await
+ .0
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ group_creation,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "private_message"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_group_encrypt_plaintext_padding() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ // This test requires a cipher suite whose signatures are not variable in length.
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+
+ let mut test_group = test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(
+ DefaultMlsRules::default()
+ .with_encryption_options(EncryptionOptions::new(true, PaddingMode::None)),
+ )
+ })
+ .await;
+
+ let without_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ let mut test_group =
+ test_group_custom_config(protocol_version, cipher_suite, |b| {
+ b.mls_rules(DefaultMlsRules::default().with_encryption_options(
+ EncryptionOptions::new(true, PaddingMode::StepFunction),
+ ))
+ })
+ .await;
+
+ let with_padding = test_group
+ .group
+ .encrypt_application_message(&random_bytes(150), vec![])
+ .await
+ .unwrap();
+
+ assert!(with_padding.mls_encoded_len() > without_padding.mls_encoded_len());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_requires_external_pub_extension() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let info = group
+ .group
+ .group_info_message(false)
+ .await
+ .unwrap()
+ .into_group_info()
+ .unwrap();
+
+ let info_msg = MlsMessage::new(protocol_version, MlsMessagePayload::GroupInfo(info));
+
+ let signing_identity = group
+ .group
+ .current_member_signing_identity()
+ .unwrap()
+ .clone();
+
+ let res = external_commit::ExternalCommitBuilder::new(
+ group.group.signer,
+ signing_identity,
+ group.group.config,
+ )
+ .build(info_msg)
+ .await
+ .map(|_| {});
+
+ assert_matches!(res, Err(MlsError::MissingExternalPubExtension));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_via_commit_options_round_trip() {
+ let mut group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![],
+ None,
+ CommitOptions::default()
+ .with_allow_external_commit(true)
+ .into(),
+ )
+ .await;
+
+ let commit_output = group.group.commit(vec![]).await.unwrap();
+
+ let (test_client, _) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ test_client
+ .external_commit_builder()
+ .unwrap()
+ .build(commit_output.external_commit_group_info.unwrap())
+ .await
+ .unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ let test_key_package =
+ test_key_package_message(protocol_version, cipher_suite, "alice").await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package.clone())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new().with_path_required(true)),
+ )
+ .await;
+
+ test_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_preference_override() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut test_group = test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ Some(CommitOptions::new()),
+ )
+ .await;
+
+ test_group.group.commit(vec![]).await.unwrap();
+
+ assert!(!test_group
+ .group
+ .pending_commit
+ .unwrap()
+ .pending_commit_secret
+ .iter()
+ .all(|x| x == &0));
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_rejects_unencrypted_application_message() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+
+ let message = alice
+ .make_plaintext(Content::Application(b"hello".to_vec().into()))
+ .await;
+
+ let res = bob.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::UnencryptedApplicationMessage));
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_state_update() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ // Create a group with 10 members
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let mut leaves = vec![];
+
+ for i in 0..8 {
+ let (group, commit) = alice.join(&format!("charlie{i}")).await;
+ leaves.push(group.group.current_user_leaf_node().unwrap().clone());
+ bob.process_message(commit).await.unwrap();
+ }
+
+ // Create many proposals, make Alice commit them
+
+ let update_message = bob.group.propose_update(vec![]).await.unwrap();
+
+ alice.process_message(update_message).await.unwrap();
+
+ let external_psk_ids: Vec<ExternalPskId> = (0..5)
+ .map(|i| {
+ let external_id = ExternalPskId::new(vec![i]);
+
+ alice
+ .group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ bob.group
+ .config
+ .secret_store()
+ .insert(ExternalPskId::new(vec![i]), PreSharedKey::from(vec![i]));
+
+ external_id
+ })
+ .collect();
+
+ let mut commit_builder = alice.group.commit_builder();
+
+ for external_psk in external_psk_ids {
+ commit_builder = commit_builder.add_external_psk(external_psk).unwrap();
+ }
+
+ for index in [2, 5, 6] {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ for i in 0..5 {
+ let (key_package, _) = test_member(
+ protocol_version,
+ cipher_suite,
+ format!("dave{i}").as_bytes(),
+ )
+ .await;
+
+ commit_builder = commit_builder
+ .add_member(key_package.key_package_message())
+ .unwrap()
+ }
+
+ let commit_output = commit_builder.build().await.unwrap();
+
+ let commit_description = alice.process_pending_commit().await.unwrap();
+
+ assert!(!commit_description.is_external);
+
+ assert_eq!(
+ commit_description.committer,
+ alice.group.current_member_index()
+ );
+
+ // Check that applying pending commit and processing commit yields correct update.
+ let state_update_alice = commit_description.state_update.clone();
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![2, 5, 6, 10, 11]
+ );
+
+ assert_eq!(
+ state_update_alice.roster_update.removed(),
+ vec![2, 5, 6]
+ .into_iter()
+ .map(|i| member_from_leaf_node(&leaves[i as usize - 2], LeafIndex(i)))
+ .collect::<Vec<_>>()
+ );
+
+ assert_eq!(
+ state_update_alice
+ .roster_update
+ .updated()
+ .iter()
+ .map(|update| update.new.clone())
+ .collect_vec()
+ .as_slice(),
+ &alice.group.roster().members()[0..2]
+ );
+
+ assert_eq!(
+ state_update_alice.added_psks,
+ (0..5)
+ .map(|i| ExternalPskId::new(vec![i]))
+ .collect::<Vec<_>>()
+ );
+
+ let payload = bob
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(bob_commit_description) = payload else {
+ panic!("expected commit");
+ };
+
+ assert_eq!(commit_description, bob_commit_description);
+ }
+
+ #[cfg(feature = "state_update")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_description_external_commit() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (bob_group, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let event = alice_group.process_message(commit).await.unwrap();
+
+ let ReceivedMessage::Commit(commit_description) = event else {
+ panic!("expected commit");
+ };
+
+ assert!(commit_description.is_external);
+ assert_eq!(commit_description.committer, 1);
+
+ assert_eq!(
+ commit_description.state_update.roster_update.added(),
+ &bob_group.roster().members()[1..2]
+ );
+
+ itertools::assert_equal(
+ bob_group.roster().members_iter(),
+ alice_group.group.roster().members_iter(),
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_new_group_externally() {
+ use crate::client::test_utils::TestClientBuilder;
+
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let (bob_identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_identity, secret_key, TEST_CIPHER_SUITE)
+ .build();
+
+ let (_, commit) = bob
+ .external_commit_builder()
+ .unwrap()
+ .with_tree_data(alice_group.group.export_tree().into_owned())
+ .build(
+ alice_group
+ .group
+ .group_info_message_allowing_ext_commit(false)
+ .await
+ .unwrap(),
+ )
+ .await
+ .unwrap();
+
+ alice_group.process_message(commit).await.unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_membership_tag_from_non_member() {
+ let (mut alice_group, mut bob_group) =
+ test_two_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, true).await;
+
+ let mut commit_output = alice_group.group.commit(vec![]).await.unwrap();
+
+ let plaintext = match commit_output.commit_message.payload {
+ MlsMessagePayload::Plain(ref mut plain) => plain,
+ _ => panic!("Non plaintext message"),
+ };
+
+ plaintext.content.sender = Sender::NewMemberCommit;
+
+ let res = bob_group
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::MembershipTagForNonMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_partial_commits() {
+ let protocol_version = TEST_PROTOCOL_VERSION;
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let mut alice = test_group(protocol_version, cipher_suite).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let (_, commit) = charlie.join("dave").await;
+
+ alice.process_message(commit.clone()).await.unwrap();
+ bob.process_message(commit.clone()).await.unwrap();
+
+ let Content::Commit(commit) = commit.into_plaintext().unwrap().content.content else {
+ panic!("Expected commit")
+ };
+
+ assert!(commit.path.is_none());
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn group_with_path_required() -> TestGroup {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ alice.group.config.0.mls_rules.commit_options.path_required = true;
+
+ alice
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ alice.join("charlie").await;
+
+ alice
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_pending_commit().await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_removed_are_removed() {
+ let mut alice = group_with_path_required().await;
+ alice.join("bob").await;
+ let (mut charlie, _) = alice.join("charlie").await;
+
+ let commit = charlie
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn old_hpke_secrets_of_updated_are_removed() {
+ let mut alice = group_with_path_required().await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (mut charlie, commit) = alice.join("charlie").await;
+ bob.process_message(commit).await.unwrap();
+
+ let update = bob.group.propose_update(vec![]).await.unwrap();
+ charlie.process_message(update.clone()).await.unwrap();
+ alice.process_message(update).await.unwrap();
+
+ let commit = charlie.group.commit(vec![]).await.unwrap();
+
+ assert!(alice.group.private_tree.secret_keys[1].is_some());
+ alice.process_message(commit.commit_message).await.unwrap();
+ assert!(alice.group.private_tree.secret_keys[1].is_none());
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn only_selected_members_of_the_original_group_can_join_subgroup() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob, _) = alice.join("bob").await;
+ let (carol, commit) = alice.join("carol").await;
+
+ // Apply the commit that adds carol
+ bob.group.process_incoming_message(commit).await.unwrap();
+
+ let bob_identity = bob.group.current_member_signing_identity().unwrap().clone();
+ let signer = bob.group.signer.clone();
+
+ let new_key_pkg = Client::new(
+ bob.group.config.clone(),
+ Some(signer),
+ Some((bob_identity, TEST_CIPHER_SUITE)),
+ TEST_PROTOCOL_VERSION,
+ )
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ let (mut alice_sub_group, welcome) = alice
+ .group
+ .branch(b"subgroup".to_vec(), vec![new_key_pkg])
+ .await
+ .unwrap();
+
+ let welcome = &welcome[0];
+
+ let (mut bob_sub_group, _) = bob.group.join_subgroup(welcome, None).await.unwrap();
+
+ // Carol can't join
+ let res = carol.group.join_subgroup(welcome, None).await.map(|_| ());
+ assert_matches!(res, Err(_));
+
+ // Alice and Bob can still talk
+ let commit_output = alice_sub_group.commit(vec![]).await.unwrap();
+
+ bob_sub_group
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn joining_group_fails_if_unsupported<F>(
+ f: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ alice_group.join_with_custom_config("alice", false, f).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_protocol_version_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config.0.settings.protocol_versions.clear();
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedProtocolVersion(v)) if v ==
+ TEST_PROTOCOL_VERSION
+ );
+ }
+
+ // WebCrypto does not support disabling ciphersuites
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn joining_group_fails_if_cipher_suite_is_not_supported() {
+ let res = joining_group_fails_if_unsupported(|config| {
+ config
+ .0
+ .crypto_provider
+ .enabled_cipher_suites
+ .retain(|&x| x != TEST_CIPHER_SUITE);
+ })
+ .await
+ .map(|_| ());
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCipherSuite(TEST_CIPHER_SUITE))
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_can_see_sender_creds() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let bob_msg = b"I'm Bob";
+
+ let msg = bob_group
+ .group
+ .encrypt_application_message(bob_msg, vec![])
+ .await
+ .unwrap();
+
+ let received_by_alice = alice_group
+ .group
+ .process_incoming_message(msg)
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_by_alice,
+ ReceivedMessage::ApplicationMessage(ApplicationMessageDescription { sender_index, .. })
+ if sender_index == bob_group.group.current_member_index()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn members_of_a_group_have_identical_authentication_secrets() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (bob_group, _) = alice_group.join("bob").await;
+
+ assert_eq!(
+ alice_group.group.epoch_authenticator().unwrap(),
+ bob_group.group.epoch_authenticator().unwrap()
+ );
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn member_cannot_decrypt_same_message_twice() {
+ let mut alice_group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let (mut bob_group, _) = alice_group.join("bob").await;
+
+ let message = alice_group
+ .group
+ .encrypt_application_message(b"foobar", Vec::new())
+ .await
+ .unwrap();
+
+ let received_message = bob_group
+ .group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+
+ assert_matches!(
+ received_message,
+ ReceivedMessage::ApplicationMessage(m) if m.data() == b"foobar"
+ );
+
+ let res = bob_group.group.process_incoming_message(message).await;
+
+ assert_matches!(res, Err(MlsError::KeyMissing(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn removing_requirements_allows_to_add() {
+ let mut alice_group = test_group_custom(
+ TEST_PROTOCOL_VERSION,
+ TEST_CIPHER_SUITE,
+ vec![17.into()],
+ None,
+ None,
+ )
+ .await;
+
+ alice_group
+ .group
+ .commit_builder()
+ .set_group_context_ext(
+ vec![RequiredCapabilitiesExt {
+ extensions: vec![17.into()],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]
+ .try_into()
+ .unwrap(),
+ )
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice_group.process_pending_commit().await.unwrap();
+
+ let test_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let test_key_package = MlsMessage::new(
+ TEST_PROTOCOL_VERSION,
+ MlsMessagePayload::KeyPackage(test_key_package),
+ );
+
+ alice_group
+ .group
+ .commit_builder()
+ .add_member(test_key_package)
+ .unwrap()
+ .set_group_context_ext(Default::default())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ let state_update = alice_group
+ .process_pending_commit()
+ .await
+ .unwrap()
+ .state_update;
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ state_update
+ .roster_update
+ .added()
+ .iter()
+ .map(|m| m.index)
+ .collect::<Vec<_>>(),
+ vec![1]
+ );
+
+ #[cfg(not(feature = "state_update"))]
+ assert!(state_update == StateUpdate {});
+
+ assert_eq!(alice_group.group.roster().members_iter().count(), 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_wrong_source() {
+ // RFC, 13.4.2. "The leaf_node_source field MUST be set to commit."
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.leaf_node_source = LeafNodeSource::Update;
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_same_hpke_key() {
+ // RFC 13.4.2. "Verify that the encryption_key value in the LeafNode is different from the committer's current leaf node"
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ // Group 0 starts using fixed key
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+ groups[0].process_pending_commit().await.unwrap();
+ groups[2]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ // Group 0 tries to use the fixed key againd
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::SameHpkeKey(0)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_hpke_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `encryption_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_duplicate_signature_key() {
+ // RFC 8.3 "Verify that the following fields are unique among the members of the group: `signature_key`"
+
+ if TEST_CIPHER_SUITE != CipherSuite::CURVE25519_AES128
+ && TEST_CIPHER_SUITE != CipherSuite::CURVE25519_CHACHA
+ {
+ return;
+ }
+
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ // Group 1 uses the fixed key
+ groups[1].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups
+ .get_mut(1)
+ .unwrap()
+ .group
+ .commit(vec![])
+ .await
+ .unwrap();
+
+ process_commit(&mut groups, commit_output.commit_message, 1).await;
+
+ // Group 0 tries to use the fixed key too
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ let sk = hex!(
+ "3468b4c890255c983e3d5cbf5cb64c1ef7f6433a518f2f3151d6672f839a06ebcad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b"
+ )
+ .into();
+
+ leaf.signing_identity.signature_key =
+ hex!("cad4fc381fe61822af45135c82921a348e6f46643d66ddefc70483565433714b").into();
+
+ Some(sk)
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_incorrect_signature() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, _| {
+ leaf.signature[0] ^= 1;
+ None
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_used_context_extension() {
+ const EXT_TYPE: ExtensionType = ExtensionType::new(999);
+
+ // The new leaf of the committer doesn't support an extension set in group context
+ let extension = Extension::new(EXT_TYPE, vec![]);
+
+ let mut groups =
+ get_test_groups_with_features(3, vec![extension].into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = get_test_capabilities();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[1]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::UnsupportedGroupExtension(EXT_TYPE)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_extension() {
+ // The new leaf of the committer doesn't support an extension required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![999.into()],
+ proposals: vec![],
+ credentials: vec![],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities = Capabilities::default();
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_has_unsupported_credential() {
+ // The new leaf of the committer has a credential unsupported by another leaf
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ for group in groups.iter_mut() {
+ group.config.0.identity_provider.allow_any_custom = true;
+ }
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(43),
+ leaf.signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ .to_vec(),
+ ));
+
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::CredentialTypeOfNewLeafIsUnsupported));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_credential_used_in_another_leaf() {
+ // The new leaf of the committer doesn't support another leaf's credential
+
+ let mut groups =
+ get_test_groups_with_features(3, Default::default(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_required_credential() {
+ // The new leaf of the committer doesn't support a credential required by group context
+
+ let extension = RequiredCapabilitiesExt {
+ extensions: vec![],
+ proposals: vec![],
+ credentials: vec![1.into()],
+ };
+
+ let extensions = vec![extension.into_extension().unwrap()];
+ let mut groups =
+ get_test_groups_with_features(3, extensions.into(), Default::default()).await;
+
+ groups[0].commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![2.into()];
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].commit(vec![]).await.unwrap();
+
+ let res = groups[2]
+ .process_incoming_message(commit_output.commit_message)
+ .await;
+
+ assert_matches!(res, Err(MlsError::RequiredCredentialNotFound(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_x509_external_senders_ext() -> ExternalSendersExt {
+ let (_, ext_sender_pk) = test_cipher_suite_provider(TEST_CIPHER_SUITE)
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let ext_sender_id = SigningIdentity {
+ signature_key: ext_sender_pk,
+ credential: Credential::X509(CertificateChain::from(vec![random_bytes(32)])),
+ };
+
+ ExternalSendersExt::new(vec![ext_sender_id])
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_leaf_not_supporting_external_sender_credential_leads_to_rejected_commit() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ // New leaf supports only basic credentials (used by the group) but not X509 used by external sender
+ alice.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.capabilities.credentials = vec![CredentialType::BASIC];
+ Some(sk.clone())
+ };
+
+ alice.commit(vec![]).await.unwrap();
+ let res = alice.apply_pending_commit().await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn node_not_supporting_external_sender_credential_cannot_join_group() {
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(core::iter::once(ext_senders).collect())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ commit,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg(not(target_arch = "wasm32"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_senders_extension_is_rejected_if_member_does_not_support_credential_type() {
+ let mut alice = ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(
+ BasicWithCustomProvider::default().with_credential_type(CredentialType::X509),
+ )
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .build()
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let (_, bob_key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ alice
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+ assert_eq!(alice.roster().members_iter().count(), 2);
+
+ let ext_senders = make_x509_external_senders_ext()
+ .await
+ .into_extension()
+ .unwrap();
+
+ let res = alice
+ .commit_builder()
+ .set_group_context_ext(core::iter::once(ext_senders).collect())
+ .unwrap()
+ .build()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredCredentialNotFound(CredentialType::X509))
+ );
+ }
+
+ /*
+ * Edge case paths
+ */
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_degenerate_path_succeeds() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ tree.update_node(get_test_25519_key(1u8), 3).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_leaf = |leaf, sk| {
+ leaf.public_key = get_test_25519_key(1u8);
+ Some(sk.clone())
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn inserting_key_in_filtered_node_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_tree = |tree: &mut TreeKemPublic| {
+ tree.update_node(get_test_25519_key(1u8), 1).unwrap();
+ };
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ let mut node = path[0].clone();
+ node.public_key = get_test_25519_key(1u8);
+ path.insert(0, node);
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ // We should get a path validation error, since the path is too long
+ assert_matches!(res, Err(MlsError::WrongPathLen));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn commit_with_too_short_path_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 10).await;
+
+ let commit_output = groups[0]
+ .group
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].process_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(2) {
+ group
+ .process_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups[0].group.commit_modifiers.modify_path = |path: Vec<UpdatePathNode>| {
+ let mut path = path;
+ path.pop();
+ path
+ };
+
+ let commit_output = groups[0].group.commit(vec![]).await.unwrap();
+
+ let res = groups[7]
+ .process_message(commit_output.commit_message)
+ .await;
+
+ assert!(res.is_err());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn update_proposal_can_change_credential() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 3).await;
+ let (identity, secret_key) = get_test_signing_identity(TEST_CIPHER_SUITE, b"member").await;
+
+ let update = groups[0]
+ .group
+ .propose_update_with_identity(secret_key, identity.clone(), vec![])
+ .await
+ .unwrap();
+
+ groups[1].process_message(update).await.unwrap();
+ let commit_output = groups[1].group.commit(vec![]).await.unwrap();
+
+ // Check that the credential was updated by in the committer's state.
+ groups[1].process_pending_commit().await.unwrap();
+ let new_member = groups[1].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+
+ // Check that the credential was updated in the updater's state.
+ groups[0]
+ .process_message(commit_output.commit_message)
+ .await
+ .unwrap();
+ let new_member = groups[0].group.roster().member_with_index(0).unwrap();
+
+ assert_eq!(
+ new_member.signing_identity.credential,
+ get_test_basic_credential(b"member".to_vec())
+ );
+
+ assert_eq!(
+ new_member.signing_identity.signature_key,
+ identity.signature_key
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_commit_with_old_adds_fails() {
+ let mut groups = test_n_member_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, 2).await;
+
+ let key_package =
+ test_key_package_message(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "foobar").await;
+
+ let proposal = groups[0]
+ .group
+ .propose_add(key_package, vec![])
+ .await
+ .unwrap();
+
+ let commit = groups[0].group.commit(vec![]).await.unwrap().commit_message;
+
+ // 10 years from now
+ let future_time = MlsTime::now().seconds_since_epoch() + 10 * 365 * 24 * 3600;
+
+ let future_time =
+ MlsTime::from_duration_since_epoch(core::time::Duration::from_secs(future_time));
+
+ groups[1]
+ .group
+ .process_incoming_message(proposal)
+ .await
+ .unwrap();
+ let res = groups[1]
+ .group
+ .process_incoming_message_with_time(commit, future_time)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLifetime));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn custom_proposal_setup() -> (TestGroup, TestGroup) {
+ let mut alice = test_group_custom_config(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, |b| {
+ b.custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await;
+
+ let (bob, _) = alice
+ .join_with_custom_config("bob", true, |c| {
+ c.0.settings
+ .custom_proposal_types
+ .push(TEST_CUSTOM_PROPOSAL_TYPE)
+ })
+ .await
+ .unwrap();
+
+ (alice, bob)
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let commit = alice
+ .group
+ .commit_builder()
+ .custom_proposal(custom_proposal.clone())
+ .build()
+ .await
+ .unwrap()
+ .commit_message;
+
+ let res = bob.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_reference() {
+ let (mut alice, mut bob) = custom_proposal_setup().await;
+
+ let custom_proposal = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![0, 1, 2]);
+
+ let proposal = alice
+ .group
+ .propose_custom(custom_proposal.clone(), vec![])
+ .await
+ .unwrap();
+
+ let recv_prop = bob.group.process_incoming_message(proposal).await.unwrap();
+
+ assert_matches!(recv_prop, ReceivedMessage::Proposal(ProposalMessageDescription { proposal: Proposal::Custom(c), ..})
+ if c == custom_proposal);
+
+ let commit = bob.group.commit(vec![]).await.unwrap().commit_message;
+ let res = alice.group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(res, ReceivedMessage::Commit(CommitMessageDescription { state_update: StateUpdate { custom_proposals, .. }, .. })
+ if custom_proposals.len() == 1 && custom_proposals[0].proposal == custom_proposal);
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(res, ReceivedMessage::Commit(_));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn can_join_with_psk() {
+ let mut alice = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let (bob, key_pkg) =
+ test_client_with_key_pkg(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await;
+
+ let psk_id = ExternalPskId::new(vec![0]);
+ let psk = PreSharedKey::from(vec![0]);
+
+ alice
+ .config
+ .secret_store()
+ .insert(psk_id.clone(), psk.clone());
+
+ bob.config.secret_store().insert(psk_id.clone(), psk);
+
+ let commit = alice
+ .commit_builder()
+ .add_member(key_pkg)
+ .unwrap()
+ .add_external_psk(psk_id)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ bob.join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap();
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn invalid_update_does_not_prevent_other_updates() {
+ const EXTENSION_TYPE: ExtensionType = ExtensionType::new(33);
+
+ let group_extensions = ExtensionList::from(vec![RequiredCapabilitiesExt {
+ extensions: vec![EXTENSION_TYPE],
+ ..Default::default()
+ }
+ .into_extension()
+ .unwrap()]);
+
+ // Alice creates a group requiring support for an extension
+ let mut alice = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("alice", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build()
+ .create_group(group_extensions.clone())
+ .await
+ .unwrap();
+
+ let (bob_signing_identity, bob_secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"bob").await;
+
+ let bob_client = TestClientBuilder::new_for_test()
+ .signing_identity(
+ bob_signing_identity.clone(),
+ bob_secret_key.clone(),
+ TEST_CIPHER_SUITE,
+ )
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let carol_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("carol", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ let dave_client = TestClientBuilder::new_for_test()
+ .with_random_signing_identity("dave", TEST_CIPHER_SUITE)
+ .await
+ .extension_type(EXTENSION_TYPE)
+ .build();
+
+ // Alice adds Bob, Carol and Dave to the group. They all support the mandatory extension.
+ let commit = alice
+ .commit_builder()
+ .add_member(bob_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(carol_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .add_member(dave_client.generate_key_package_message().await.unwrap())
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let mut bob = bob_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ bob.write_to_storage().await.unwrap();
+
+ // Bob reloads his group data, but with parameters that will cause his generated leaves to
+ // not support the mandatory extension.
+ let mut bob = TestClientBuilder::new_for_test()
+ .signing_identity(bob_signing_identity, bob_secret_key, TEST_CIPHER_SUITE)
+ .key_package_repo(bob.config.key_package_repo())
+ .group_state_storage(bob.config.group_state_storage())
+ .build()
+ .load_group(alice.group_id())
+ .await
+ .unwrap();
+
+ let mut carol = carol_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ let mut dave = dave_client
+ .join_group(None, &commit.welcome_messages[0])
+ .await
+ .unwrap()
+ .0;
+
+ // Bob's updated leaf does not support the mandatory extension.
+ let bob_update = bob.propose_update(Vec::new()).await.unwrap();
+ let carol_update = carol.propose_update(Vec::new()).await.unwrap();
+ let dave_update = dave.propose_update(Vec::new()).await.unwrap();
+
+ // Alice receives the update proposals to be committed.
+ alice.process_incoming_message(bob_update).await.unwrap();
+ alice.process_incoming_message(carol_update).await.unwrap();
+ alice.process_incoming_message(dave_update).await.unwrap();
+
+ // Alice commits the update proposals.
+ alice.commit(Vec::new()).await.unwrap();
+ let commit_desc = alice.apply_pending_commit().await.unwrap();
+
+ let find_update_for = |id: &str| {
+ commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| u.prior.signing_identity.credential.as_basic())
+ .any(|c| c.identifier == id.as_bytes())
+ };
+
+ // Check that all updates preserve identities.
+ let identities_are_preserved = commit_desc
+ .state_update
+ .roster_update
+ .updated()
+ .iter()
+ .filter_map(|u| {
+ let before = &u.prior.signing_identity.credential.as_basic()?.identifier;
+ let after = &u.new.signing_identity.credential.as_basic()?.identifier;
+ Some((before, after))
+ })
+ .all(|(before, after)| before == after);
+
+ assert!(identities_are_preserved);
+
+ // Carol's and Dave's updates should be part of the commit.
+ assert!(find_update_for("carol"));
+ assert!(find_update_for("dave"));
+
+ // Bob's update should be rejected.
+ assert!(!find_update_for("bob"));
+
+ // Check that all members are still in the group.
+ let all_members_are_in = alice
+ .roster()
+ .members_iter()
+ .zip(["alice", "bob", "carol", "dave"])
+ .all(|(member, id)| {
+ member
+ .signing_identity
+ .credential
+ .as_basic()
+ .unwrap()
+ .identifier
+ == id.as_bytes()
+ });
+
+ assert!(all_members_are_in);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_may_enforce_path() {
+ test_custom_proposal_mls_rules(true).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_need_not_enforce_path() {
+ test_custom_proposal_mls_rules(false).await;
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_mls_rules(path_required_for_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let alice_pub_before = alice.current_user_leaf_node().unwrap().public_key.clone();
+
+ let kp = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .generate_key_package_message()
+ .await
+ .unwrap();
+
+ alice
+ .commit_builder()
+ .custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .add_member(kp)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice.apply_pending_commit().await.unwrap();
+
+ let alice_pub_after = &alice.current_user_leaf_node().unwrap().public_key;
+
+ if path_required_for_custom {
+ assert_ne!(alice_pub_after, &alice_pub_before);
+ } else {
+ assert_eq!(alice_pub_after, &alice_pub_before);
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(true).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_value_in_external_join_may_not_be_allowed() {
+ test_custom_proposal_by_value_in_external_join(false).await
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_custom_proposal_by_value_in_external_join(external_joiner_can_send_custom: bool) {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let commit = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_custom_proposal(CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]))
+ .build(group_info)
+ .await;
+
+ if external_joiner_can_send_custom {
+ let commit = commit.unwrap().1;
+ alice.process_incoming_message(commit).await.unwrap();
+ } else {
+ assert_matches!(commit.map(|_| ()), Err(MlsError::MlsRulesError(_)));
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_by_ref_in_external_join() {
+ let mls_rules = CustomMlsRules {
+ path_required_for_custom: true,
+ external_joiner_can_send_custom: true,
+ };
+
+ let mut alice = client_with_custom_rules(b"alice", mls_rules.clone())
+ .await
+ .create_group(Default::default())
+ .await
+ .unwrap();
+
+ let by_ref = CustomProposal::new(TEST_CUSTOM_PROPOSAL_TYPE, vec![]);
+ let by_ref = alice.propose_custom(by_ref, vec![]).await.unwrap();
+
+ let group_info = alice
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (_, commit) = client_with_custom_rules(b"bob", mls_rules)
+ .await
+ .external_commit_builder()
+ .unwrap()
+ .with_received_custom_proposal(by_ref)
+ .build(group_info)
+ .await
+ .unwrap();
+
+ alice.process_incoming_message(commit).await.unwrap();
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn client_with_custom_rules(
+ name: &[u8],
+ mls_rules: CustomMlsRules,
+ ) -> Client<impl MlsConfig> {
+ let (signing_identity, signer) = get_test_signing_identity(TEST_CIPHER_SUITE, name).await;
+
+ ClientBuilder::new()
+ .crypto_provider(TestCryptoProvider::new())
+ .identity_provider(BasicWithCustomProvider::new(BasicIdentityProvider::new()))
+ .signing_identity(signing_identity, signer, TEST_CIPHER_SUITE)
+ .custom_proposal_type(TEST_CUSTOM_PROPOSAL_TYPE)
+ .mls_rules(mls_rules)
+ .build()
+ }
+
+ #[derive(Debug, Clone)]
+ struct CustomMlsRules {
+ path_required_for_custom: bool,
+ external_joiner_can_send_custom: bool,
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ impl ProposalBundle {
+ fn has_test_custom_proposal(&self) -> bool {
+ self.custom_proposal_types()
+ .any(|t| t == TEST_CUSTOM_PROPOSAL_TYPE)
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl crate::MlsRules for CustomMlsRules {
+ type Error = MlsError;
+
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: &ProposalBundle,
+ ) -> Result<CommitOptions, MlsError> {
+ Ok(CommitOptions::default().with_path_required(
+ !proposals.has_test_custom_proposal() || self.path_required_for_custom,
+ ))
+ }
+
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<crate::mls_rules::EncryptionOptions, MlsError> {
+ Ok(Default::default())
+ }
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ sender: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, MlsError> {
+ let is_external = matches!(sender, CommitSource::NewMember(_));
+ let has_custom = proposals.has_test_custom_proposal();
+ let allowed = !has_custom || !is_external || self.external_joiner_can_send_custom;
+
+ allowed.then_some(proposals).ok_or(MlsError::InvalidSender)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn group_can_receive_commit_from_self() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE)
+ .await
+ .group;
+
+ let commit = group.commit(vec![]).await.unwrap();
+
+ let update = group
+ .process_incoming_message(commit.commit_message)
+ .await
+ .unwrap();
+
+ let ReceivedMessage::Commit(update) = update else {
+ panic!("expected commit message")
+ };
+
+ assert_eq!(update.committer, *group.private_tree.self_index);
+ }
+}
diff --git a/src/group/padding.rs b/src/group/padding.rs
new file mode 100644
index 0000000..6320ccf
--- /dev/null
+++ b/src/group/padding.rs
@@ -0,0 +1,109 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Padding used when sending an encrypted group message.
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
+#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
+#[repr(u8)]
+pub enum PaddingMode {
+ /// Step function based on the size of the message being sent.
+ /// The amount of padding used will increase with the size of the original
+ /// message.
+ #[default]
+ StepFunction,
+ /// No padding.
+ None,
+}
+
+impl PaddingMode {
+ pub(super) fn padded_size(&self, content_size: usize) -> usize {
+ match self {
+ PaddingMode::StepFunction => {
+ // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
+ // by zeros and then the next number is taken to make sure the message fits.
+ let blind = 1
+ << ((content_size + 1)
+ .next_power_of_two()
+ .max(256)
+ .trailing_zeros()
+ - 3);
+
+ (content_size | (blind - 1)) + 1
+ }
+ PaddingMode::None => content_size,
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::PaddingMode;
+
+ use alloc::vec;
+ use alloc::vec::Vec;
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ input: usize,
+ output: usize,
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_message_padding_test_vector() -> Vec<TestCase> {
+ let mut test_cases = vec![];
+ for x in 1..1024 {
+ test_cases.push(TestCase {
+ input: x,
+ output: PaddingMode::StepFunction.padded_size(x),
+ });
+ }
+ test_cases
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(
+ message_padding_test_vector,
+ generate_message_padding_test_vector()
+ )
+ }
+
+ #[test]
+ fn test_no_padding() {
+ for i in [0, 100, 1000, 10000] {
+ assert_eq!(PaddingMode::None.padded_size(i), i)
+ }
+ }
+
+ #[test]
+ fn test_padding_length() {
+ assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
+
+ // Short
+ assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
+ assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
+ assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
+
+ // Almost long and almost short
+ assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
+ assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
+ assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
+
+ // One length from each of the 4 buckets between 256 and 512
+ assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
+ assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
+ assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
+ assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
+
+ // All test cases
+ let test_cases: Vec<TestCase> = load_test_cases();
+ for test_case in test_cases {
+ assert_eq!(
+ test_case.output,
+ PaddingMode::StepFunction.padded_size(test_case.input)
+ );
+ }
+ }
+}
diff --git a/src/group/proposal.rs b/src/group/proposal.rs
new file mode 100644
index 0000000..a31be29
--- /dev/null
+++ b/src/group/proposal.rs
@@ -0,0 +1,578 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{boxed::Box, vec::Vec};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::tree_kem::leaf_node::LeafNode;
+
+use crate::{
+ client::MlsError, tree_kem::node::LeafIndex, CipherSuite, KeyPackage, MlsMessage,
+ ProtocolVersion,
+};
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{group::Capabilities, identity::SigningIdentity};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::proposal_ref::ProposalRef;
+
+pub use mls_rs_core::extension::ExtensionList;
+pub use mls_rs_core::group::ProposalType;
+
+#[cfg(feature = "psk")]
+use crate::psk::{ExternalPskId, JustPreSharedKeyID, PreSharedKeyID};
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that adds a member to a [`Group`](crate::group::Group).
+pub struct AddProposal {
+ pub(crate) key_package: KeyPackage,
+}
+
+impl AddProposal {
+ /// The [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ self.key_package.signing_identity()
+ }
+
+ /// Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be added by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.key_package.leaf_node.ungreased_capabilities()
+ }
+
+ /// Key package extensions that are assoiciated with the
+ /// [`Member`](mls_rs_core::group::Member) that will be added by this proposal.
+ pub fn key_package_extensions(&self) -> ExtensionList {
+ self.key_package.ungreased_extensions()
+ }
+
+ /// Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that will be added.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.key_package.leaf_node.ungreased_extensions()
+ }
+}
+
+impl From<KeyPackage> for AddProposal {
+ fn from(key_package: KeyPackage) -> Self {
+ Self { key_package }
+ }
+}
+
+impl TryFrom<MlsMessage> for AddProposal {
+ type Error = MlsError;
+
+ fn try_from(value: MlsMessage) -> Result<Self, Self::Error> {
+ value
+ .into_key_package()
+ .ok_or(MlsError::UnexpectedMessageType)
+ .map(Into::into)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal that will update an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct UpdateProposal {
+ pub(crate) leaf_node: LeafNode,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl UpdateProposal {
+ /// The new [`SigningIdentity`] of the [`Member`](mls_rs_core::group::Member)
+ /// that is being updated by this proposal.
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ &self.leaf_node.signing_identity
+ }
+
+ /// New Client [`Capabilities`] of the [`Member`](mls_rs_core::group::Member)
+ /// that will be updated by this proposal.
+ pub fn capabilities(&self) -> Capabilities {
+ self.leaf_node.ungreased_capabilities()
+ }
+
+ /// New Leaf node extensions that will be entered into the group state for the
+ /// [`Member`](mls_rs_core::group::Member) that is being updated by this proposal.
+ pub fn leaf_node_extensions(&self) -> ExtensionList {
+ self.leaf_node.ungreased_extensions()
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to remove an existing [`Member`](mls_rs_core::group::Member) of a
+/// [`Group`](crate::group::Group).
+pub struct RemoveProposal {
+ pub(crate) to_remove: LeafIndex,
+}
+
+impl RemoveProposal {
+ /// The index of the [`Member`](mls_rs_core::group::Member) that will be removed by
+ /// this proposal.
+ pub fn to_remove(&self) -> u32 {
+ *self.to_remove
+ }
+}
+
+impl From<u32> for RemoveProposal {
+ fn from(value: u32) -> Self {
+ RemoveProposal {
+ to_remove: LeafIndex(value),
+ }
+ }
+}
+
+#[cfg(feature = "psk")]
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to add a pre-shared key to a group.
+pub struct PreSharedKeyProposal {
+ pub(crate) psk: PreSharedKeyID,
+}
+
+#[cfg(feature = "psk")]
+impl PreSharedKeyProposal {
+ /// The external pre-shared key id of this proposal.
+ ///
+ /// MLS requires the pre-shared key type for PreSharedKeyProposal to be of
+ /// type `External`.
+ ///
+ /// Returns `None` in the condition that the underlying psk is not external.
+ pub fn external_psk_id(&self) -> Option<&ExternalPskId> {
+ match self.psk.key_id {
+ JustPreSharedKeyID::External(ref ext) => Some(ext),
+ JustPreSharedKeyID::Resumption(_) => None,
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal to reinitialize a group using new parameters.
+pub struct ReInitProposal {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) group_id: Vec<u8>,
+ pub(crate) version: ProtocolVersion,
+ pub(crate) cipher_suite: CipherSuite,
+ pub(crate) extensions: ExtensionList,
+}
+
+impl Debug for ReInitProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ReInitProposal")
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("version", &self.version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field("extensions", &self.extensions)
+ .finish()
+ }
+}
+
+impl ReInitProposal {
+ /// The unique id of the new group post reinitialization.
+ pub fn group_id(&self) -> &[u8] {
+ &self.group_id
+ }
+
+ /// The new protocol version to use post reinitialization.
+ pub fn new_version(&self) -> ProtocolVersion {
+ self.version
+ }
+
+ /// The new ciphersuite to use post reinitialization.
+ pub fn new_cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ /// Group context extensions to set in the new group post reinitialization.
+ pub fn new_group_context_extensions(&self) -> &ExtensionList {
+ &self.extensions
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A proposal used for external commits.
+pub struct ExternalInit {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub(crate) kem_output: Vec<u8>,
+}
+
+impl Debug for ExternalInit {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ExternalInit")
+ .field(
+ "kem_output",
+ &mls_rs_core::debug::pretty_bytes(&self.kem_output),
+ )
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[derive(Clone, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A user defined custom proposal.
+///
+/// User defined proposals are passed through the protocol as an opaque value.
+pub struct CustomProposal {
+ proposal_type: ProposalType,
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ data: Vec<u8>,
+}
+
+#[cfg(feature = "custom_proposal")]
+impl Debug for CustomProposal {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("CustomProposal")
+ .field("proposal_type", &self.proposal_type)
+ .field("data", &mls_rs_core::debug::pretty_bytes(&self.data))
+ .finish()
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl CustomProposal {
+ /// Create a custom proposal.
+ ///
+ /// # Warning
+ ///
+ /// Avoid using the [`ProposalType`] values that have constants already
+ /// defined by this crate. Using existing constants in a custom proposal
+ /// has unspecified behavior.
+ pub fn new(proposal_type: ProposalType, data: Vec<u8>) -> Self {
+ Self {
+ proposal_type,
+ data,
+ }
+ }
+
+ /// The proposal type used for this custom proposal.
+ pub fn proposal_type(&self) -> ProposalType {
+ self.proposal_type
+ }
+
+ /// The opaque data communicated by this custom proposal.
+ pub fn data(&self) -> &[u8] {
+ &self.data
+ }
+}
+
+/// Trait to simplify creating custom proposals that are serialized with MLS
+/// encoding.
+#[cfg(feature = "custom_proposal")]
+pub trait MlsCustomProposal: MlsSize + MlsEncode + MlsDecode + Sized {
+ fn proposal_type() -> ProposalType;
+
+ fn to_custom_proposal(&self) -> Result<CustomProposal, mls_rs_codec::Error> {
+ Ok(CustomProposal::new(
+ Self::proposal_type(),
+ self.mls_encode_to_vec()?,
+ ))
+ }
+
+ fn from_custom_proposal(proposal: &CustomProposal) -> Result<Self, mls_rs_codec::Error> {
+ if proposal.proposal_type() != Self::proposal_type() {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "invalid proposal type".to_string(),
+ // ));
+
+ //#[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(4));
+ }
+
+ Self::mls_decode(&mut proposal.data())
+ }
+}
+
+#[allow(clippy::large_enum_variant)]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u16)]
+#[non_exhaustive]
+/// An enum that represents all possible types of proposals.
+pub enum Proposal {
+ Add(alloc::boxed::Box<AddProposal>),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(UpdateProposal),
+ Remove(RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(PreSharedKeyProposal),
+ ReInit(ReInitProposal),
+ ExternalInit(ExternalInit),
+ GroupContextExtensions(ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(CustomProposal),
+}
+
+impl MlsSize for Proposal {
+ fn mls_encoded_len(&self) -> usize {
+ let inner_len = match self {
+ Proposal::Add(p) => p.mls_encoded_len(),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encoded_len(),
+ Proposal::Remove(p) => p.mls_encoded_len(),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encoded_len(),
+ Proposal::ReInit(p) => p.mls_encoded_len(),
+ Proposal::ExternalInit(p) => p.mls_encoded_len(),
+ Proposal::GroupContextExtensions(p) => p.mls_encoded_len(),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => mls_rs_codec::byte_vec::mls_encoded_len(&p.data),
+ };
+
+ self.proposal_type().mls_encoded_len() + inner_len
+ }
+}
+
+impl MlsEncode for Proposal {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.proposal_type().mls_encode(writer)?;
+
+ match self {
+ Proposal::Add(p) => p.mls_encode(writer),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => p.mls_encode(writer),
+ Proposal::Remove(p) => p.mls_encode(writer),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => p.mls_encode(writer),
+ Proposal::ReInit(p) => p.mls_encode(writer),
+ Proposal::ExternalInit(p) => p.mls_encode(writer),
+ Proposal::GroupContextExtensions(p) => p.mls_encode(writer),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => {
+ if p.proposal_type.raw_value() <= 7 {
+ // #[cfg(feature = "std")]
+ // return Err(mls_rs_codec::Error::Custom(
+ // "custom proposal types can not be set to defined values of 0-7".to_string(),
+ // ));
+
+ // #[cfg(not(feature = "std"))]
+ return Err(mls_rs_codec::Error::Custom(2));
+ }
+ mls_rs_codec::byte_vec::mls_encode(&p.data, writer)
+ }
+ }
+ }
+}
+
+impl MlsDecode for Proposal {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ let proposal_type = ProposalType::mls_decode(reader)?;
+
+ Ok(match proposal_type {
+ ProposalType::ADD => {
+ Proposal::Add(alloc::boxed::Box::new(AddProposal::mls_decode(reader)?))
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalType::UPDATE => Proposal::Update(UpdateProposal::mls_decode(reader)?),
+ ProposalType::REMOVE => Proposal::Remove(RemoveProposal::mls_decode(reader)?),
+ #[cfg(feature = "psk")]
+ ProposalType::PSK => Proposal::Psk(PreSharedKeyProposal::mls_decode(reader)?),
+ ProposalType::RE_INIT => Proposal::ReInit(ReInitProposal::mls_decode(reader)?),
+ ProposalType::EXTERNAL_INIT => {
+ Proposal::ExternalInit(ExternalInit::mls_decode(reader)?)
+ }
+ ProposalType::GROUP_CONTEXT_EXTENSIONS => {
+ Proposal::GroupContextExtensions(ExtensionList::mls_decode(reader)?)
+ }
+ #[cfg(feature = "custom_proposal")]
+ custom => Proposal::Custom(CustomProposal {
+ proposal_type: custom,
+ data: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ }),
+ // TODO fix test dependency on openssl loading codec with default features
+ #[cfg(not(feature = "custom_proposal"))]
+ _ => return Err(mls_rs_codec::Error::Custom(3)),
+ })
+ }
+}
+
+impl Proposal {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ Proposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(_) => ProposalType::UPDATE,
+ Proposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ Proposal::Psk(_) => ProposalType::PSK,
+ Proposal::ReInit(_) => ProposalType::RE_INIT,
+ Proposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ Proposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq)]
+/// An enum that represents a borrowed version of [`Proposal`].
+pub enum BorrowedProposal<'a> {
+ Add(&'a AddProposal),
+ #[cfg(feature = "by_ref_proposal")]
+ Update(&'a UpdateProposal),
+ Remove(&'a RemoveProposal),
+ #[cfg(feature = "psk")]
+ Psk(&'a PreSharedKeyProposal),
+ ReInit(&'a ReInitProposal),
+ ExternalInit(&'a ExternalInit),
+ GroupContextExtensions(&'a ExtensionList),
+ #[cfg(feature = "custom_proposal")]
+ Custom(&'a CustomProposal),
+}
+
+impl<'a> From<BorrowedProposal<'a>> for Proposal {
+ fn from(value: BorrowedProposal<'a>) -> Self {
+ match value {
+ BorrowedProposal::Add(add) => Proposal::Add(alloc::boxed::Box::new(add.clone())),
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(update) => Proposal::Update(update.clone()),
+ BorrowedProposal::Remove(remove) => Proposal::Remove(remove.clone()),
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(psk) => Proposal::Psk(psk.clone()),
+ BorrowedProposal::ReInit(reinit) => Proposal::ReInit(reinit.clone()),
+ BorrowedProposal::ExternalInit(external) => Proposal::ExternalInit(external.clone()),
+ BorrowedProposal::GroupContextExtensions(ext) => {
+ Proposal::GroupContextExtensions(ext.clone())
+ }
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(custom) => Proposal::Custom(custom.clone()),
+ }
+ }
+}
+
+impl BorrowedProposal<'_> {
+ pub fn proposal_type(&self) -> ProposalType {
+ match self {
+ BorrowedProposal::Add(_) => ProposalType::ADD,
+ #[cfg(feature = "by_ref_proposal")]
+ BorrowedProposal::Update(_) => ProposalType::UPDATE,
+ BorrowedProposal::Remove(_) => ProposalType::REMOVE,
+ #[cfg(feature = "psk")]
+ BorrowedProposal::Psk(_) => ProposalType::PSK,
+ BorrowedProposal::ReInit(_) => ProposalType::RE_INIT,
+ BorrowedProposal::ExternalInit(_) => ProposalType::EXTERNAL_INIT,
+ BorrowedProposal::GroupContextExtensions(_) => ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ #[cfg(feature = "custom_proposal")]
+ BorrowedProposal::Custom(c) => c.proposal_type,
+ }
+ }
+}
+
+impl<'a> From<&'a Proposal> for BorrowedProposal<'a> {
+ fn from(p: &'a Proposal) -> Self {
+ match p {
+ Proposal::Add(p) => BorrowedProposal::Add(p),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(p) => BorrowedProposal::Update(p),
+ Proposal::Remove(p) => BorrowedProposal::Remove(p),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(p) => BorrowedProposal::Psk(p),
+ Proposal::ReInit(p) => BorrowedProposal::ReInit(p),
+ Proposal::ExternalInit(p) => BorrowedProposal::ExternalInit(p),
+ Proposal::GroupContextExtensions(p) => BorrowedProposal::GroupContextExtensions(p),
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(p) => BorrowedProposal::Custom(p),
+ }
+ }
+}
+
+impl<'a> From<&'a AddProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a AddProposal) -> Self {
+ Self::Add(p)
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> From<&'a UpdateProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a UpdateProposal) -> Self {
+ Self::Update(p)
+ }
+}
+
+impl<'a> From<&'a RemoveProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a RemoveProposal) -> Self {
+ Self::Remove(p)
+ }
+}
+
+#[cfg(feature = "psk")]
+impl<'a> From<&'a PreSharedKeyProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a PreSharedKeyProposal) -> Self {
+ Self::Psk(p)
+ }
+}
+
+impl<'a> From<&'a ReInitProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a ReInitProposal) -> Self {
+ Self::ReInit(p)
+ }
+}
+
+impl<'a> From<&'a ExternalInit> for BorrowedProposal<'a> {
+ fn from(p: &'a ExternalInit) -> Self {
+ Self::ExternalInit(p)
+ }
+}
+
+impl<'a> From<&'a ExtensionList> for BorrowedProposal<'a> {
+ fn from(p: &'a ExtensionList) -> Self {
+ Self::GroupContextExtensions(p)
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+impl<'a> From<&'a CustomProposal> for BorrowedProposal<'a> {
+ fn from(p: &'a CustomProposal) -> Self {
+ Self::Custom(p)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum ProposalOrRef {
+ Proposal(Box<Proposal>) = 1u8,
+ #[cfg(feature = "by_ref_proposal")]
+ Reference(ProposalRef) = 2u8,
+}
+
+impl From<Proposal> for ProposalOrRef {
+ fn from(proposal: Proposal) -> Self {
+ Self::Proposal(Box::new(proposal))
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl From<ProposalRef> for ProposalOrRef {
+ fn from(r: ProposalRef) -> Self {
+ Self::Reference(r)
+ }
+}
diff --git a/src/group/proposal_cache.rs b/src/group/proposal_cache.rs
new file mode 100644
index 0000000..17acf79
--- /dev/null
+++ b/src/group/proposal_cache.rs
@@ -0,0 +1,4216 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use super::{
+ message_processor::ProvisionalState,
+ mls_rules::{CommitDirection, CommitSource, MlsRules},
+ GroupState, ProposalOrRef,
+};
+use crate::{
+ client::MlsError,
+ group::{
+ proposal_filter::{ProposalApplier, ProposalBundle, ProposalSource},
+ Proposal, Sender,
+ },
+ time::MlsTime,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_filter::FilterStrategy, ProposalRef, ProtocolVersion};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(feature = "by_ref_proposal")]
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::{
+ crypto::CipherSuiteProvider, error::IntoAnyError, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use core::fmt::{self, Debug};
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct CachedProposal {
+ pub(crate) proposal: Proposal,
+ pub(crate) sender: Sender,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[derive(Clone, PartialEq)]
+pub(crate) struct ProposalCache {
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Debug for ProposalCache {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("ProposalCache")
+ .field("protocol_version", &self.protocol_version)
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("proposals", &self.proposals)
+ .finish()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl ProposalCache {
+ pub fn new(protocol_version: ProtocolVersion, group_id: Vec<u8>) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals: Default::default(),
+ }
+ }
+
+ pub fn import(
+ protocol_version: ProtocolVersion,
+ group_id: Vec<u8>,
+ #[cfg(feature = "std")] proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(not(feature = "std"))] proposals: Vec<(ProposalRef, CachedProposal)>,
+ ) -> Self {
+ Self {
+ protocol_version,
+ group_id,
+ proposals,
+ }
+ }
+
+ #[inline]
+ pub fn clear(&mut self) {
+ self.proposals.clear();
+ }
+
+ #[cfg(feature = "private_message")]
+ #[inline]
+ pub fn is_empty(&self) -> bool {
+ self.proposals.is_empty()
+ }
+
+ pub fn insert(&mut self, proposal_ref: ProposalRef, proposal: Proposal, sender: Sender) {
+ let cached_proposal = CachedProposal { proposal, sender };
+
+ #[cfg(feature = "std")]
+ self.proposals.insert(proposal_ref, cached_proposal);
+
+ #[cfg(not(feature = "std"))]
+ // This may result in dups but it does not matter
+ self.proposals.push((proposal_ref, cached_proposal));
+ }
+
+ pub fn prepare_commit(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ ) -> ProposalBundle {
+ self.proposals
+ .iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .chain(
+ additional_proposals
+ .into_iter()
+ .map(|p| (p, sender, ProposalSource::ByValue)),
+ )
+ .collect()
+ }
+
+ pub fn resolve_for_commit(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ match p {
+ ProposalOrRef::Proposal(p) => proposals.add(*p, sender, ProposalSource::ByValue),
+ ProposalOrRef::Reference(r) => {
+ #[cfg(feature = "std")]
+ let p = self
+ .proposals
+ .get(&r)
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+ #[cfg(not(feature = "std"))]
+ let p = self
+ .proposals
+ .iter()
+ .find_map(|(rr, p)| (rr == &r).then_some(p))
+ .ok_or(MlsError::ProposalNotFound)?
+ .clone();
+
+ proposals.add(p.proposal, p.sender, ProposalSource::ByReference(r));
+ }
+ };
+ }
+
+ Ok(proposals)
+ }
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn prepare_commit(
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+) -> ProposalBundle {
+ let mut proposals = ProposalBundle::default();
+
+ for p in additional_proposals.into_iter() {
+ proposals.add(p, sender, ProposalSource::ByValue);
+ }
+
+ proposals
+}
+
+#[cfg(not(feature = "by_ref_proposal"))]
+pub(crate) fn resolve_for_commit(
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+) -> Result<ProposalBundle, MlsError> {
+ let mut proposals = ProposalBundle::default();
+
+ for p in proposal_list {
+ let ProposalOrRef::Proposal(p) = p;
+ proposals.add(*p, sender, ProposalSource::ByValue);
+ }
+
+ Ok(proposals)
+}
+
+impl GroupState {
+ #[inline(never)]
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_resolved<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ mut proposals: ProposalBundle,
+ external_leaf: Option<&LeafNode>,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ psk_storage: &P,
+ user_rules: &F,
+ commit_time: Option<MlsTime>,
+ direction: CommitDirection,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let roster = self.public_tree.roster();
+ let group_extensions = &self.context.extensions;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let all_proposals = proposals.clone();
+
+ let origin = match sender {
+ Sender::Member(index) => Ok::<_, MlsError>(CommitSource::ExistingMember(
+ roster.member_with_index(index)?,
+ )),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::InvalidSender),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::InvalidSender),
+ Sender::NewMemberCommit => Ok(CommitSource::NewMember(
+ external_leaf
+ .map(|l| l.signing_identity.clone())
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?,
+ )),
+ }?;
+
+ proposals = user_rules
+ .filter_proposals(direction, origin, &roster, group_extensions, proposals)
+ .await
+ .map_err(|e| MlsError::MlsRulesError(e.into_any_error()))?;
+
+ let applier = ProposalApplier::new(
+ &self.public_tree,
+ self.context.protocol_version,
+ cipher_suite_provider,
+ group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ &self.context.group_id,
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let applier_output = match direction {
+ CommitDirection::Send => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreByRef, &sender, proposals, commit_time)
+ .await?
+ }
+ CommitDirection::Receive => {
+ applier
+ .apply_proposals(FilterStrategy::IgnoreNone, &sender, proposals, commit_time)
+ .await?
+ }
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ let applier_output = applier
+ .apply_proposals(&sender, &proposals, commit_time)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let unused_proposals = unused_proposals(
+ match direction {
+ CommitDirection::Send => all_proposals,
+ CommitDirection::Receive => self.proposals.proposals.iter().collect(),
+ },
+ &applier_output.applied_proposals,
+ );
+
+ let mut group_context = self.context.clone();
+ group_context.epoch += 1;
+
+ if let Some(ext) = applier_output.new_context_extensions {
+ group_context.extensions = ext;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = applier_output.applied_proposals;
+
+ Ok(ProvisionalState {
+ public_tree: applier_output.new_tree,
+ group_context,
+ applied_proposals: proposals,
+ external_init_index: applier_output.external_init_index,
+ indexes_of_added_kpkgs: applier_output.indexes_of_added_kpkgs,
+ #[cfg(feature = "by_ref_proposal")]
+ unused_proposals,
+ })
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl Extend<(ProposalRef, CachedProposal)> for ProposalCache {
+ fn extend<T>(&mut self, iter: T)
+ where
+ T: IntoIterator<Item = (ProposalRef, CachedProposal)>,
+ {
+ self.proposals.extend(iter);
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn has_ref(proposals: &ProposalBundle, reference: &ProposalRef) -> bool {
+ proposals
+ .iter_proposals()
+ .any(|p| matches!(&p.source, ProposalSource::ByReference(r) if r == reference))
+}
+
+#[cfg(feature = "by_ref_proposal")]
+fn unused_proposals(
+ all_proposals: ProposalBundle,
+ accepted_proposals: &ProposalBundle,
+) -> Vec<crate::mls_rules::ProposalInfo<Proposal>> {
+ all_proposals
+ .into_proposals()
+ .filter(|p| {
+ matches!(p.source, ProposalSource::ByReference(ref r) if !has_ref(accepted_proposals, r)
+ )
+ })
+ .collect()
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) mod test_utils {
+ use mls_rs_core::{
+ crypto::CipherSuiteProvider, extension::ExtensionList, identity::IdentityProvider,
+ psk::PreSharedKeyStorage,
+ };
+
+ use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ mls_rules::{CommitDirection, DefaultMlsRules, MlsRules},
+ proposal::{Proposal, ProposalOrRef},
+ proposal_ref::ProposalRef,
+ state::GroupState,
+ test_utils::{get_test_group_context, TEST_GROUP},
+ GroupContext, LeafIndex, LeafNode, ProvisionalState, Sender, TreeKemPublic,
+ },
+ identity::{basic::BasicIdentityProvider, test_utils::BasicWithCustomProvider},
+ psk::AlwaysFoundPskStorage,
+ };
+
+ use super::{CachedProposal, MlsError, ProposalCache};
+
+ use alloc::vec::Vec;
+
+ impl CachedProposal {
+ pub fn new(proposal: Proposal, sender: Sender) -> Self {
+ Self { proposal, sender }
+ }
+ }
+
+ #[derive(Debug)]
+ pub(crate) struct CommitReceiver<'a, C, F, P, CSP> {
+ tree: &'a TreeKemPublic,
+ sender: Sender,
+ receiver: LeafIndex,
+ cache: ProposalCache,
+ identity_provider: C,
+ cipher_suite_provider: CSP,
+ group_context_extensions: ExtensionList,
+ user_rules: F,
+ with_psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitReceiver<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ pub fn new<S>(
+ tree: &'a TreeKemPublic,
+ sender: S,
+ receiver: LeafIndex,
+ cipher_suite_provider: CSP,
+ ) -> Self
+ where
+ S: Into<Sender>,
+ {
+ Self {
+ tree,
+ sender: sender.into(),
+ receiver,
+ cache: make_proposal_cache(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider),
+ group_context_extensions: Default::default(),
+ user_rules: pass_through_rules(),
+ with_psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitReceiver<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_identity_provider<V>(self, validator: V) -> CommitReceiver<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: validator,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_user_rules<G>(self, f: G) -> CommitReceiver<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: f,
+ with_psk_storage: self.with_psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ pub fn with_psk_storage<V>(self, v: V) -> CommitReceiver<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitReceiver {
+ tree: self.tree,
+ sender: self.sender,
+ receiver: self.receiver,
+ cache: self.cache,
+ identity_provider: self.identity_provider,
+ group_context_extensions: self.group_context_extensions,
+ user_rules: self.user_rules,
+ with_psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn with_extensions(self, extensions: ExtensionList) -> Self {
+ Self {
+ group_context_extensions: extensions,
+ ..self
+ }
+ }
+
+ pub fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn receive<I>(&self, proposals: I) -> Result<ProvisionalState, MlsError>
+ where
+ I: IntoIterator,
+ I::Item: Into<ProposalOrRef>,
+ {
+ self.cache
+ .resolve_for_commit_default(
+ self.sender,
+ proposals.into_iter().map(Into::into).collect(),
+ None,
+ &self.group_context_extensions,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ &self.with_psk_storage,
+ &self.user_rules,
+ )
+ .await
+ }
+ }
+
+ pub(crate) fn make_proposal_cache() -> ProposalCache {
+ ProposalCache::new(TEST_PROTOCOL_VERSION, TEST_GROUP.to_vec())
+ }
+
+ pub fn pass_through_rules() -> DefaultMlsRules {
+ DefaultMlsRules::new()
+ }
+
+ impl ProposalCache {
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resolve_for_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ proposal_list: Vec<ProposalOrRef>,
+ external_leaf: Option<&LeafNode>,
+ group_extensions: &ExtensionList,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let mut context =
+ get_test_group_context(123, cipher_suite_provider.cipher_suite()).await;
+
+ context.extensions = group_extensions.clone();
+
+ let mut state = GroupState::new(
+ context,
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ state.proposals.proposals = self.proposals.clone();
+ let proposals = self.resolve_for_commit(sender, proposal_list)?;
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Receive,
+ )
+ .await
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn prepare_commit_default<C, F, P, CSP>(
+ &self,
+ sender: Sender,
+ additional_proposals: Vec<Proposal>,
+ context: &GroupContext,
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ public_tree: &TreeKemPublic,
+ external_leaf: Option<&LeafNode>,
+ psk_storage: &P,
+ user_rules: F,
+ ) -> Result<ProvisionalState, MlsError>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ let state = GroupState::new(
+ context.clone(),
+ public_tree.clone(),
+ Vec::new().into(),
+ ConfirmationTag::empty(cipher_suite_provider).await,
+ );
+
+ let proposals = self.prepare_commit(sender, additional_proposals);
+
+ state
+ .apply_resolved(
+ sender,
+ proposals,
+ external_leaf,
+ identity_provider,
+ cipher_suite_provider,
+ psk_storage,
+ &user_rules,
+ None,
+ CommitDirection::Send,
+ )
+ .await
+ }
+ }
+}
+
+// TODO add tests for lite version of filtering
+#[cfg(all(feature = "by_ref_proposal", test))]
+mod tests {
+ use alloc::{boxed::Box, vec, vec::Vec};
+
+ use super::test_utils::{make_proposal_cache, pass_through_rules, CommitReceiver};
+ use super::{CachedProposal, ProposalCache};
+ use crate::client::MlsError;
+ use crate::group::message_processor::ProvisionalState;
+ use crate::group::mls_rules::{CommitDirection, CommitSource, EncryptionOptions};
+ use crate::group::proposal_filter::{ProposalBundle, ProposalInfo, ProposalSource};
+ use crate::group::proposal_ref::test_utils::auth_content_from_proposal;
+ use crate::group::proposal_ref::ProposalRef;
+ use crate::group::{
+ AddProposal, AuthenticatedContent, Content, ExternalInit, Proposal, ProposalOrRef,
+ ReInitProposal, RemoveProposal, Roster, Sender, UpdateProposal,
+ };
+ use crate::key_package::test_utils::test_key_package_with_signer;
+ use crate::signer::Signable;
+ use crate::tree_kem::leaf_node::LeafNode;
+ use crate::tree_kem::node::LeafIndex;
+ use crate::tree_kem::TreeKemPublic;
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::{self, test_utils::test_cipher_suite_provider},
+ extension::test_utils::TestExtension,
+ group::{
+ message_processor::path_update_required,
+ proposal_filter::proposer_can_propose,
+ test_utils::{get_test_group_context, random_bytes, test_group, TEST_GROUP},
+ },
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ key_package::{test_utils::test_key_package, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ psk::AlwaysFoundPskStorage,
+ tree_kem::{
+ leaf_node::{
+ test_utils::{
+ default_properties, get_basic_test_node, get_basic_test_node_capabilities,
+ get_basic_test_node_sig_key, get_test_capabilities,
+ },
+ ConfigProperties, LeafNodeSigningContext, LeafNodeSource,
+ },
+ Lifetime,
+ },
+ };
+ use crate::{KeyPackage, MlsRules};
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ extension::ExternalSendersExt,
+ tree_kem::leaf_node_validator::test_utils::FailureIdentityProvider,
+ };
+
+ #[cfg(feature = "psk")]
+ use crate::{
+ group::proposal::PreSharedKeyProposal,
+ psk::{
+ ExternalPskId, JustPreSharedKeyID, PreSharedKeyID, PskGroupId, PskNonce,
+ ResumptionPSKUsage, ResumptionPsk,
+ },
+ };
+
+ #[cfg(feature = "custom_proposal")]
+ use crate::group::proposal::CustomProposal;
+
+ use assert_matches::assert_matches;
+ use core::convert::Infallible;
+ use itertools::Itertools;
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use mls_rs_core::extension::ExtensionList;
+ use mls_rs_core::group::{Capabilities, ProposalType};
+ use mls_rs_core::identity::IdentityProvider;
+ use mls_rs_core::protocol_version::ProtocolVersion;
+ use mls_rs_core::psk::{PreSharedKey, PreSharedKeyStorage};
+ use mls_rs_core::{
+ extension::MlsExtension,
+ identity::{Credential, CredentialType, CustomCredential},
+ };
+
+ fn test_sender() -> u32 {
+ 1
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree_custom_proposals(
+ name: &str,
+ proposal_types: Vec<ProposalType>,
+ ) -> (LeafIndex, TreeKemPublic) {
+ let (leaf, secret, _) = get_basic_test_node_capabilities(
+ TEST_CIPHER_SUITE,
+ name,
+ Capabilities {
+ proposals: proposal_types,
+ ..get_test_capabilities()
+ },
+ )
+ .await;
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_tree(name: &str) -> (LeafIndex, TreeKemPublic) {
+ new_tree_custom_proposals(name, vec![]).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn add_member(tree: &mut TreeKemPublic, name: &str) -> LeafIndex {
+ let test_node = get_basic_test_node(TEST_CIPHER_SUITE, name).await;
+
+ tree.add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await
+ .unwrap()[0]
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_leaf_node(name: &str, leaf_index: u32) -> LeafNode {
+ let (mut leaf, _, signer) = get_basic_test_node_sig_key(TEST_CIPHER_SUITE, name).await;
+
+ leaf.update(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ leaf_index,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf
+ }
+
+ struct TestProposals {
+ test_sender: u32,
+ test_proposals: Vec<AuthenticatedContent>,
+ expected_effects: ProvisionalState,
+ tree: TreeKemPublic,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposals(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ ) -> TestProposals {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (sender_leaf, sender_leaf_secret, _) =
+ get_basic_test_node_sig_key(cipher_suite, "alice").await;
+
+ let sender = LeafIndex(0);
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ sender_leaf,
+ sender_leaf_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let add_package = test_key_package(protocol_version, cipher_suite, "dave").await;
+
+ let remove_leaf_index = add_member(&mut tree, "carol").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: add_package.clone(),
+ }));
+
+ let remove = Proposal::Remove(RemoveProposal {
+ to_remove: remove_leaf_index,
+ });
+
+ let extensions = Proposal::GroupContextExtensions(ExtensionList::new());
+
+ let proposals = vec![add, remove, extensions];
+
+ let test_node = get_basic_test_node(cipher_suite, "charlie").await;
+
+ let test_sender = *tree
+ .add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let mut expected_tree = tree.clone();
+
+ let mut bundle = ProposalBundle::default();
+
+ let plaintext = proposals
+ .iter()
+ .cloned()
+ .map(|p| auth_content_from_proposal(p, sender))
+ .collect_vec();
+
+ for i in 0..proposals.len() {
+ let pref = ProposalRef::from_content(&cipher_suite_provider, &plaintext[i])
+ .await
+ .unwrap();
+
+ bundle.add(
+ proposals[i].clone(),
+ Sender::Member(test_sender),
+ ProposalSource::ByReference(pref),
+ )
+ }
+
+ expected_tree
+ .batch_edit(
+ &mut bundle,
+ &Default::default(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ true,
+ )
+ .await
+ .unwrap();
+
+ let expected_effects = ProvisionalState {
+ public_tree: expected_tree,
+ group_context: get_test_group_context(1, cipher_suite).await,
+ external_init_index: None,
+ indexes_of_added_kpkgs: vec![LeafIndex(1)],
+ #[cfg(feature = "state_update")]
+ unused_proposals: vec![],
+ applied_proposals: bundle,
+ };
+
+ TestProposals {
+ test_sender,
+ test_proposals: plaintext,
+ expected_effects,
+ tree,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn filter_proposals(
+ cipher_suite: CipherSuite,
+ proposals: Vec<AuthenticatedContent>,
+ ) -> Vec<(ProposalRef, CachedProposal)> {
+ let mut contents = Vec::new();
+
+ for p in proposals {
+ if let Content::Proposal(proposal) = &p.content.content {
+ let proposal_ref =
+ ProposalRef::from_content(&test_cipher_suite_provider(cipher_suite), &p)
+ .await
+ .unwrap();
+ contents.push((
+ proposal_ref,
+ CachedProposal::new(proposal.as_ref().clone(), p.content.sender),
+ ));
+ }
+ }
+
+ contents
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_ref<S>(p: &Proposal, sender: S) -> ProposalRef
+ where
+ S: Into<Sender>,
+ {
+ ProposalRef::from_content(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &auth_content_from_proposal(p.clone(), sender),
+ )
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_proposal_info<S>(p: &Proposal, sender: S) -> ProposalInfo<Proposal>
+ where
+ S: Into<Sender> + Clone,
+ {
+ ProposalInfo {
+ proposal: p.clone(),
+ sender: sender.clone().into(),
+ source: ProposalSource::ByReference(make_proposal_ref(p, sender).await),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_proposal_cache_setup(proposals: Vec<AuthenticatedContent>) -> ProposalCache {
+ let mut cache = make_proposal_cache();
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, proposals).await);
+ cache
+ }
+
+ fn assert_matches(mut expected_state: ProvisionalState, state: ProvisionalState) {
+ let expected_proposals = expected_state.applied_proposals.into_proposals_or_refs();
+ let proposals = state.applied_proposals.into_proposals_or_refs();
+
+ assert_eq!(proposals.len(), expected_proposals.len());
+
+ // Determine there are no duplicates in the proposals returned
+ assert!(!proposals.iter().enumerate().any(|(i, p1)| proposals
+ .iter()
+ .enumerate()
+ .any(|(j, p2)| p1 == p2 && i != j)),);
+
+ // Proposal order may change so we just compare the length and contents are the same
+ expected_proposals
+ .iter()
+ .for_each(|p| assert!(proposals.contains(p)));
+
+ assert_eq!(
+ expected_state.external_init_index,
+ state.external_init_index
+ );
+
+ // We don't compare the epoch in this test.
+ expected_state.group_context.epoch = state.group_context.epoch;
+ assert_eq!(expected_state.group_context, state.group_context);
+
+ assert_eq!(
+ expected_state.indexes_of_added_kpkgs,
+ state.indexes_of_added_kpkgs
+ );
+
+ assert_eq!(expected_state.public_tree, state.public_tree);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(expected_state.unused_proposals, state.unused_proposals);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_all_cached() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_commit_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ mut expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let additional_key_package =
+ test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await;
+
+ let additional = AddProposal {
+ key_package: additional_key_package.clone(),
+ };
+
+ let cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![Proposal::Add(Box::new(additional.clone()))],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ expected_effects.applied_proposals.add(
+ Proposal::Add(Box::new(additional.clone())),
+ Sender::Member(test_sender),
+ ProposalSource::ByValue,
+ );
+
+ let leaf = vec![additional_key_package.leaf_node.clone()];
+
+ expected_effects
+ .public_tree
+ .add_leaves(leaf, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ expected_effects.indexes_of_added_kpkgs.push(LeafIndex(3));
+
+ assert_matches(expected_effects, provisional_state);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_update_filter() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update_proposal = make_update_proposal("foo").await;
+
+ let additional = vec![Proposal::Update(update_proposal)];
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_removal_override_update() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let update = Proposal::Update(make_update_proposal("foo").await);
+ let update_proposal_ref = make_proposal_ref(&update, LeafIndex(1)).await;
+ let mut cache = test_proposal_cache_setup(test_proposals).await;
+
+ cache.insert(update_proposal_ref.clone(), update, Sender::Member(1));
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(provisional_state
+ .applied_proposals
+ .removals
+ .iter()
+ .any(|p| *p.proposal.to_remove == 1));
+
+ assert!(!provisional_state
+ .applied_proposals
+ .into_proposals_or_refs()
+ .contains(&ProposalOrRef::Reference(update_proposal_ref)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_insert() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, test_proposals.clone()).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ vec![],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_filter_duplicates_additional() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_proposals,
+ expected_effects,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let mut cache = test_proposal_cache_setup(test_proposals.clone()).await;
+
+ // Updates from different senders will be allowed so we test duplicates for add / remove
+ let additional = test_proposals
+ .clone()
+ .into_iter()
+ .filter_map(|plaintext| match plaintext.content.content {
+ Content::Proposal(p) if p.proposal_type() == ProposalType::UPDATE => None,
+ Content::Proposal(_) => Some(plaintext),
+ _ => None,
+ })
+ .collect::<Vec<_>>();
+
+ cache.extend(filter_proposals(TEST_CIPHER_SUITE, additional).await);
+
+ let provisional_state = cache
+ .prepare_commit_default(
+ Sender::Member(2),
+ Vec::new(),
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_is_empty() {
+ let mut cache = make_proposal_cache();
+ assert!(cache.is_empty());
+
+ let test_proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(test_sender()),
+ });
+
+ let proposer = test_sender();
+ let test_proposal_ref = make_proposal_ref(&test_proposal, LeafIndex(proposer)).await;
+ cache.insert(test_proposal_ref, test_proposal, Sender::Member(proposer));
+
+ assert!(!cache.is_empty())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_cache_resolve() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let TestProposals {
+ test_sender,
+ test_proposals,
+ tree,
+ ..
+ } = test_proposals(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ let cache = test_proposal_cache_setup(test_proposals).await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ }));
+
+ let additional = vec![proposal];
+
+ let expected_effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender),
+ additional,
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ let proposals = expected_effects
+ .applied_proposals
+ .clone()
+ .into_proposals_or_refs();
+
+ let resolution = cache
+ .resolve_for_commit_default(
+ Sender::Member(test_sender),
+ proposals,
+ None,
+ &ExtensionList::new(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches(expected_effects, resolution);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_filters_duplicate_psk_ids() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let proposal = Proposal::Psk(make_external_psk(
+ b"ted",
+ crate::psk::PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ ));
+
+ let res = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![proposal.clone(), proposal],
+ &get_test_group_context(0, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_node() -> LeafNode {
+ let (mut leaf_node, _, signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "foo").await;
+
+ leaf_node
+ .commit(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ TEST_GROUP,
+ 0,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ leaf_node
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn external_commit_must_have_new_leaf() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Proposal(Box::new(Proposal::ExternalInit(
+ ExternalInit { kem_output },
+ )))],
+ None,
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitMustHaveNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_proposals_by_ref_for_new_member() {
+ let mut cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let proposal = {
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ Proposal::ExternalInit(ExternalInit { kem_output })
+ };
+
+ let proposal_ref = make_proposal_ref(&proposal, test_sender()).await;
+
+ cache.insert(
+ proposal_ref.clone(),
+ proposal,
+ Sender::Member(test_sender()),
+ );
+
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ vec![ProposalOrRef::Reference(proposal_ref)],
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::OnlyMembersCanCommitProposalsByRef));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposal_cache_rejects_multiple_external_init_proposals_in_commit() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit {
+ kem_output: kem_output.clone(),
+ }),
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new_member_commits_proposal(proposal: Proposal) -> Result<ProvisionalState, MlsError> {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ [
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ proposal,
+ ]
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_add_proposal() {
+ let res = new_member_commits_proposal(Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::ADD
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_more_than_one_remove_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let foo = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let bar = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![foo, bar];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[1],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitWithMoreThanOneRemove));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_invalid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "bar").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::ExternalCommitRemovesOtherIdentity));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_remove_proposal_valid_credential() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let kem_output = vec![0; cipher_suite_provider.kdf_extract_size()];
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let group_extensions = group.group.context().extensions.clone();
+ let mut public_tree = group.group.state.public_tree;
+
+ let node = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let test_leaf_nodes = vec![node];
+
+ let test_leaf_node_indexes = public_tree
+ .add_leaves(
+ test_leaf_nodes,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let proposals = vec![
+ Proposal::ExternalInit(ExternalInit { kem_output }),
+ Proposal::Remove(RemoveProposal {
+ to_remove: test_leaf_node_indexes[0],
+ }),
+ ];
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ proposals
+ .into_iter()
+ .map(|p| ProposalOrRef::Proposal(Box::new(p)))
+ .collect(),
+ Some(&test_node().await),
+ &group_extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_update_proposal() {
+ let res = new_member_commits_proposal(Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "foo").await,
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::UPDATE
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_group_extensions_proposal() {
+ let res =
+ new_member_commits_proposal(Proposal::GroupContextExtensions(ExtensionList::new()))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::GROUP_CONTEXT_EXTENSIONS,
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_cannot_commit_reinit_proposal() {
+ let res = new_member_commits_proposal(Proposal::ReInit(ReInitProposal {
+ group_id: b"foo".to_vec(),
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidProposalTypeInExternalCommit(
+ ProposalType::RE_INIT
+ ))
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn new_member_commit_must_contain_an_external_init_proposal() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ let public_tree = &group.group.state.public_tree;
+
+ let res = cache
+ .resolve_for_commit_default(
+ Sender::NewMemberCommit,
+ Vec::new(),
+ Some(&test_node().await),
+ &group.group.context().extensions,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ public_tree,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_empty() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ vec![],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_updates() {
+ let mut cache = make_proposal_cache();
+ let update = Proposal::Update(make_update_proposal("bar").await);
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ cache.insert(
+ make_proposal_ref(&update, LeafIndex(2)).await,
+ update,
+ Sender::Member(2),
+ );
+
+ let mut tree = TreeKemPublic::new();
+ add_member(&mut tree, "alice").await;
+ add_member(&mut tree, "bob").await;
+ add_member(&mut tree, "carol").await;
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(test_sender()),
+ Vec::new(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_required_removes() {
+ let cache = make_proposal_cache();
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, _) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+ let alice = 0;
+
+ let (mut tree, _) = TreeKemPublic::derive(
+ alice_leaf,
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let bob_node = get_basic_test_node(TEST_CIPHER_SUITE, "bob").await;
+
+ let bob = tree
+ .add_leaves(
+ vec![bob_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(alice),
+ vec![remove],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(path_update_required(&effects.applied_proposals))
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_update_not_required() {
+ let (alice, tree) = new_tree("alice").await;
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let cache = make_proposal_cache();
+
+ let psk = Proposal::Psk(PreSharedKeyProposal {
+ psk: PreSharedKeyID::new(
+ JustPreSharedKeyID::External(ExternalPskId::new(vec![])),
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .unwrap(),
+ });
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }));
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![psk, add],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn path_update_is_not_required_for_re_init() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let (alice, tree) = new_tree("alice").await;
+ let cache = make_proposal_cache();
+
+ let reinit = Proposal::ReInit(ReInitProposal {
+ group_id: vec![],
+ version: TEST_PROTOCOL_VERSION,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: Default::default(),
+ });
+
+ let effects = cache
+ .prepare_commit_default(
+ Sender::Member(*alice),
+ vec![reinit],
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ &tree,
+ None,
+ &AlwaysFoundPskStorage,
+ pass_through_rules(),
+ )
+ .await
+ .unwrap();
+
+ assert!(!path_update_required(&effects.applied_proposals))
+ }
+
+ #[derive(Debug)]
+ struct CommitSender<'a, C, F, P, CSP> {
+ cipher_suite_provider: CSP,
+ tree: &'a TreeKemPublic,
+ sender: LeafIndex,
+ cache: ProposalCache,
+ additional_proposals: Vec<Proposal>,
+ identity_provider: C,
+ user_rules: F,
+ psk_storage: P,
+ }
+
+ impl<'a, CSP>
+ CommitSender<'a, BasicWithCustomProvider, DefaultMlsRules, AlwaysFoundPskStorage, CSP>
+ {
+ fn new(tree: &'a TreeKemPublic, sender: LeafIndex, cipher_suite_provider: CSP) -> Self {
+ Self {
+ tree,
+ sender,
+ cache: make_proposal_cache(),
+ additional_proposals: Vec::new(),
+ identity_provider: BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ user_rules: pass_through_rules(),
+ psk_storage: AlwaysFoundPskStorage,
+ cipher_suite_provider,
+ }
+ }
+ }
+
+ impl<'a, C, F, P, CSP> CommitSender<'a, C, F, P, CSP>
+ where
+ C: IdentityProvider,
+ F: MlsRules,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ fn with_identity_provider<V>(self, identity_provider: V) -> CommitSender<'a, V, F, P, CSP>
+ where
+ V: IdentityProvider,
+ {
+ CommitSender {
+ identity_provider,
+ cipher_suite_provider: self.cipher_suite_provider,
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ user_rules: self.user_rules,
+ psk_storage: self.psk_storage,
+ }
+ }
+
+ fn cache<S>(mut self, r: ProposalRef, p: Proposal, proposer: S) -> Self
+ where
+ S: Into<Sender>,
+ {
+ self.cache.insert(r, p, proposer.into());
+ self
+ }
+
+ fn with_additional<I>(mut self, proposals: I) -> Self
+ where
+ I: IntoIterator<Item = Proposal>,
+ {
+ self.additional_proposals.extend(proposals);
+ self
+ }
+
+ fn with_user_rules<G>(self, f: G) -> CommitSender<'a, C, G, P, CSP>
+ where
+ G: MlsRules,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: f,
+ psk_storage: self.psk_storage,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ fn with_psk_storage<V>(self, v: V) -> CommitSender<'a, C, F, V, CSP>
+ where
+ V: PreSharedKeyStorage,
+ {
+ CommitSender {
+ tree: self.tree,
+ sender: self.sender,
+ cache: self.cache,
+ additional_proposals: self.additional_proposals,
+ identity_provider: self.identity_provider,
+ user_rules: self.user_rules,
+ psk_storage: v,
+ cipher_suite_provider: self.cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn send(&self) -> Result<(Vec<ProposalOrRef>, ProvisionalState), MlsError> {
+ let state = self
+ .cache
+ .prepare_commit_default(
+ Sender::Member(*self.sender),
+ self.additional_proposals.clone(),
+ &get_test_group_context(1, TEST_CIPHER_SUITE).await,
+ &self.identity_provider,
+ &self.cipher_suite_provider,
+ self.tree,
+ None,
+ &self.psk_storage,
+ &self.user_rules,
+ )
+ .await?;
+
+ let proposals = state.applied_proposals.clone().into_proposals_or_refs();
+
+ Ok((proposals, state))
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_invalid_signature() -> KeyPackage {
+ let mut kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "mallory").await;
+ kp.signature.clear();
+ kp
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn key_package_with_public_key(key: crypto::HpkePublicKey) -> KeyPackage {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (mut key_package, signer) =
+ test_key_package_with_signer(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
+
+ key_package.leaf_node.public_key = key;
+
+ key_package
+ .leaf_node
+ .sign(
+ &cs,
+ &signer,
+ &LeafNodeSigningContext {
+ group_id: None,
+ leaf_index: None,
+ },
+ )
+ .await
+ .unwrap();
+
+ key_package.sign(&cs, &signer, &()).await.unwrap();
+
+ key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_invalid_key_package_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_invalid_key_package_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_invalid_signature().await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_hpke_key_of_another_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Add(Box::new(AddProposal {
+ key_package: key_package_with_public_key(
+ tree.get_leaf_node(alice).unwrap().public_key.clone(),
+ )
+ .await,
+ }));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_with_invalid_leaf_node_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_ref = make_proposal_ref(&proposal, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ bob,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, bob)
+ .receive([proposal_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_with_invalid_leaf_node_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "alice").await,
+ });
+
+ let proposal_info = make_proposal_info(&proposal, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(proposal_info.proposal_ref().unwrap().clone(), proposal, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_with_invalid_index_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(20)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_with_invalid_index_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(10),
+ });
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_external_psk(id: &[u8], nonce: PskNonce) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::External(ExternalPskId::new(id.to_vec())),
+ psk_nonce: nonce,
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ fn new_external_psk(id: &[u8]) -> PreSharedKeyProposal {
+ make_external_psk(
+ id,
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength,));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_with_invalid_nonce_fails() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_external_psk(
+ b"foo",
+ invalid_nonce.clone(),
+ ))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidPskNonceLength));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_with_invalid_nonce_filters_it_out() {
+ let invalid_nonce = PskNonce(vec![0, 1, 2]);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_external_psk(b"foo", invalid_nonce));
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ fn make_resumption_psk(usage: ResumptionPSKUsage) -> PreSharedKeyProposal {
+ PreSharedKeyProposal {
+ psk: PreSharedKeyID {
+ key_id: JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(TEST_GROUP.to_vec()),
+ psk_epoch: 1,
+ }),
+ psk_nonce: PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .unwrap(),
+ },
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn receiving_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Psk(make_resumption_psk(usage))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_additional_resumption_psk_with_bad_usage_fails(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Psk(make_resumption_psk(usage))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal));
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn sending_resumption_psk_with_bad_usage_filters_it_out(usage: ResumptionPSKUsage) {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(make_resumption_psk(usage));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_reinit_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_reinit_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_reinit_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Reinit).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_resumption_psk_with_branch_usage_fails() {
+ receiving_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_resumption_psk_with_branch_usage_fails() {
+ sending_additional_resumption_psk_with_bad_usage_fails(ResumptionPSKUsage::Branch).await;
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_resumption_psk_with_branch_usage_filters_it_out() {
+ sending_resumption_psk_with_bad_usage_filters_it_out(ResumptionPSKUsage::Branch).await;
+ }
+
+ fn make_reinit(version: ProtocolVersion) -> ReInitProposal {
+ ReInitProposal {
+ group_id: TEST_GROUP.to_vec(),
+ version,
+ cipher_suite: TEST_CIPHER_SUITE,
+ extensions: ExtensionList::new(),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_downgrading_version_fails() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ReInit(make_reinit(smaller_protocol_version))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProtocolVersionInReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_downgrading_version_filters_it_out() {
+ let smaller_protocol_version = ProtocolVersion::from(0);
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::ReInit(make_reinit(smaller_protocol_version));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let update = Proposal::Update(make_update_proposal("alice").await);
+ let update_ref = make_proposal_ref(&update, alice).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, alice)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidCommitSelfUpdate));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_update_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Update(make_update_proposal("alice").await)])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Update(make_update_proposal("alice").await);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_remove_for_committer_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Remove(RemoveProposal { to_remove: alice })])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::CommitterSelfRemoval));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_remove_for_committer_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Remove(RemoveProposal { to_remove: alice });
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_and_remove_for_same_leaf_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .cache(remove_ref.clone(), remove, bob)
+ .receive([update_ref, remove_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_and_remove_for_same_leaf_filters_update_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("bob").await);
+ let update_info = make_proposal_info(&update, alice).await;
+
+ let remove = Proposal::Remove(RemoveProposal { to_remove: bob });
+ let remove_ref = make_proposal_ref(&remove, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ update_info.proposal_ref().unwrap().clone(),
+ update.clone(),
+ alice,
+ )
+ .cache(remove_ref.clone(), remove, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![remove_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_add_proposal() -> Box<AddProposal> {
+ Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "frank").await,
+ })
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_proposals_for_same_client_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_proposals_for_same_client_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add_one = Proposal::Add(make_add_proposal().await);
+ let add_two = Proposal::Add(make_add_proposal().await);
+ let add_ref_one = make_proposal_ref(&add_one, alice).await;
+ let add_ref_two = make_proposal_ref(&add_two, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_ref_one.clone(), add_one.clone(), alice)
+ .cache(add_ref_two.clone(), add_two.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let committed_add_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(add_ref)] => add_ref,
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ let add_refs = [add_ref_one, add_ref_two];
+ assert!(add_refs.contains(committed_add_ref));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_add_info] if committed_add_ref != rejected_add_info.proposal_ref().unwrap() && add_refs.contains(rejected_add_info.proposal_ref().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_update_for_different_identity_fails() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal_custom("carol", 1).await);
+ let update_ref = make_proposal_ref(&update, bob).await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(update_ref.clone(), update, bob)
+ .receive([update_ref])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSuccessor));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_update_for_different_identity_filters_it_out() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let update = Proposal::Update(make_update_proposal("carol").await);
+ let update_info = make_proposal_info(&update, bob).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(update_info.proposal_ref().unwrap().clone(), update, bob)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ // Bob proposed the update, so it is not listed as rejected when Alice commits it because
+ // she didn't propose it.
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![update_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_for_same_client_as_existing_member_fails() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let res = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_additional([add])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(1)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_for_same_client_as_existing_member_filters_it_out() {
+ let (alice, public_tree) = new_tree("alice").await;
+ let add = Proposal::Add(make_add_proposal().await);
+
+ let ProvisionalState { public_tree, .. } = CommitReceiver::new(
+ &public_tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([add.clone()])
+ .await
+ .unwrap();
+
+ let proposal_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals = CommitSender::new(
+ &public_tree,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ add.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([psk_proposal.clone(), psk_proposal])
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_psk_proposals_with_same_psk_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+ let psk_proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([psk_proposal.clone(), psk_proposal])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicatePskIds));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_psk_proposals_with_same_psk_id_keeps_only_one() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ let proposal = Proposal::Psk(new_external_psk(b"foo"));
+
+ let proposal_info = [
+ make_proposal_info(&proposal, alice).await,
+ make_proposal_info(&proposal, bob).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info[0].proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .cache(
+ proposal_info[1].proposal_ref().unwrap().clone(),
+ proposal,
+ bob,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [r] => r.clone(),
+ _ => panic!("Expected single proposal reference in {processed_proposals:?}"),
+ };
+
+ assert!(proposal_info.contains(&committed_info));
+
+ #[cfg(feature = "state_update")]
+ match &*processed_proposals.1.unused_proposals {
+ [r] => {
+ assert_ne!(*r, committed_info);
+ assert!(proposal_info.contains(r));
+ }
+ _ => panic!(
+ "Expected one proposal reference in {:?}",
+ processed_proposals.1.unused_proposals
+ ),
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_additional_group_context_extensions_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ Proposal::GroupContextExtensions(ExtensionList::new()),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::MoreThanOneGroupContextExtensionsProposal)
+ );
+ }
+
+ fn make_extension_list(foo: u8) -> ExtensionList {
+ vec![TestExtension { foo }.into_extension().unwrap()].into()
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_group_context_extensions_keeps_only_one() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice, tree) = {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, b"alice").await;
+
+ let properties = ConfigProperties {
+ capabilities: Capabilities {
+ extensions: vec![42.into()],
+ ..Capabilities::default()
+ },
+ extensions: Default::default(),
+ };
+
+ let (leaf, secret) = LeafNode::generate(
+ &cipher_suite_provider,
+ properties,
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .unwrap();
+
+ let (pub_tree, priv_tree) =
+ TreeKemPublic::derive(leaf, secret, &BasicIdentityProvider, &Default::default())
+ .await
+ .unwrap();
+
+ (priv_tree.self_index, pub_tree)
+ };
+
+ let proposals = [
+ Proposal::GroupContextExtensions(make_extension_list(0)),
+ Proposal::GroupContextExtensions(make_extension_list(1)),
+ ];
+
+ let gce_info = [
+ make_proposal_info(&proposals[0], alice).await,
+ make_proposal_info(&proposals[1], alice).await,
+ ];
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ gce_info[0].proposal_ref().unwrap().clone(),
+ proposals[0].clone(),
+ alice,
+ )
+ .cache(
+ gce_info[1].proposal_ref().unwrap().clone(),
+ proposals[1].clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ let committed_gce_info = match processed_proposals
+ .1
+ .applied_proposals
+ .clone()
+ .into_proposals()
+ .collect_vec()
+ .as_slice()
+ {
+ [gce_info] => gce_info.clone(),
+ _ => panic!("committed proposals list does not contain exactly one reference"),
+ };
+
+ assert!(gce_info.contains(&committed_gce_info));
+
+ #[cfg(feature = "state_update")]
+ assert_matches!(
+ &*processed_proposals.1.unused_proposals,
+ [rejected_gce_info] if committed_gce_info != *rejected_gce_info && gce_info.contains(rejected_gce_info)
+ );
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_external_senders_extension() -> ExtensionList {
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"alice")
+ .await
+ .0;
+
+ vec![ExternalSendersExt::new(vec![identity])
+ .into_extension()
+ .unwrap()]
+ .into()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_identity_provider(FailureIdentityProvider::new())
+ .receive([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_invalid_external_senders_extension_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .with_additional([Proposal::GroupContextExtensions(
+ make_external_senders_extension().await,
+ )])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_invalid_external_senders_extension_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_external_senders_extension().await);
+
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_identity_provider(FailureIdentityProvider::new())
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_reinit_with_other_proposals_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::Add(make_add_proposal().await),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_reinit_with_other_proposals_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_info = make_proposal_info(&reinit, alice).await;
+ let add = Proposal::Add(make_add_proposal().await);
+ let add_ref = make_proposal_ref(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ reinit_info.proposal_ref().unwrap().clone(),
+ reinit.clone(),
+ alice,
+ )
+ .cache(add_ref.clone(), add, alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, vec![add_ref.into()]);
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![reinit_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_multiple_reinits_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ ])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::OtherProposalWithReInit));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_multiple_reinits_keeps_only_one() {
+ let (alice, tree) = new_tree("alice").await;
+ let reinit = Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION));
+ let reinit_ref = make_proposal_ref(&reinit, alice).await;
+ let other_reinit = Proposal::ReInit(ReInitProposal {
+ group_id: b"other_group".to_vec(),
+ ..make_reinit(TEST_PROTOCOL_VERSION)
+ });
+ let other_reinit_ref = make_proposal_ref(&other_reinit, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(reinit_ref.clone(), reinit.clone(), alice)
+ .cache(other_reinit_ref.clone(), other_reinit.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ let processed_ref = match &*processed_proposals.0 {
+ [ProposalOrRef::Reference(r)] => r,
+ p => panic!("Expected single proposal reference but found {p:?}"),
+ };
+
+ assert!(*processed_ref == reinit_ref || *processed_ref == other_reinit_ref);
+
+ #[cfg(feature = "state_update")]
+ {
+ let (rejected_ref, unused_proposal) = match &*processed_proposals.1.unused_proposals {
+ [r] => (r.proposal_ref().unwrap().clone(), r.proposal.clone()),
+ p => panic!("Expected single proposal but found {p:?}"),
+ };
+
+ assert_ne!(rejected_ref, *processed_ref);
+ assert!(rejected_ref == reinit_ref || rejected_ref == other_reinit_ref);
+ assert!(unused_proposal == reinit || unused_proposal == other_reinit);
+ }
+ }
+
+ fn make_external_init() -> ExternalInit {
+ ExternalInit {
+ kem_output: vec![33; test_cipher_suite_provider(TEST_CIPHER_SUITE).kdf_extract_size()],
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::ExternalInit(make_external_init())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_init_from_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::ExternalInit(make_external_init())])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_init_from_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let external_init = Proposal::ExternalInit(make_external_init());
+ let external_init_info = make_proposal_info(&external_init, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ external_init_info.proposal_ref().unwrap().clone(),
+ external_init.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(
+ processed_proposals.1.unused_proposals,
+ vec![external_init_info]
+ );
+ }
+
+ fn required_capabilities_proposal(extension: u16) -> Proposal {
+ let required_capabilities = RequiredCapabilitiesExt {
+ extensions: vec![extension.into()],
+ ..Default::default()
+ };
+
+ let ext = vec![required_capabilities.into_extension().unwrap()];
+
+ Proposal::GroupContextExtensions(ext.into())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([required_capabilities_proposal(33)])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_required_capabilities_not_supported_by_member_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([required_capabilities_proposal(33)])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 33.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_required_capabilities_not_supported_by_member_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = required_capabilities_proposal(33);
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_update_from_pk2_to_pk3_works() {
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let bob_new_leaf = update_leaf_node("bob", 1).await;
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let pk2_to_pk3 = Proposal::Update(UpdateProposal {
+ leaf_node: bob_new_leaf.clone(),
+ });
+
+ let pk2_to_pk3_ref = make_proposal_ref(&pk2_to_pk3, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(pk2_to_pk3_ref.clone(), pk2_to_pk3, bob)
+ .receive([pk1_to_pk2_ref, pk2_to_pk3_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice, bob]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf, bob_new_leaf]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn committing_update_from_pk1_to_pk2_and_removal_of_pk2_works() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (alice_leaf, alice_secret, alice_signer) =
+ get_basic_test_node_sig_key(TEST_CIPHER_SUITE, "alice").await;
+
+ let (mut tree, priv_tree) = TreeKemPublic::derive(
+ alice_leaf.clone(),
+ alice_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let alice = priv_tree.self_index;
+
+ let bob = add_member(&mut tree, "bob").await;
+ let carol = add_member(&mut tree, "carol").await;
+
+ let bob_current_leaf = tree.get_leaf_node(bob).unwrap();
+
+ let mut alice_new_leaf = LeafNode {
+ public_key: bob_current_leaf.public_key.clone(),
+ leaf_node_source: LeafNodeSource::Update,
+ ..alice_leaf
+ };
+
+ alice_new_leaf
+ .sign(
+ &cipher_suite_provider,
+ &alice_signer,
+ &(TEST_GROUP, 0).into(),
+ )
+ .await
+ .unwrap();
+
+ let pk1_to_pk2 = Proposal::Update(UpdateProposal {
+ leaf_node: alice_new_leaf.clone(),
+ });
+
+ let pk1_to_pk2_ref = make_proposal_ref(&pk1_to_pk2, alice).await;
+
+ let remove_pk2 = Proposal::Remove(RemoveProposal { to_remove: bob });
+
+ let remove_pk2_ref = make_proposal_ref(&remove_pk2, bob).await;
+
+ let effects = CommitReceiver::new(
+ &tree,
+ carol,
+ carol,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(pk1_to_pk2_ref.clone(), pk1_to_pk2, alice)
+ .cache(remove_pk2_ref.clone(), remove_pk2, bob)
+ .receive([pk1_to_pk2_ref, remove_pk2_ref])
+ .await
+ .unwrap();
+
+ assert_eq!(effects.applied_proposals.update_senders, vec![alice]);
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .updates
+ .into_iter()
+ .map(|p| p.proposal.leaf_node)
+ .collect_vec(),
+ vec![alice_new_leaf]
+ );
+
+ assert_eq!(
+ effects
+ .applied_proposals
+ .removals
+ .into_iter()
+ .map(|p| p.proposal.to_remove)
+ .collect_vec(),
+ vec![bob]
+ );
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn unsupported_credential_key_package(name: &str) -> KeyPackage {
+ let (mut signing_identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, name.as_bytes()).await;
+
+ signing_identity.credential = Credential::Custom(CustomCredential::new(
+ CredentialType::new(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE),
+ random_bytes(32),
+ ));
+
+ let generator = KeyPackageGenerator {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite_provider: &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ signing_identity: &signing_identity,
+ signing_key: &secret_key,
+ identity_provider: &BasicWithCustomProvider::new(BasicIdentityProvider::new()),
+ };
+
+ generator
+ .generate(
+ Lifetime::years(1).unwrap(),
+ Capabilities {
+ credentials: vec![42.into()],
+ ..Default::default()
+ },
+ Default::default(),
+ Default::default(),
+ )
+ .await
+ .unwrap()
+ .key_package
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_add_with_leaf_not_supporting_credential_type_of_other_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_add_with_leaf_not_supporting_credential_type_of_other_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let add = Proposal::Add(Box::new(AddProposal {
+ key_package: unsupported_credential_key_package("bob").await,
+ }));
+
+ let add_info = make_proposal_info(&add, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(add_info.proposal_ref().unwrap().clone(), add.clone(), alice)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![add_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_proposal_type_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([custom_proposal.clone()])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedCustomProposal(c)
+ ) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_custom_proposal_with_member_not_supporting_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let custom_info = make_proposal_info(&custom_proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ custom_info.proposal_ref().unwrap().clone(),
+ custom_proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![custom_info]);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_custom_proposal_with_member_not_supporting_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let custom_proposal = Proposal::Custom(CustomProposal::new(ProposalType::new(42), vec![]));
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([custom_proposal.clone()])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::UnsupportedCustomProposal(c)) if c == custom_proposal.proposal_type()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .receive([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_group_extension_unsupported_by_leaf_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(make_extension_list(0))])
+ .send()
+ .await;
+
+ assert_matches!(
+ res,
+ Err(
+ MlsError::UnsupportedGroupExtension(v)
+ ) if v == 42.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_group_extension_unsupported_by_leaf_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(make_extension_list(0));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[cfg(feature = "psk")]
+ #[derive(Debug)]
+ struct AlwaysNotFoundPskStorage;
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl PreSharedKeyStorage for AlwaysNotFoundPskStorage {
+ type Error = Infallible;
+
+ async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(None)
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn receiving_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .receive([Proposal::Psk(new_external_psk(b"abc"))])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_additional_external_psk_with_unknown_id_fails() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .with_additional([Proposal::Psk(new_external_psk(b"abc"))])
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MissingRequiredPsk));
+ }
+
+ #[cfg(feature = "psk")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn sending_external_psk_with_unknown_id_filters_it_out() {
+ let (alice, tree) = new_tree("alice").await;
+ let proposal = Proposal::Psk(new_external_psk(b"abc"));
+ let proposal_info = make_proposal_info(&proposal, alice).await;
+
+ let processed_proposals =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_psk_storage(AlwaysNotFoundPskStorage)
+ .cache(
+ proposal_info.proposal_ref().unwrap().clone(),
+ proposal.clone(),
+ alice,
+ )
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(processed_proposals.0, Vec::new());
+
+ #[cfg(feature = "state_update")]
+ assert_eq!(processed_proposals.1.unused_proposals, vec![proposal_info]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_remove_proposals() {
+ struct RemoveGroupContextExtensions;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for RemoveGroupContextExtensions {
+ type Error = Infallible;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.group_context_extensions.clear();
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ let (alice, tree) = new_tree("alice").await;
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(RemoveGroupContextExtensions)
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, Vec::new());
+ }
+
+ struct FailureMlsRules;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for FailureMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ _: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ Err(MlsError::InvalidSignature)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ struct InjectMlsRules {
+ to_inject: Proposal,
+ source: ProposalSource,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl MlsRules for InjectMlsRules {
+ type Error = MlsError;
+
+ async fn filter_proposals(
+ &self,
+ _: CommitDirection,
+ _: CommitSource,
+ _: &Roster,
+ _: &ExtensionList,
+ mut proposals: ProposalBundle,
+ ) -> Result<ProposalBundle, Self::Error> {
+ proposals.add(
+ self.to_inject.clone(),
+ Sender::Member(0),
+ self.source.clone(),
+ );
+ Ok(proposals)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn commit_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ _: &ProposalBundle,
+ ) -> Result<CommitOptions, Self::Error> {
+ Ok(Default::default())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn encryption_options(
+ &self,
+ _: &Roster,
+ _: &ExtensionList,
+ ) -> Result<EncryptionOptions, Self::Error> {
+ Ok(Default::default())
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(
+ committed,
+ vec![ProposalOrRef::Proposal(test_proposal.into())]
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_inject_local_only_proposals() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::GroupContextExtensions(Default::default());
+
+ let (committed, _) =
+ CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::Local,
+ })
+ .send()
+ .await
+ .unwrap();
+
+ assert_eq!(committed, vec![]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_cant_break_base_rules() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let test_proposal = Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "leaf").await,
+ });
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_user_rules(InjectMlsRules {
+ to_inject: test_proposal.clone(),
+ source: ProposalSource::ByValue,
+ })
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender { .. }))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_refuse_to_send_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitSender::new(&tree, alice, test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .with_additional([Proposal::GroupContextExtensions(Default::default())])
+ .with_user_rules(FailureMlsRules)
+ .send()
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn user_defined_filter_can_reject_incoming_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let res = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .with_user_rules(FailureMlsRules)
+ .receive([Proposal::GroupContextExtensions(Default::default())])
+ .await;
+
+ assert_matches!(res, Err(MlsError::MlsRulesError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn proposers_are_verified() {
+ let (alice, mut tree) = new_tree("alice").await;
+ let bob = add_member(&mut tree, "bob").await;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let identity = get_test_signing_identity(TEST_CIPHER_SUITE, b"carol")
+ .await
+ .0;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let external_senders = ExternalSendersExt::new(vec![identity]);
+
+ let proposals: &[Proposal] = &[
+ Proposal::Add(make_add_proposal().await),
+ Proposal::Update(make_update_proposal("alice").await),
+ Proposal::Remove(RemoveProposal { to_remove: bob }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(make_external_psk(
+ b"ted",
+ PskNonce::random(&test_cipher_suite_provider(TEST_CIPHER_SUITE)).unwrap(),
+ )),
+ Proposal::ReInit(make_reinit(TEST_PROTOCOL_VERSION)),
+ Proposal::ExternalInit(make_external_init()),
+ Proposal::GroupContextExtensions(Default::default()),
+ ];
+
+ let proposers = [
+ Sender::Member(*alice),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(0),
+ Sender::NewMemberCommit,
+ Sender::NewMemberProposal,
+ ];
+
+ for ((proposer, proposal), by_ref) in proposers
+ .into_iter()
+ .cartesian_product(proposals)
+ .cartesian_product([true])
+ {
+ let committer = Sender::Member(*alice);
+
+ let receiver = CommitReceiver::new(
+ &tree,
+ committer,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let extensions: ExtensionList =
+ vec![external_senders.clone().into_extension().unwrap()].into();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let receiver = receiver.with_extensions(extensions);
+
+ let (receiver, proposals, proposer) = if by_ref {
+ let proposal_ref = make_proposal_ref(proposal, proposer).await;
+ let receiver = receiver.cache(proposal_ref.clone(), proposal.clone(), proposer);
+ (receiver, vec![ProposalOrRef::from(proposal_ref)], proposer)
+ } else {
+ (receiver, vec![proposal.clone().into()], committer)
+ };
+
+ let res = receiver.receive(proposals).await;
+
+ if proposer_can_propose(proposer, proposal.proposal_type(), by_ref).is_err() {
+ assert_matches!(res, Err(MlsError::InvalidProposalTypeForSender));
+ } else {
+ let is_self_update = proposal.proposal_type() == ProposalType::UPDATE
+ && by_ref
+ && matches!(proposer, Sender::Member(_));
+
+ if !is_self_update {
+ res.unwrap();
+ }
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal(name: &str) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, 1).await,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn make_update_proposal_custom(name: &str, leaf_index: u32) -> UpdateProposal {
+ UpdateProposal {
+ leaf_node: update_leaf_node(name, leaf_index).await,
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn when_receiving_commit_unused_proposals_are_proposals_in_cache_but_not_in_commit() {
+ let (alice, tree) = new_tree("alice").await;
+
+ let proposal = Proposal::GroupContextExtensions(Default::default());
+ let proposal_ref = make_proposal_ref(&proposal, alice).await;
+
+ let state = CommitReceiver::new(
+ &tree,
+ alice,
+ alice,
+ test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .cache(proposal_ref.clone(), proposal, alice)
+ .receive([Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "bob").await,
+ }))])
+ .await
+ .unwrap();
+
+ let [p] = &state.unused_proposals[..] else {
+ panic!(
+ "Expected single unused proposal but got {:?}",
+ state.unused_proposals
+ );
+ };
+
+ assert_eq!(p.proposal_ref(), Some(&proposal_ref));
+ }
+}
diff --git a/src/group/proposal_filter.rs b/src/group/proposal_filter.rs
new file mode 100644
index 0000000..5ef6b20
--- /dev/null
+++ b/src/group/proposal_filter.rs
@@ -0,0 +1,23 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod bundle;
+mod filtering_common;
+
+#[cfg(feature = "by_ref_proposal")]
+mod filtering;
+#[cfg(not(feature = "by_ref_proposal"))]
+pub mod filtering_lite;
+#[cfg(all(feature = "custom_proposal", not(feature = "by_ref_proposal")))]
+use filtering_lite as filtering;
+
+pub use bundle::{ProposalBundle, ProposalInfo, ProposalSource};
+
+#[cfg(feature = "by_ref_proposal")]
+pub(crate) use filtering::FilterStrategy;
+
+pub(crate) use filtering_common::ProposalApplier;
+
+#[cfg(all(feature = "by_ref_proposal", test))]
+pub(crate) use filtering::proposer_can_propose;
diff --git a/src/group/proposal_filter/bundle.rs b/src/group/proposal_filter/bundle.rs
new file mode 100644
index 0000000..f18a75b
--- /dev/null
+++ b/src/group/proposal_filter/bundle.rs
@@ -0,0 +1,633 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+use crate::{
+ group::{
+ AddProposal, BorrowedProposal, Proposal, ProposalOrRef, ProposalType, ReInitProposal,
+ RemoveProposal, Sender,
+ },
+ ExtensionList,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::{proposal_cache::CachedProposal, LeafIndex, ProposalRef, UpdateProposal};
+
+#[cfg(feature = "psk")]
+use crate::group::PreSharedKeyProposal;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::CustomProposal;
+
+use crate::group::ExternalInit;
+
+use core::iter::empty;
+
+#[derive(Clone, Debug, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// A collection of proposals.
+pub struct ProposalBundle {
+ pub(crate) additions: Vec<ProposalInfo<AddProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) updates: Vec<ProposalInfo<UpdateProposal>>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) update_senders: Vec<LeafIndex>,
+ pub(crate) removals: Vec<ProposalInfo<RemoveProposal>>,
+ #[cfg(feature = "psk")]
+ pub(crate) psks: Vec<ProposalInfo<PreSharedKeyProposal>>,
+ pub(crate) reinitializations: Vec<ProposalInfo<ReInitProposal>>,
+ pub(crate) external_initializations: Vec<ProposalInfo<ExternalInit>>,
+ pub(crate) group_context_extensions: Vec<ProposalInfo<ExtensionList>>,
+ #[cfg(feature = "custom_proposal")]
+ pub(crate) custom_proposals: Vec<ProposalInfo<CustomProposal>>,
+}
+
+impl ProposalBundle {
+ pub fn add(&mut self, proposal: Proposal, sender: Sender, source: ProposalSource) {
+ match proposal {
+ Proposal::Add(proposal) => self.additions.push(ProposalInfo {
+ proposal: *proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "by_ref_proposal")]
+ Proposal::Update(proposal) => self.updates.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::Remove(proposal) => self.removals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ #[cfg(feature = "psk")]
+ Proposal::Psk(proposal) => self.psks.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ReInit(proposal) => self.reinitializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::ExternalInit(proposal) => self.external_initializations.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ Proposal::GroupContextExtensions(proposal) => {
+ self.group_context_extensions.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ })
+ }
+ #[cfg(feature = "custom_proposal")]
+ Proposal::Custom(proposal) => self.custom_proposals.push(ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }),
+ }
+ }
+
+ /// Remove the proposal of type `T` at `index`
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ ///
+ /// `index` is consistent with the index returned by any of the proposal
+ /// type specific functions in this module.
+ pub fn remove<T: Proposable>(&mut self, index: usize) {
+ T::remove(self, index);
+ }
+
+ /// Iterate over proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn by_type<'a, T: Proposable + 'a>(&'a self) -> impl Iterator<Item = &'a ProposalInfo<T>> {
+ T::filter(self).iter()
+ }
+
+ /// Retain proposals, filtered by type.
+ ///
+ /// Type `T` can be any of the standard MLS proposal types defined in the
+ /// [`proposal`](crate::group::proposal) module.
+ pub fn retain_by_type<T, F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ T: Proposable,
+ F: FnMut(&ProposalInfo<T>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ T::retain(self, |p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn retain_custom<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<CustomProposal>) -> Result<bool, E>,
+ {
+ let mut res = Ok(());
+
+ self.custom_proposals.retain(|p| match f(p) {
+ Ok(keep) => keep,
+ Err(e) => {
+ if res.is_ok() {
+ res = Err(e);
+ }
+ false
+ }
+ });
+
+ res
+ }
+
+ /// Retain MLS standard proposals in the bundle.
+ pub fn retain<F, E>(&mut self, mut f: F) -> Result<(), E>
+ where
+ F: FnMut(&ProposalInfo<BorrowedProposal<'_>>) -> Result<bool, E>,
+ {
+ self.retain_by_type::<AddProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.retain_by_type::<UpdateProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<RemoveProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ #[cfg(feature = "psk")]
+ self.retain_by_type::<PreSharedKeyProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ReInitProposal, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExternalInit, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ self.retain_by_type::<ExtensionList, _, _>(|proposal| {
+ f(&proposal.as_ref().map(BorrowedProposal::from))
+ })?;
+
+ Ok(())
+ }
+
+ /// The number of proposals in the bundle
+ pub fn length(&self) -> usize {
+ let len = 0;
+
+ #[cfg(feature = "psk")]
+ let len = len + self.psks.len();
+
+ let len = len + self.external_initializations.len();
+
+ #[cfg(feature = "custom_proposal")]
+ let len = len + self.custom_proposals.len();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let len = len + self.updates.len();
+
+ len + self.additions.len()
+ + self.removals.len()
+ + self.reinitializations.len()
+ + self.group_context_extensions.len()
+ }
+
+ /// Iterate over all proposals inside the bundle.
+ pub fn iter_proposals(&self) -> impl Iterator<Item = ProposalInfo<BorrowedProposal<'_>>> {
+ let res = self
+ .additions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Add))
+ .chain(
+ self.removals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Remove)),
+ )
+ .chain(
+ self.reinitializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ReInit)),
+ );
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(
+ self.updates
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Update)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(
+ self.psks
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Psk)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::ExternalInit)),
+ );
+
+ let res = res.chain(
+ self.group_context_extensions
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::GroupContextExtensions)),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .iter()
+ .map(|p| p.as_ref().map(BorrowedProposal::Custom)),
+ );
+
+ res
+ }
+
+ /// Iterate over proposal in the bundle, consuming the bundle.
+ pub fn into_proposals(self) -> impl Iterator<Item = ProposalInfo<Proposal>> {
+ let res = empty();
+
+ #[cfg(feature = "custom_proposal")]
+ let res = res.chain(
+ self.custom_proposals
+ .into_iter()
+ .map(|p| p.map(Proposal::Custom)),
+ );
+
+ let res = res.chain(
+ self.external_initializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ExternalInit)),
+ );
+
+ #[cfg(feature = "psk")]
+ let res = res.chain(self.psks.into_iter().map(|p| p.map(Proposal::Psk)));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain(self.updates.into_iter().map(|p| p.map(Proposal::Update)));
+
+ res.chain(
+ self.additions
+ .into_iter()
+ .map(|p| p.map(|p| Proposal::Add(alloc::boxed::Box::new(p)))),
+ )
+ .chain(self.removals.into_iter().map(|p| p.map(Proposal::Remove)))
+ .chain(
+ self.reinitializations
+ .into_iter()
+ .map(|p| p.map(Proposal::ReInit)),
+ )
+ .chain(
+ self.group_context_extensions
+ .into_iter()
+ .map(|p| p.map(Proposal::GroupContextExtensions)),
+ )
+ }
+
+ pub(crate) fn into_proposals_or_refs(self) -> Vec<ProposalOrRef> {
+ self.into_proposals()
+ .filter_map(|p| match p.source {
+ ProposalSource::ByValue => Some(ProposalOrRef::Proposal(Box::new(p.proposal))),
+ #[cfg(feature = "by_ref_proposal")]
+ ProposalSource::ByReference(reference) => Some(ProposalOrRef::Reference(reference)),
+ _ => None,
+ })
+ .collect()
+ }
+
+ /// Add proposals in the bundle.
+ pub fn add_proposals(&self) -> &[ProposalInfo<AddProposal>] {
+ &self.additions
+ }
+
+ /// Update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposals(&self) -> &[ProposalInfo<UpdateProposal>] {
+ &self.updates
+ }
+
+ /// Senders of update proposals in the bundle.
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_proposal_senders(&self) -> &[LeafIndex] {
+ &self.update_senders
+ }
+
+ /// Remove proposals in the bundle.
+ pub fn remove_proposals(&self) -> &[ProposalInfo<RemoveProposal>] {
+ &self.removals
+ }
+
+ /// Pre-shared key proposals in the bundle.
+ #[cfg(feature = "psk")]
+ pub fn psk_proposals(&self) -> &[ProposalInfo<PreSharedKeyProposal>] {
+ &self.psks
+ }
+
+ /// Reinit proposals in the bundle.
+ pub fn reinit_proposals(&self) -> &[ProposalInfo<ReInitProposal>] {
+ &self.reinitializations
+ }
+
+ /// External init proposals in the bundle.
+ pub fn external_init_proposals(&self) -> &[ProposalInfo<ExternalInit>] {
+ &self.external_initializations
+ }
+
+ /// Group context extension proposals in the bundle.
+ pub fn group_context_ext_proposals(&self) -> &[ProposalInfo<ExtensionList>] {
+ &self.group_context_extensions
+ }
+
+ /// Custom proposals in the bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposals(&self) -> &[ProposalInfo<CustomProposal>] {
+ &self.custom_proposals
+ }
+
+ pub(crate) fn group_context_extensions_proposal(&self) -> Option<&ProposalInfo<ExtensionList>> {
+ self.group_context_extensions.first()
+ }
+
+ /// Custom proposal types that are in use within this bundle.
+ #[cfg(feature = "custom_proposal")]
+ pub fn custom_proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ #[cfg(feature = "std")]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .unique();
+
+ #[cfg(not(feature = "std"))]
+ let res = self
+ .custom_proposals
+ .iter()
+ .map(|v| v.proposal.proposal_type())
+ .collect::<alloc::collections::BTreeSet<_>>()
+ .into_iter();
+
+ res
+ }
+
+ /// Standard proposal types that are in use within this bundle.
+ pub fn proposal_types(&self) -> impl Iterator<Item = ProposalType> + '_ {
+ let res = (!self.additions.is_empty())
+ .then_some(ProposalType::ADD)
+ .into_iter()
+ .chain((!self.removals.is_empty()).then_some(ProposalType::REMOVE))
+ .chain((!self.reinitializations.is_empty()).then_some(ProposalType::RE_INIT));
+
+ #[cfg(feature = "by_ref_proposal")]
+ let res = res.chain((!self.updates.is_empty()).then_some(ProposalType::UPDATE));
+
+ #[cfg(feature = "psk")]
+ let res = res.chain((!self.psks.is_empty()).then_some(ProposalType::PSK));
+
+ let res = res.chain(
+ (!self.external_initializations.is_empty()).then_some(ProposalType::EXTERNAL_INIT),
+ );
+
+ #[cfg(not(feature = "custom_proposal"))]
+ return res.chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ );
+
+ #[cfg(feature = "custom_proposal")]
+ return res
+ .chain(
+ (!self.group_context_extensions.is_empty())
+ .then_some(ProposalType::GROUP_CONTEXT_EXTENSIONS),
+ )
+ .chain(self.custom_proposal_types());
+ }
+}
+
+impl FromIterator<(Proposal, Sender, ProposalSource)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (Proposal, Sender, ProposalSource)>,
+ {
+ let mut bundle = ProposalBundle::default();
+ for (proposal, sender, source) in iter {
+ bundle.add(proposal, sender, source);
+ }
+ bundle
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<(&'a ProposalRef, &'a CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = (&'a ProposalRef, &'a CachedProposal)>,
+ {
+ iter.into_iter()
+ .map(|(r, p)| {
+ (
+ p.proposal.clone(),
+ p.sender,
+ ProposalSource::ByReference(r.clone()),
+ )
+ })
+ .collect()
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+impl<'a> FromIterator<&'a (ProposalRef, CachedProposal)> for ProposalBundle {
+ fn from_iter<I>(iter: I) -> Self
+ where
+ I: IntoIterator<Item = &'a (ProposalRef, CachedProposal)>,
+ {
+ iter.into_iter().map(|pair| (&pair.0, &pair.1)).collect()
+ }
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[derive(Clone, Debug, PartialEq)]
+pub enum ProposalSource {
+ ByValue,
+ #[cfg(feature = "by_ref_proposal")]
+ ByReference(ProposalRef),
+ Local,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type(opaque))]
+#[derive(Clone, Debug, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+/// Proposal description used as input to a
+/// [`MlsRules`](crate::MlsRules).
+pub struct ProposalInfo<T> {
+ /// The underlying proposal value.
+ pub proposal: T,
+ /// The sender of this proposal.
+ pub sender: Sender,
+ /// The source of the proposal.
+ pub source: ProposalSource,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl<T> ProposalInfo<T> {
+ /// Create a new ProposalInfo.
+ ///
+ /// The resulting value will be either transmitted with a commit or
+ /// locally injected into a commit resolution depending on the
+ /// `can_transmit` flag.
+ ///
+ /// This function is useful when implementing custom
+ /// [`MlsRules`](crate::MlsRules).
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn new(proposal: T, sender: Sender, can_transmit: bool) -> Self {
+ let source = if can_transmit {
+ ProposalSource::ByValue
+ } else {
+ ProposalSource::Local
+ };
+
+ ProposalInfo {
+ proposal,
+ sender,
+ source,
+ }
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn sender(&self) -> &Sender {
+ &self.sender
+ }
+
+ #[cfg(all(feature = "ffi", not(test)))]
+ pub fn source(&self) -> &ProposalSource {
+ &self.source
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn map<U, F>(self, f: F) -> ProposalInfo<U>
+ where
+ F: FnOnce(T) -> U,
+ {
+ ProposalInfo {
+ proposal: f(self.proposal),
+ sender: self.sender,
+ source: self.source,
+ }
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn as_ref(&self) -> ProposalInfo<&T> {
+ ProposalInfo {
+ proposal: &self.proposal,
+ sender: self.sender,
+ source: self.source.clone(),
+ }
+ }
+
+ #[inline(always)]
+ pub fn is_by_value(&self) -> bool {
+ self.source == ProposalSource::ByValue
+ }
+
+ #[inline(always)]
+ pub fn is_by_reference(&self) -> bool {
+ !self.is_by_value()
+ }
+
+ /// The [`ProposalRef`] of this proposal if its source is [`ProposalSource::ByReference`]
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn proposal_ref(&self) -> Option<&ProposalRef> {
+ match self.source {
+ ProposalSource::ByReference(ref reference) => Some(reference),
+ _ => None,
+ }
+ }
+}
+
+#[cfg(all(feature = "ffi", not(test)))]
+safer_ffi_gen::specialize!(ProposalInfoFfi = ProposalInfo<Proposal>);
+
+pub trait Proposable: Sized {
+ const TYPE: ProposalType;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>];
+ fn remove(bundle: &mut ProposalBundle, index: usize);
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool;
+}
+
+macro_rules! impl_proposable {
+ ($ty:ty, $proposal_type:ident, $field:ident) => {
+ impl Proposable for $ty {
+ const TYPE: ProposalType = ProposalType::$proposal_type;
+
+ fn filter(bundle: &ProposalBundle) -> &[ProposalInfo<Self>] {
+ &bundle.$field
+ }
+
+ fn remove(bundle: &mut ProposalBundle, index: usize) {
+ if index < bundle.$field.len() {
+ bundle.$field.remove(index);
+ }
+ }
+
+ fn retain<F>(bundle: &mut ProposalBundle, keep: F)
+ where
+ F: FnMut(&ProposalInfo<Self>) -> bool,
+ {
+ bundle.$field.retain(keep);
+ }
+ }
+ };
+}
+
+impl_proposable!(AddProposal, ADD, additions);
+#[cfg(feature = "by_ref_proposal")]
+impl_proposable!(UpdateProposal, UPDATE, updates);
+impl_proposable!(RemoveProposal, REMOVE, removals);
+#[cfg(feature = "psk")]
+impl_proposable!(PreSharedKeyProposal, PSK, psks);
+impl_proposable!(ReInitProposal, RE_INIT, reinitializations);
+impl_proposable!(ExternalInit, EXTERNAL_INIT, external_initializations);
+impl_proposable!(
+ ExtensionList,
+ GROUP_CONTEXT_EXTENSIONS,
+ group_context_extensions
+);
diff --git a/src/group/proposal_filter/filtering.rs b/src/group/proposal_filter/filtering.rs
new file mode 100644
index 0000000..8e67ff5
--- /dev/null
+++ b/src/group/proposal_filter/filtering.rs
@@ -0,0 +1,580 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{
+ proposal::ReInitProposal,
+ proposal_filter::{ProposalBundle, ProposalInfo},
+ AddProposal, ProposalType, RemoveProposal, Sender, UpdateProposal,
+ },
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use alloc::vec::Vec;
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(any(
+ feature = "custom_proposal",
+ not(any(mls_build_async, feature = "rayon"))
+))]
+use itertools::Itertools;
+
+use crate::group::ExternalInit;
+
+#[cfg(feature = "psk")]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let proposals = filter_out_invalid_proposers(strategy, proposals)?;
+
+ let mut proposals: ProposalBundle =
+ filter_out_update_for_committer(strategy, commit_sender, proposals)?;
+
+ // We ignore the strategy here because the check above ensures all updates are from members
+ proposals.update_senders = proposals
+ .updates
+ .iter()
+ .map(leaf_index_of_update_sender)
+ .collect::<Result<_, _>>()?;
+
+ let mut proposals = filter_out_removal_of_committer(strategy, commit_sender, proposals)?;
+
+ filter_out_invalid_psks(
+ strategy,
+ self.cipher_suite_provider,
+ &mut proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = filter_out_invalid_group_extensions(
+ strategy,
+ proposals,
+ self.identity_provider,
+ commit_time,
+ )
+ .await?;
+
+ let proposals = filter_out_extra_group_context_extensions(strategy, proposals)?;
+ let proposals = filter_out_invalid_reinit(strategy, proposals, self.protocol_version)?;
+ let proposals = filter_out_reinit_if_other_proposals(strategy.is_ignore(), proposals)?;
+
+ let proposals = filter_out_external_init(strategy, proposals)?;
+
+ self.apply_proposal_changes(strategy, proposals, commit_time)
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(strategy, proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(
+ strategy,
+ proposals,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ strategy: FilterStrategy,
+ proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let mut applied_proposals = self
+ .validate_new_nodes(strategy, proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit(
+ &mut applied_proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ strategy.is_ignore(),
+ )
+ .await?;
+
+ let new_context_extensions = applied_proposals
+ .group_context_extensions_proposal()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ applied_proposals,
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ProposalBundle, MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.update_proposals())
+ .zip(wrap_iter(proposals.update_proposal_senders()))
+ .enumerate()
+ .filter_map(|(i, (p, &sender_index))| async move {
+ let res = {
+ let leaf = &p.proposal.leaf_node;
+
+ let res = leaf_node_validator
+ .check_if_valid(
+ leaf,
+ ValidationContext::Update((self.group_id, *sender_index, commit_time)),
+ )
+ .await;
+
+ let old_leaf = match self.original_tree.get_leaf_node(sender_index) {
+ Ok(leaf) => leaf,
+ Err(e) => return Some(Err(e)),
+ };
+
+ let valid_successor = self
+ .identity_provider
+ .valid_successor(
+ &old_leaf.signing_identity,
+ &leaf.signing_identity,
+ group_extensions_in_use,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
+ .and_then(|valid| valid.then_some(()).ok_or(MlsError::InvalidSuccessor));
+
+ res.and(valid_successor)
+ };
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices.into_iter().rev().for_each(|i| {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ });
+
+ let bad_indices: Vec<_> = wrap_iter(proposals.add_proposals())
+ .enumerate()
+ .filter_map(|(i, p)| async move {
+ let res = self
+ .validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ .await;
+
+ apply_strategy(strategy, p.is_by_reference(), res)
+ .map(|b| (!b).then_some(i))
+ .transpose()
+ })
+ .try_collect()
+ .await?;
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<AddProposal>(i));
+
+ Ok(proposals)
+ }
+}
+
+#[derive(Clone, Copy, Debug)]
+pub enum FilterStrategy {
+ IgnoreByRef,
+ IgnoreNone,
+}
+
+impl FilterStrategy {
+ pub(super) fn ignore(self, by_ref: bool) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => by_ref,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+
+ fn is_ignore(self) -> bool {
+ match self {
+ FilterStrategy::IgnoreByRef => true,
+ FilterStrategy::IgnoreNone => false,
+ }
+ }
+}
+
+pub(crate) fn apply_strategy(
+ strategy: FilterStrategy,
+ by_ref: bool,
+ r: Result<(), MlsError>,
+) -> Result<bool, MlsError> {
+ r.map(|_| true)
+ .or_else(|error| strategy.ignore(by_ref).then_some(false).ok_or(error))
+}
+
+fn filter_out_update_for_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<UpdateProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.sender != Sender::Member(*commit_sender))
+ .then_some(())
+ .ok_or(MlsError::InvalidCommitSelfUpdate),
+ )
+ })?;
+ Ok(proposals)
+}
+
+fn filter_out_removal_of_committer(
+ strategy: FilterStrategy,
+ commit_sender: LeafIndex,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<RemoveProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval),
+ )
+ })?;
+ Ok(proposals)
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<ProposalBundle, MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut bad_indices = Vec::new();
+
+ for (i, p) in proposals.by_type::<ExtensionList>().enumerate() {
+ let ext = p.proposal.get_as::<ExternalSendersExt>();
+
+ let res = match ext {
+ Ok(None) => Ok(()),
+ Ok(Some(extension)) => extension
+ .verify_all(identity_provider, commit_time, &p.proposal)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())),
+ Err(e) => Err(MlsError::from(e)),
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i);
+ }
+ }
+
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<ExtensionList>(i));
+
+ Ok(proposals)
+}
+
+fn filter_out_extra_group_context_extensions(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let mut found = false;
+
+ proposals.retain_by_type::<ExtensionList, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (!core::mem::replace(&mut found, true))
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_invalid_reinit(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ReInitProposal, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+fn filter_out_reinit_if_other_proposals(
+ filter: bool,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ let proposal_count = proposals.length();
+
+ let has_reinit_and_other_proposal =
+ !proposals.reinit_proposals().is_empty() && proposal_count != 1;
+
+ if has_reinit_and_other_proposal {
+ let any_by_val = proposals.reinit_proposals().iter().any(|p| p.is_by_value());
+
+ if any_by_val || !filter {
+ return Err(MlsError::OtherProposalWithReInit);
+ }
+
+ let has_other_proposal_type = proposal_count > proposals.reinit_proposals().len();
+
+ if has_other_proposal_type {
+ proposals.reinitializations = Vec::new();
+ } else {
+ proposals.reinitializations.truncate(1);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn filter_out_external_init(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ proposals.retain_by_type::<ExternalInit, _, _>(|p| {
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ Err(MlsError::InvalidProposalTypeForSender),
+ )
+ })?;
+
+ Ok(proposals)
+}
+
+pub(crate) fn proposer_can_propose(
+ proposer: Sender,
+ proposal_type: ProposalType,
+ by_ref: bool,
+) -> Result<(), MlsError> {
+ let can_propose = match (proposer, by_ref) {
+ (Sender::Member(_), false) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::Member(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::UPDATE
+ | ProposalType::REMOVE
+ | ProposalType::PSK
+ | ProposalType::RE_INIT
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), false) => false,
+ #[cfg(feature = "by_ref_proposal")]
+ (Sender::External(_), true) => matches!(
+ proposal_type,
+ ProposalType::ADD
+ | ProposalType::REMOVE
+ | ProposalType::RE_INIT
+ | ProposalType::PSK
+ | ProposalType::GROUP_CONTEXT_EXTENSIONS
+ ),
+ (Sender::NewMemberCommit, false) => matches!(
+ proposal_type,
+ ProposalType::REMOVE | ProposalType::PSK | ProposalType::EXTERNAL_INIT
+ ),
+ (Sender::NewMemberCommit, true) => false,
+ (Sender::NewMemberProposal, false) => false,
+ (Sender::NewMemberProposal, true) => matches!(proposal_type, ProposalType::ADD),
+ };
+
+ can_propose
+ .then_some(())
+ .ok_or(MlsError::InvalidProposalTypeForSender)
+}
+
+pub(crate) fn filter_out_invalid_proposers(
+ strategy: FilterStrategy,
+ mut proposals: ProposalBundle,
+) -> Result<ProposalBundle, MlsError> {
+ for i in (0..proposals.add_proposals().len()).rev() {
+ let p = &proposals.add_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::ADD, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<AddProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.update_proposals().len()).rev() {
+ let p = &proposals.update_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::UPDATE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<UpdateProposal>(i);
+ proposals.update_senders.remove(i);
+ }
+ }
+
+ for i in (0..proposals.remove_proposals().len()).rev() {
+ let p = &proposals.remove_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::REMOVE, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<RemoveProposal>(i);
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ for i in (0..proposals.psk_proposals().len()).rev() {
+ let p = &proposals.psk_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::PSK, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<PreSharedKeyProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.reinit_proposals().len()).rev() {
+ let p = &proposals.reinit_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::RE_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ReInitProposal>(i);
+ }
+ }
+
+ for i in (0..proposals.external_init_proposals().len()).rev() {
+ let p = &proposals.external_init_proposals()[i];
+ let res = proposer_can_propose(p.sender, ProposalType::EXTERNAL_INIT, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExternalInit>(i);
+ }
+ }
+
+ for i in (0..proposals.group_context_ext_proposals().len()).rev() {
+ let p = &proposals.group_context_ext_proposals()[i];
+ let gce_type = ProposalType::GROUP_CONTEXT_EXTENSIONS;
+ let res = proposer_can_propose(p.sender, gce_type, p.is_by_reference());
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ proposals.remove::<ExtensionList>(i);
+ }
+ }
+
+ Ok(proposals)
+}
+
+fn leaf_index_of_update_sender(p: &ProposalInfo<UpdateProposal>) -> Result<LeafIndex, MlsError> {
+ match p.sender {
+ Sender::Member(i) => Ok(LeafIndex(i)),
+ _ => Err(MlsError::InvalidProposalTypeForSender),
+ }
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &mut ProposalBundle,
+ tree: &TreeKemPublic,
+ strategy: FilterStrategy,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ proposals.retain_custom(|p| {
+ let proposal_type = p.proposal.proposal_type();
+
+ apply_strategy(
+ strategy,
+ p.is_by_reference(),
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type)),
+ )
+ })
+}
diff --git a/src/group/proposal_filter/filtering_common.rs b/src/group/proposal_filter/filtering_common.rs
new file mode 100644
index 0000000..278c0de
--- /dev/null
+++ b/src/group/proposal_filter/filtering_common.rs
@@ -0,0 +1,579 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::{proposal_filter::ProposalBundle, Sender},
+ key_package::{validate_key_package_properties, KeyPackage},
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+ TreeKemPublic,
+ },
+ CipherSuiteProvider, ExtensionList,
+};
+
+use crate::tree_kem::leaf_node::LeafNode;
+
+use super::ProposalInfo;
+
+use crate::extension::{MlsExtension, RequiredCapabilitiesExt};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use mls_rs_core::error::IntoAnyError;
+
+use alloc::vec::Vec;
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+use crate::group::{ExternalInit, ProposalType, RemoveProposal};
+
+#[cfg(all(feature = "by_ref_proposal", feature = "psk"))]
+use crate::group::proposal::PreSharedKeyProposal;
+
+#[cfg(feature = "psk")]
+use crate::group::{JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+#[cfg(feature = "by_ref_proposal")]
+use super::filtering::{apply_strategy, filter_out_invalid_proposers, FilterStrategy};
+
+#[cfg(feature = "custom_proposal")]
+use super::filtering::filter_out_unsupported_custom_proposals;
+
+#[derive(Debug)]
+pub(crate) struct ProposalApplier<'a, C, P, CSP> {
+ pub original_tree: &'a TreeKemPublic,
+ pub protocol_version: ProtocolVersion,
+ pub cipher_suite_provider: &'a CSP,
+ pub original_group_extensions: &'a ExtensionList,
+ pub external_leaf: Option<&'a LeafNode>,
+ pub identity_provider: &'a C,
+ pub psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")]
+ pub group_id: &'a [u8],
+}
+
+#[derive(Debug)]
+pub(crate) struct ApplyProposalsOutput {
+ pub(crate) new_tree: TreeKemPublic,
+ pub(crate) indexes_of_added_kpkgs: Vec<LeafIndex>,
+ pub(crate) external_init_index: Option<LeafIndex>,
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) applied_proposals: ProposalBundle,
+ pub(crate) new_context_extensions: Option<ExtensionList>,
+}
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[allow(clippy::too_many_arguments)]
+ pub(crate) fn new(
+ original_tree: &'a TreeKemPublic,
+ protocol_version: ProtocolVersion,
+ cipher_suite_provider: &'a CSP,
+ original_group_extensions: &'a ExtensionList,
+ external_leaf: Option<&'a LeafNode>,
+ identity_provider: &'a C,
+ psk_storage: &'a P,
+ #[cfg(feature = "by_ref_proposal")] group_id: &'a [u8],
+ ) -> Self {
+ Self {
+ original_tree,
+ protocol_version,
+ cipher_suite_provider,
+ original_group_extensions,
+ external_leaf,
+ identity_provider,
+ psk_storage,
+ #[cfg(feature = "by_ref_proposal")]
+ group_id,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_proposals(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ commit_sender: &Sender,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let output = match commit_sender {
+ Sender::Member(sender) => {
+ self.apply_proposals_from_member(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ LeafIndex(*sender),
+ proposals,
+ commit_time,
+ )
+ .await
+ }
+ Sender::NewMemberCommit => {
+ self.apply_proposals_from_new_member(proposals, commit_time)
+ .await
+ }
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExternalSenderCannotCommit),
+ }?;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ let mut output = output;
+
+ #[cfg(all(feature = "by_ref_proposal", feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(
+ &mut output.applied_proposals,
+ &output.new_tree,
+ strategy,
+ )?;
+
+ #[cfg(all(not(feature = "by_ref_proposal"), feature = "custom_proposal"))]
+ filter_out_unsupported_custom_proposals(proposals, &output.new_tree)?;
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ // The lint below is triggered by the `proposals` parameter which may or may not be a borrow.
+ #[allow(clippy::needless_borrow)]
+ async fn apply_proposals_from_new_member(
+ &self,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ let external_leaf = self
+ .external_leaf
+ .ok_or(MlsError::ExternalCommitMustHaveNewLeaf)?;
+
+ ensure_exactly_one_external_init(&proposals)?;
+
+ ensure_at_most_one_removal_for_self(
+ &proposals,
+ external_leaf,
+ self.original_tree,
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?;
+
+ ensure_proposals_in_external_commit_are_allowed(&proposals)?;
+ ensure_no_proposal_by_ref(&proposals)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals = filter_out_invalid_proposers(FilterStrategy::IgnoreNone, proposals)?;
+
+ filter_out_invalid_psks(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ self.cipher_suite_provider,
+ #[cfg(feature = "by_ref_proposal")]
+ &mut proposals,
+ #[cfg(not(feature = "by_ref_proposal"))]
+ proposals,
+ self.psk_storage,
+ )
+ .await?;
+
+ let mut output = self
+ .apply_proposal_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ FilterStrategy::IgnoreNone,
+ proposals,
+ commit_time,
+ )
+ .await?;
+
+ output.external_init_index = Some(
+ insert_external_leaf(
+ &mut output.new_tree,
+ external_leaf.clone(),
+ self.identity_provider,
+ self.original_group_extensions,
+ )
+ .await?,
+ );
+
+ Ok(output)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_with_new_capabilities(
+ &self,
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: ProposalBundle,
+ group_context_extensions_proposal: ProposalInfo<ExtensionList>,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ #[cfg(feature = "by_ref_proposal")]
+ let mut proposals_clone = proposals.clone();
+
+ // Apply adds, updates etc. in the context of new extensions
+ let output = self
+ .apply_tree_changes(
+ #[cfg(feature = "by_ref_proposal")]
+ strategy,
+ proposals,
+ &group_context_extensions_proposal.proposal,
+ commit_time,
+ )
+ .await?;
+
+ // Verify that capabilities and extensions are supported after modifications.
+ // TODO: The newly inserted nodes have already been validated by `apply_tree_changes`
+ // above. We should investigate if there is an easy way to avoid the double check.
+ let must_check = group_context_extensions_proposal
+ .proposal
+ .has_extension(RequiredCapabilitiesExt::extension_type());
+
+ #[cfg(feature = "by_ref_proposal")]
+ let must_check = must_check
+ || group_context_extensions_proposal
+ .proposal
+ .has_extension(ExternalSendersExt::extension_type());
+
+ let new_capabilities_supported = if must_check {
+ let leaf_validator = LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(&group_context_extensions_proposal.proposal),
+ );
+
+ output
+ .new_tree
+ .non_empty_leaves()
+ .try_for_each(|(_, leaf)| {
+ leaf_validator.validate_required_capabilities(leaf)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ leaf_validator.validate_external_senders_ext_credentials(leaf)?;
+
+ Ok(())
+ })
+ } else {
+ Ok(())
+ };
+
+ let new_extensions_supported = group_context_extensions_proposal
+ .proposal
+ .iter()
+ .map(|extension| extension.extension_type)
+ .filter(|&ext_type| !ext_type.is_default())
+ .find(|ext_type| {
+ !output
+ .new_tree
+ .non_empty_leaves()
+ .all(|(_, leaf)| leaf.capabilities.extensions.contains(ext_type))
+ })
+ .map_or(Ok(()), |ext| Err(MlsError::UnsupportedGroupExtension(ext)));
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ {
+ new_capabilities_supported.and(new_extensions_supported)?;
+ Ok(output)
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ // If extensions are good, return `Ok`. If not and the strategy is to filter, remove the group
+ // context extensions proposal and try applying all proposals again in the context of the old
+ // extensions. Else, return an error.
+ match new_capabilities_supported.and(new_extensions_supported) {
+ Ok(()) => Ok(output),
+ Err(e) => {
+ if strategy.ignore(group_context_extensions_proposal.is_by_reference()) {
+ proposals_clone.group_context_extensions.clear();
+
+ self.apply_tree_changes(
+ strategy,
+ proposals_clone,
+ self.original_group_extensions,
+ commit_time,
+ )
+ .await
+ } else {
+ Err(e)
+ }
+ }
+ }
+ }
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ .await?;
+
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ .await
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rayon"))]
+ pub fn validate_new_node<Ip: IdentityProvider, Cp: CipherSuiteProvider>(
+ &self,
+ leaf_node_validator: &LeafNodeValidator<'_, Ip, Cp>,
+ key_package: &KeyPackage,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let (a, b) = rayon::join(
+ || {
+ leaf_node_validator
+ .check_if_valid(&key_package.leaf_node, ValidationContext::Add(commit_time))
+ },
+ || {
+ validate_key_package_properties(
+ key_package,
+ self.protocol_version,
+ self.cipher_suite_provider,
+ )
+ },
+ );
+ a?;
+ b
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] strategy: FilterStrategy,
+ cipher_suite_provider: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] proposals: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] proposals: &mut ProposalBundle,
+ psk_storage: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ let kdf_extract_size = cipher_suite_provider.kdf_extract_size();
+
+ #[cfg(feature = "std")]
+ let mut ids_seen = HashSet::new();
+
+ #[cfg(not(feature = "std"))]
+ let mut ids_seen = Vec::new();
+
+ #[cfg(feature = "by_ref_proposal")]
+ let mut bad_indices = Vec::new();
+
+ for i in 0..proposals.psk_proposals().len() {
+ let p = &proposals.psks[i];
+
+ let valid = matches!(
+ p.proposal.psk.key_id,
+ JustPreSharedKeyID::External(_)
+ | JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage: ResumptionPSKUsage::Application,
+ ..
+ })
+ );
+
+ let nonce_length = p.proposal.psk.psk_nonce.0.len();
+ let nonce_valid = nonce_length == kdf_extract_size;
+
+ #[cfg(feature = "std")]
+ let is_new_id = ids_seen.insert(p.proposal.psk.clone());
+
+ #[cfg(not(feature = "std"))]
+ let is_new_id = !ids_seen.contains(&p.proposal.psk);
+
+ let external_id_is_valid = match &p.proposal.psk.key_id {
+ JustPreSharedKeyID::External(id) => psk_storage
+ .contains(id)
+ .await
+ .map_err(|e| MlsError::PskStoreError(e.into_any_error()))
+ .and_then(|found| {
+ if found {
+ Ok(())
+ } else {
+ Err(MlsError::MissingRequiredPsk)
+ }
+ }),
+ JustPreSharedKeyID::Resumption(_) => Ok(()),
+ };
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ if !valid {
+ return Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal);
+ } else if !nonce_valid {
+ return Err(MlsError::InvalidPskNonceLength);
+ } else if !is_new_id {
+ return Err(MlsError::DuplicatePskIds);
+ } else if external_id_is_valid.is_err() {
+ return external_id_is_valid;
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ {
+ let res = if !valid {
+ Err(MlsError::InvalidTypeOrUsageInPreSharedKeyProposal)
+ } else if !nonce_valid {
+ Err(MlsError::InvalidPskNonceLength)
+ } else if !is_new_id {
+ Err(MlsError::DuplicatePskIds)
+ } else {
+ external_id_is_valid
+ };
+
+ if !apply_strategy(strategy, p.is_by_reference(), res)? {
+ bad_indices.push(i)
+ }
+ }
+
+ #[cfg(not(feature = "std"))]
+ ids_seen.push(p.proposal.psk.clone());
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ bad_indices
+ .into_iter()
+ .rev()
+ .for_each(|i| proposals.remove::<PreSharedKeyProposal>(i));
+
+ Ok(())
+}
+
+#[cfg(not(feature = "psk"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn filter_out_invalid_psks<P, CP>(
+ #[cfg(feature = "by_ref_proposal")] _: FilterStrategy,
+ _: &CP,
+ #[cfg(not(feature = "by_ref_proposal"))] _: &ProposalBundle,
+ #[cfg(feature = "by_ref_proposal")] _: &mut ProposalBundle,
+ _: &P,
+) -> Result<(), MlsError>
+where
+ P: PreSharedKeyStorage,
+ CP: CipherSuiteProvider,
+{
+ Ok(())
+}
+
+fn ensure_exactly_one_external_init(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.by_type::<ExternalInit>().count() == 1)
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitMustHaveExactlyOneExternalInit)
+}
+
+/// Non-default proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom proposals in external commits
+fn ensure_proposals_in_external_commit_are_allowed(
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ let supported_default_types = [
+ ProposalType::EXTERNAL_INIT,
+ ProposalType::REMOVE,
+ ProposalType::PSK,
+ ];
+
+ let unsupported_type = proposals
+ .proposal_types()
+ .find(|ty| !supported_default_types.contains(ty) && ProposalType::DEFAULT.contains(ty));
+
+ match unsupported_type {
+ Some(kind) => Err(MlsError::InvalidProposalTypeInExternalCommit(kind)),
+ None => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_at_most_one_removal_for_self<C>(
+ proposals: &ProposalBundle,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let mut removals = proposals.by_type::<RemoveProposal>();
+
+ match (removals.next(), removals.next()) {
+ (Some(removal), None) => {
+ ensure_removal_is_for_self(
+ &removal.proposal,
+ external_leaf,
+ tree,
+ identity_provider,
+ extensions,
+ )
+ .await
+ }
+ (Some(_), Some(_)) => Err(MlsError::ExternalCommitWithMoreThanOneRemove),
+ (None, _) => Ok(()),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn ensure_removal_is_for_self<C>(
+ removal: &RemoveProposal,
+ external_leaf: &LeafNode,
+ tree: &TreeKemPublic,
+ identity_provider: &C,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ let existing_signing_id = &tree.get_leaf_node(removal.to_remove)?.signing_identity;
+
+ identity_provider
+ .valid_successor(
+ existing_signing_id,
+ &external_leaf.signing_identity,
+ extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
+ .then_some(())
+ .ok_or(MlsError::ExternalCommitRemovesOtherIdentity)
+}
+
+/// Non-default by-ref proposal types are by default allowed. Custom MlsRules may disallow
+/// specific custom by-ref proposals.
+fn ensure_no_proposal_by_ref(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ proposals
+ .iter_proposals()
+ .all(|p| !ProposalType::DEFAULT.contains(&p.proposal.proposal_type()) || p.is_by_value())
+ .then_some(())
+ .ok_or(MlsError::OnlyMembersCanCommitProposalsByRef)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn insert_external_leaf<I: IdentityProvider>(
+ tree: &mut TreeKemPublic,
+ leaf_node: LeafNode,
+ identity_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<LeafIndex, MlsError> {
+ tree.add_leaf(leaf_node, identity_provider, extensions, None)
+ .await
+}
diff --git a/src/group/proposal_filter/filtering_lite.rs b/src/group/proposal_filter/filtering_lite.rs
new file mode 100644
index 0000000..09ca389
--- /dev/null
+++ b/src/group/proposal_filter/filtering_lite.rs
@@ -0,0 +1,225 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ group::proposal_filter::ProposalBundle,
+ iter::wrap_iter,
+ protocol_version::ProtocolVersion,
+ time::MlsTime,
+ tree_kem::{leaf_node_validator::LeafNodeValidator, node::LeafIndex},
+ CipherSuiteProvider, ExtensionList,
+};
+
+use super::filtering_common::{filter_out_invalid_psks, ApplyProposalsOutput, ProposalApplier};
+
+#[cfg(feature = "by_ref_proposal")]
+use {crate::extension::ExternalSendersExt, mls_rs_core::error::IntoAnyError};
+
+use mls_rs_core::{identity::IdentityProvider, psk::PreSharedKeyStorage};
+
+#[cfg(feature = "custom_proposal")]
+use itertools::Itertools;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use rayon::prelude::*;
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+#[cfg(feature = "custom_proposal")]
+use crate::tree_kem::TreeKemPublic;
+
+#[cfg(feature = "psk")]
+use crate::group::{
+ proposal::PreSharedKeyProposal, JustPreSharedKeyID, ResumptionPSKUsage, ResumptionPsk,
+};
+
+#[cfg(all(feature = "std", feature = "psk"))]
+use std::collections::HashSet;
+
+impl<'a, C, P, CSP> ProposalApplier<'a, C, P, CSP>
+where
+ C: IdentityProvider,
+ P: PreSharedKeyStorage,
+ CSP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposals_from_member(
+ &self,
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ filter_out_removal_of_committer(commit_sender, proposals)?;
+ filter_out_invalid_psks(self.cipher_suite_provider, proposals, self.psk_storage).await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ filter_out_invalid_group_extensions(proposals, self.identity_provider, commit_time).await?;
+
+ filter_out_extra_group_context_extensions(proposals)?;
+ filter_out_invalid_reinit(proposals, self.protocol_version)?;
+ filter_out_reinit_if_other_proposals(proposals)?;
+
+ self.apply_proposal_changes(proposals, commit_time).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_proposal_changes(
+ &self,
+ proposals: &ProposalBundle,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ match proposals.group_context_extensions_proposal().cloned() {
+ Some(p) => {
+ self.apply_proposals_with_new_capabilities(proposals, p, commit_time)
+ .await
+ }
+ None => {
+ self.apply_tree_changes(proposals, self.original_group_extensions, commit_time)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn apply_tree_changes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<ApplyProposalsOutput, MlsError> {
+ self.validate_new_nodes(proposals, group_extensions_in_use, commit_time)
+ .await?;
+
+ let mut new_tree = self.original_tree.clone();
+
+ let added = new_tree
+ .batch_edit_lite(
+ proposals,
+ group_extensions_in_use,
+ self.identity_provider,
+ self.cipher_suite_provider,
+ )
+ .await?;
+
+ let new_context_extensions = proposals
+ .group_context_extensions
+ .first()
+ .map(|gce| gce.proposal.clone());
+
+ Ok(ApplyProposalsOutput {
+ new_tree,
+ indexes_of_added_kpkgs: added,
+ external_init_index: None,
+ new_context_extensions,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_new_nodes(
+ &self,
+ proposals: &ProposalBundle,
+ group_extensions_in_use: &ExtensionList,
+ commit_time: Option<MlsTime>,
+ ) -> Result<(), MlsError> {
+ let leaf_node_validator = &LeafNodeValidator::new(
+ self.cipher_suite_provider,
+ self.identity_provider,
+ Some(group_extensions_in_use),
+ );
+
+ let adds = wrap_iter(proposals.add_proposals());
+
+ #[cfg(mls_build_async)]
+ let adds = adds.map(Ok);
+
+ { adds }
+ .try_for_each(|p| {
+ self.validate_new_node(leaf_node_validator, &p.proposal.key_package, commit_time)
+ })
+ .await
+ }
+}
+
+fn filter_out_removal_of_committer(
+ commit_sender: LeafIndex,
+ proposals: &ProposalBundle,
+) -> Result<(), MlsError> {
+ for p in &proposals.removals {
+ (p.proposal.to_remove != commit_sender)
+ .then_some(())
+ .ok_or(MlsError::CommitterSelfRemoval)?;
+ }
+
+ Ok(())
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn filter_out_invalid_group_extensions<C>(
+ proposals: &ProposalBundle,
+ identity_provider: &C,
+ commit_time: Option<MlsTime>,
+) -> Result<(), MlsError>
+where
+ C: IdentityProvider,
+{
+ if let Some(p) = proposals.group_context_extensions.first() {
+ if let Some(ext) = p.proposal.get_as::<ExternalSendersExt>()? {
+ ext.verify_all(identity_provider, commit_time, p.proposal())
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+ }
+
+ Ok(())
+}
+
+fn filter_out_extra_group_context_extensions(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.group_context_extensions.len() < 2)
+ .then_some(())
+ .ok_or(MlsError::MoreThanOneGroupContextExtensionsProposal)
+}
+
+fn filter_out_invalid_reinit(
+ proposals: &ProposalBundle,
+ protocol_version: ProtocolVersion,
+) -> Result<(), MlsError> {
+ if let Some(p) = proposals.reinitializations.first() {
+ (p.proposal.version >= protocol_version)
+ .then_some(())
+ .ok_or(MlsError::InvalidProtocolVersionInReInit)?;
+ }
+
+ Ok(())
+}
+
+fn filter_out_reinit_if_other_proposals(proposals: &ProposalBundle) -> Result<(), MlsError> {
+ (proposals.reinitializations.is_empty() || proposals.length() == 1)
+ .then_some(())
+ .ok_or(MlsError::OtherProposalWithReInit)
+}
+
+#[cfg(feature = "custom_proposal")]
+pub(super) fn filter_out_unsupported_custom_proposals(
+ proposals: &ProposalBundle,
+ tree: &TreeKemPublic,
+) -> Result<(), MlsError> {
+ let supported_types = proposals
+ .custom_proposal_types()
+ .filter(|t| tree.can_support_proposal(*t))
+ .collect_vec();
+
+ for p in &proposals.custom_proposals {
+ let proposal_type = p.proposal.proposal_type();
+
+ supported_types
+ .contains(&proposal_type)
+ .then_some(())
+ .ok_or(MlsError::UnsupportedCustomProposal(proposal_type))?;
+ }
+
+ Ok(())
+}
diff --git a/src/group/proposal_ref.rs b/src/group/proposal_ref.rs
new file mode 100644
index 0000000..c97c9a1
--- /dev/null
+++ b/src/group/proposal_ref.rs
@@ -0,0 +1,226 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::Deref;
+
+use super::*;
+use crate::hash_reference::HashReference;
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// Unique identifier for a proposal message.
+pub struct ProposalRef(HashReference);
+
+impl Deref for ProposalRef {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), ::safer_ffi_gen::safer_ffi_gen)]
+impl ProposalRef {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_content<CS: CipherSuiteProvider>(
+ cipher_suite_provider: &CS,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ let bytes = &content.mls_encode_to_vec()?;
+
+ Ok(ProposalRef(
+ HashReference::compute(bytes, b"MLS 1.0 Proposal Reference", cipher_suite_provider)
+ .await?,
+ ))
+ }
+
+ pub fn as_slice(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::group::test_utils::{random_bytes, TEST_GROUP};
+ use alloc::boxed::Box;
+
+ impl ProposalRef {
+ pub fn new_fake(bytes: Vec<u8>) -> Self {
+ Self(bytes.into())
+ }
+ }
+
+ pub fn auth_content_from_proposal<S>(proposal: Proposal, sender: S) -> AuthenticatedContent
+ where
+ S: Into<Sender>,
+ {
+ AuthenticatedContent {
+ wire_format: WireFormat::PublicMessage,
+ content: FramedContent {
+ group_id: TEST_GROUP.to_vec(),
+ epoch: 0,
+ sender: sender.into(),
+ authenticated_data: vec![],
+ content: Content::Proposal(Box::new(proposal)),
+ },
+ auth: FramedContentAuthData {
+ signature: MessageSignature::from(random_bytes(128)),
+ confirmation_tag: None,
+ },
+ }
+ }
+}
+
+#[cfg(test)]
+mod test {
+ use super::test_utils::auth_content_from_proposal;
+ use super::*;
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ key_package::test_utils::test_key_package,
+ tree_kem::leaf_node::test_utils::get_basic_test_node,
+ };
+ use alloc::boxed::Box;
+
+ use crate::extension::RequiredCapabilitiesExt;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn get_test_extension_list() -> ExtensionList {
+ let test_extension = RequiredCapabilitiesExt {
+ extensions: vec![42.into()],
+ proposals: Default::default(),
+ credentials: vec![],
+ };
+
+ let mut extension_list = ExtensionList::new();
+ extension_list.set_from(test_extension).unwrap();
+
+ extension_list
+ }
+
+ #[derive(serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ input: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ output: Vec<u8>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate_proposal_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for (protocol_version, cipher_suite) in
+ ProtocolVersion::all().flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
+ {
+ let sender = LeafIndex(0);
+
+ let add = auth_content_from_proposal(
+ Proposal::Add(Box::new(AddProposal {
+ key_package: test_key_package(protocol_version, cipher_suite, "alice").await,
+ })),
+ sender,
+ );
+
+ let update = auth_content_from_proposal(
+ Proposal::Update(UpdateProposal {
+ leaf_node: get_basic_test_node(cipher_suite, "foo").await,
+ }),
+ sender,
+ );
+
+ let remove = auth_content_from_proposal(
+ Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ }),
+ sender,
+ );
+
+ let group_context_ext = auth_content_from_proposal(
+ Proposal::GroupContextExtensions(get_test_extension_list()),
+ sender,
+ );
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: add.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &add)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: update.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &update)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: remove.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &remove)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: group_context_ext.mls_encode_to_vec().unwrap(),
+ output: ProposalRef::from_content(&cipher_suite_provider, &group_context_ext)
+ .await
+ .unwrap()
+ .to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(proposal_ref, generate_proposal_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_proposal_ref() {
+ let test_cases = load_test_cases().await;
+
+ for one_case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let proposal_content =
+ AuthenticatedContent::mls_decode(&mut one_case.input.as_slice()).unwrap();
+
+ let proposal_ref = ProposalRef::from_content(&cs_provider, &proposal_content)
+ .await
+ .unwrap();
+
+ let expected_out = ProposalRef(HashReference::from(one_case.output));
+
+ assert_eq!(expected_out, proposal_ref);
+ }
+ }
+}
diff --git a/src/group/resumption.rs b/src/group/resumption.rs
new file mode 100644
index 0000000..3478ef3
--- /dev/null
+++ b/src/group/resumption.rs
@@ -0,0 +1,299 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+use mls_rs_core::{
+ crypto::{CipherSuite, SignatureSecretKey},
+ extension::ExtensionList,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+};
+
+use crate::{client::MlsError, Client, Group, MlsMessage};
+
+use super::{
+ proposal::ReInitProposal, ClientConfig, ExportedTree, JustPreSharedKeyID, MessageProcessor,
+ NewMemberInfo, PreSharedKeyID, PskGroupId, PskSecretInput, ResumptionPSKUsage, ResumptionPsk,
+};
+
+struct ResumptionGroupParameters<'a> {
+ group_id: &'a [u8],
+ cipher_suite: CipherSuite,
+ version: ProtocolVersion,
+ extensions: &'a ExtensionList,
+}
+
+pub struct ReinitClient<C: ClientConfig + Clone> {
+ client: Client<C>,
+ reinit: ReInitProposal,
+ psk_input: PskSecretInput,
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Create a sub-group from a subset of the current group members.
+ ///
+ /// Membership within the resulting sub-group is indicated by providing a
+ /// key package that produces the same
+ /// [identity](crate::IdentityProvider::identity) value
+ /// as an existing group member. The identity value of each key package
+ /// is determined using the
+ /// [`IdentityProvider`](crate::IdentityProvider)
+ /// that is currently in use by this group instance.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn branch(
+ &self,
+ sub_group_id: Vec<u8>,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: &sub_group_id,
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_create_group(
+ self.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // TODO investigate if it's worth updating your own signing identity here
+ self.current_member_signing_identity()?.clone(),
+ self.signer.clone(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Join a subgroup that was created by [`Group::branch`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join_subgroup(
+ &self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let expected_new_group_prams = ResumptionGroupParameters {
+ group_id: &[],
+ cipher_suite: self.cipher_suite(),
+ version: self.protocol_version(),
+ extensions: &self.group_state().context.extensions,
+ };
+
+ resumption_join_group(
+ self.config.clone(),
+ self.signer.clone(),
+ welcome,
+ tree_data,
+ expected_new_group_prams,
+ false,
+ self.resumption_psk_input(ResumptionPSKUsage::Branch)?,
+ )
+ .await
+ }
+
+ /// Generate a [`ReinitClient`] that can be used to create or join a new group
+ /// that is based on properties defined by a [`ReInitProposal`]
+ /// committed in a previously accepted commit. This is the only action available
+ /// after accepting such a commit. The old group can no longer be used according to the RFC.
+ ///
+ /// If the [`ReInitProposal`] changes the ciphersuite, then `new_signer`
+ /// and `new_signer_identity` must be set and match the new ciphersuite, as indicated by
+ /// [`pending_reinit_ciphersuite`](crate::group::StateUpdate::pending_reinit_ciphersuite)
+ /// of the [`StateUpdate`](crate::group::StateUpdate) outputted after processing the
+ /// commit to the reinit proposal. The value of [identity](crate::IdentityProvider::identity)
+ /// must be the same for `new_signing_identity` and the current identity in use by this
+ /// group instance.
+ pub fn get_reinit_client(
+ self,
+ new_signer: Option<SignatureSecretKey>,
+ new_signing_identity: Option<SigningIdentity>,
+ ) -> Result<ReinitClient<C>, MlsError> {
+ let psk_input = self.resumption_psk_input(ResumptionPSKUsage::Reinit)?;
+
+ let new_signing_identity = new_signing_identity
+ .map(Ok)
+ .unwrap_or_else(|| self.current_member_signing_identity().cloned())?;
+
+ let reinit = self
+ .state
+ .pending_reinit
+ .ok_or(MlsError::PendingReInitNotFound)?;
+
+ let new_signer = match new_signer {
+ Some(signer) => signer,
+ None => self.signer,
+ };
+
+ let client = Client::new(
+ self.config,
+ Some(new_signer),
+ Some((new_signing_identity, reinit.new_cipher_suite())),
+ reinit.new_version(),
+ );
+
+ Ok(ReinitClient {
+ client,
+ reinit,
+ psk_input,
+ })
+ }
+
+ fn resumption_psk_input(&self, usage: ResumptionPSKUsage) -> Result<PskSecretInput, MlsError> {
+ let psk = self.epoch_secrets.resumption_secret.clone();
+
+ let id = JustPreSharedKeyID::Resumption(ResumptionPsk {
+ usage,
+ psk_group_id: PskGroupId(self.group_id().to_vec()),
+ psk_epoch: self.current_epoch(),
+ });
+
+ let id = PreSharedKeyID::new(id, self.cipher_suite_provider())?;
+ Ok(PskSecretInput { id, psk })
+ }
+}
+
+/// A [`Client`] that can be used to create or join a new group
+/// that is based on properties defined by a [`ReInitProposal`]
+/// committed in a previously accepted commit.
+impl<C: ClientConfig + Clone> ReinitClient<C> {
+ /// Generate a key package for the new group. The key package can
+ /// be used in [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate_key_package(&self) -> Result<MlsMessage, MlsError> {
+ self.client.generate_key_package_message().await
+ }
+
+ /// Create the new group using new key packages of all group members, possibly
+ /// generated by [`ReinitClient::generate_key_package`].
+ ///
+ /// # Warning
+ ///
+ /// This function will fail if the number of members in the reinitialized
+ /// group is not the same as the prior group roster.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit(
+ self,
+ new_key_packages: Vec<MlsMessage>,
+ ) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ let new_group_params = ResumptionGroupParameters {
+ group_id: self.reinit.group_id(),
+ cipher_suite: self.reinit.new_cipher_suite(),
+ version: self.reinit.new_version(),
+ extensions: self.reinit.new_group_context_extensions(),
+ };
+
+ resumption_create_group(
+ self.client.config.clone(),
+ new_key_packages,
+ &new_group_params,
+ // These private fields are created with `Some(x)` by `get_reinit_client`
+ self.client.signing_identity.unwrap().0,
+ self.client.signer.unwrap(),
+ #[cfg(any(feature = "private_message", feature = "psk"))]
+ self.psk_input,
+ )
+ .await
+ }
+
+ /// Join a reinitialized group that was created by [`ReinitClient::commit`].
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn join(
+ self,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ ) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let reinit = self.reinit;
+
+ let expected_group_params = ResumptionGroupParameters {
+ group_id: reinit.group_id(),
+ cipher_suite: reinit.new_cipher_suite(),
+ version: reinit.new_version(),
+ extensions: reinit.new_group_context_extensions(),
+ };
+
+ resumption_join_group(
+ self.client.config,
+ // This private field is created with `Some(x)` by `get_reinit_client`
+ self.client.signer.unwrap(),
+ welcome,
+ tree_data,
+ expected_group_params,
+ true,
+ self.psk_input,
+ )
+ .await
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_create_group<C: ClientConfig + Clone>(
+ config: C,
+ new_key_packages: Vec<MlsMessage>,
+ new_group_params: &ResumptionGroupParameters<'_>,
+ signing_identity: SigningIdentity,
+ signer: SignatureSecretKey,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, Vec<MlsMessage>), MlsError> {
+ // Create a new group with new parameters
+ let mut group = Group::new(
+ config,
+ Some(new_group_params.group_id.to_vec()),
+ new_group_params.cipher_suite,
+ new_group_params.version,
+ signing_identity,
+ new_group_params.extensions.clone(),
+ signer,
+ )
+ .await?;
+
+ // Install the resumption psk in the new group
+ group.previous_psk = Some(psk_input);
+
+ // Create a commit that adds new key packages and uses the resumption PSK
+ let mut commit = group.commit_builder();
+
+ for kp in new_key_packages.into_iter() {
+ commit = commit.add_member(kp)?;
+ }
+
+ let commit = commit.build().await?;
+ group.apply_pending_commit().await?;
+
+ // Uninstall the resumption psk on success (in case of failure, the new group is discarded anyway)
+ group.previous_psk = None;
+
+ Ok((group, commit.welcome_messages))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn resumption_join_group<C: ClientConfig + Clone>(
+ config: C,
+ signer: SignatureSecretKey,
+ welcome: &MlsMessage,
+ tree_data: Option<ExportedTree<'_>>,
+ expected_new_group_params: ResumptionGroupParameters<'_>,
+ verify_group_id: bool,
+ psk_input: PskSecretInput,
+) -> Result<(Group<C>, NewMemberInfo), MlsError> {
+ let psk_input = Some(psk_input);
+
+ let (group, new_member_info) =
+ Group::<C>::from_welcome_message(welcome, tree_data, config, signer, psk_input).await?;
+
+ if group.protocol_version() != expected_new_group_params.version {
+ Err(MlsError::ProtocolVersionMismatch)
+ } else if group.cipher_suite() != expected_new_group_params.cipher_suite {
+ Err(MlsError::CipherSuiteMismatch)
+ } else if verify_group_id && group.group_id() != expected_new_group_params.group_id {
+ Err(MlsError::GroupIdMismatch)
+ } else if &group.group_state().context.extensions != expected_new_group_params.extensions {
+ Err(MlsError::ReInitExtensionsMismatch)
+ } else {
+ Ok((group, new_member_info))
+ }
+}
diff --git a/src/group/roster.rs b/src/group/roster.rs
new file mode 100644
index 0000000..dd0a9f0
--- /dev/null
+++ b/src/group/roster.rs
@@ -0,0 +1,91 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::*;
+
+pub use mls_rs_core::group::Member;
+
+#[cfg(feature = "state_update")]
+pub(crate) fn member_from_key_package(key_package: &KeyPackage, index: LeafIndex) -> Member {
+ member_from_leaf_node(&key_package.leaf_node, index)
+}
+
+pub(crate) fn member_from_leaf_node(leaf_node: &LeafNode, leaf_index: LeafIndex) -> Member {
+ Member::new(
+ *leaf_index,
+ leaf_node.signing_identity.clone(),
+ leaf_node.ungreased_capabilities(),
+ leaf_node.ungreased_extensions(),
+ )
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, Debug)]
+pub struct Roster<'a> {
+ pub(crate) public_tree: &'a TreeKemPublic,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl<'a> Roster<'a> {
+ /// Iterator over the current roster that lazily copies data out of the
+ /// internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn members_iter(&self) -> impl Iterator<Item = Member> + 'a {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(index, node)| member_from_leaf_node(node, index))
+ }
+
+ /// The current set of group members. This function makes a clone of
+ /// member information from the internal group state.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this roster do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ pub fn members(&self) -> Vec<Member> {
+ self.members_iter().collect()
+ }
+
+ /// Retrieve the member with given `index` within the group in time `O(1)`.
+ /// This index does correlate with indexes of users within [`ReceivedMessage`]
+ /// content descriptions.
+ pub fn member_with_index(&self, index: u32) -> Result<Member, MlsError> {
+ let index = LeafIndex(index);
+
+ self.public_tree
+ .get_leaf_node(index)
+ .map(|l| member_from_leaf_node(l, index))
+ }
+
+ /// Iterator over member's signing identities.
+ ///
+ /// # Warning
+ ///
+ /// The indexes within this iterator do not correlate with indexes of users
+ /// within [`ReceivedMessage`] content descriptions due to the layout of
+ /// member information within a MLS group state.
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ pub fn member_identities_iter(&self) -> impl Iterator<Item = &SigningIdentity> + '_ {
+ self.public_tree
+ .non_empty_leaves()
+ .map(|(_, node)| &node.signing_identity)
+ }
+}
+
+impl TreeKemPublic {
+ pub(crate) fn roster(&self) -> Roster {
+ Roster { public_tree: self }
+ }
+}
diff --git a/src/group/secret_tree.rs b/src/group/secret_tree.rs
new file mode 100644
index 0000000..df0c30f
--- /dev/null
+++ b/src/group/secret_tree.rs
@@ -0,0 +1,1115 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::{Deref, DerefMut},
+};
+
+use zeroize::Zeroizing;
+
+use crate::{client::MlsError, tree_kem::math::TreeIndex, CipherSuiteProvider};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+
+use super::key_schedule::kdf_expand_with_label;
+
+pub(crate) const MAX_RATCHET_BACK_HISTORY: u32 = 1024;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+enum SecretTreeNode {
+ Secret(TreeSecret) = 0u8,
+ Ratchet(SecretRatchets) = 1u8,
+}
+
+impl SecretTreeNode {
+ fn into_secret(self) -> Option<TreeSecret> {
+ if let SecretTreeNode::Secret(secret) = self {
+ Some(secret)
+ } else {
+ None
+ }
+ }
+}
+
+#[derive(Clone, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for TreeSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("TreeSecret")
+ .fmt(f)
+ }
+}
+
+impl Deref for TreeSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for TreeSecret {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl AsRef<[u8]> for TreeSecret {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for TreeSecret {
+ fn from(vec: Vec<u8>) -> Self {
+ TreeSecret(Zeroizing::new(vec))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for TreeSecret {
+ fn from(vec: Zeroizing<Vec<u8>>) -> Self {
+ TreeSecret(vec)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+struct TreeSecretsVec<T: TreeIndex> {
+ #[cfg(feature = "std")]
+ inner: HashMap<T, SecretTreeNode>,
+ #[cfg(not(feature = "std"))]
+ inner: Vec<(T, SecretTreeNode)>,
+}
+
+#[cfg(feature = "std")]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ self.inner.insert(index, value);
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.inner.remove(index)
+ }
+}
+
+#[cfg(not(feature = "std"))]
+impl<T: TreeIndex> TreeSecretsVec<T> {
+ fn set_node(&mut self, index: T, value: SecretTreeNode) {
+ if let Some(i) = self.find_node(&index) {
+ self.inner[i] = (index, value)
+ } else {
+ self.inner.push((index, value))
+ }
+ }
+
+ fn take_node(&mut self, index: &T) -> Option<SecretTreeNode> {
+ self.find_node(index).map(|i| self.inner.remove(i).1)
+ }
+
+ fn find_node(&self, index: &T) -> Option<usize> {
+ use itertools::Itertools;
+
+ self.inner
+ .iter()
+ .find_position(|(i, _)| i == index)
+ .map(|(i, _)| i)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretTree<T: TreeIndex> {
+ known_secrets: TreeSecretsVec<T>,
+ leaf_count: T,
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub(crate) fn empty() -> SecretTree<T> {
+ SecretTree {
+ known_secrets: Default::default(),
+ leaf_count: T::zero(),
+ }
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretRatchets {
+ pub application: SecretKeyRatchet,
+ pub handshake: SecretKeyRatchet,
+}
+
+impl SecretRatchets {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => {
+ self.handshake
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ KeyType::Application => {
+ self.application
+ .get_message_key(cipher_suite_provider, generation)
+ .await
+ }
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ match key_type {
+ KeyType::Handshake => self.handshake.next_message_key(cipher_suite).await,
+ KeyType::Application => self.application.next_message_key(cipher_suite).await,
+ }
+ }
+}
+
+impl<T: TreeIndex> SecretTree<T> {
+ pub fn new(leaf_count: T, encryption_secret: Zeroizing<Vec<u8>>) -> SecretTree<T> {
+ let mut known_secrets = TreeSecretsVec::default();
+
+ let root_secret = SecretTreeNode::Secret(TreeSecret::from(encryption_secret));
+ known_secrets.set_node(leaf_count.root(), root_secret);
+
+ Self {
+ known_secrets,
+ leaf_count,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn consume_node<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ index: &T,
+ ) -> Result<(), MlsError> {
+ let node = self.known_secrets.take_node(index);
+
+ if let Some(secret) = node.and_then(|n| n.into_secret()) {
+ let left_index = index.left().ok_or(MlsError::LeafNodeNoChildren)?;
+ let right_index = index.right().ok_or(MlsError::LeafNodeNoChildren)?;
+
+ let left_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"left", None)
+ .await?;
+
+ let right_secret =
+ kdf_expand_with_label(cipher_suite_provider, &secret, b"tree", b"right", None)
+ .await?;
+
+ self.known_secrets
+ .set_node(left_index, SecretTreeNode::Secret(left_secret.into()));
+
+ self.known_secrets
+ .set_node(right_index, SecretTreeNode::Secret(right_secret.into()));
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn take_leaf_ratchet<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: &T,
+ ) -> Result<SecretRatchets, MlsError> {
+ let node_index = leaf_index;
+
+ let node = match self.known_secrets.take_node(node_index) {
+ Some(node) => node,
+ None => {
+ // Start at the root node and work your way down consuming any intermediates needed
+ for i in node_index.direct_copath(&self.leaf_count).into_iter().rev() {
+ self.consume_node(cipher_suite, &i.path).await?;
+ }
+
+ self.known_secrets
+ .take_node(node_index)
+ .ok_or(MlsError::InvalidLeafConsumption)?
+ }
+ };
+
+ Ok(match node {
+ SecretTreeNode::Ratchet(ratchet) => ratchet,
+ SecretTreeNode::Secret(secret) => SecretRatchets {
+ application: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Application)
+ .await?,
+ handshake: SecretKeyRatchet::new(cipher_suite, &secret, KeyType::Handshake).await?,
+ },
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+ let res = ratchet.next_message_key(cipher_suite, key_type).await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn message_key_generation<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite: &P,
+ leaf_index: T,
+ key_type: KeyType,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ let mut ratchet = self.take_leaf_ratchet(cipher_suite, &leaf_index).await?;
+
+ let res = ratchet
+ .message_key_generation(cipher_suite, generation, key_type)
+ .await?;
+
+ self.known_secrets
+ .set_node(leaf_index, SecretTreeNode::Ratchet(ratchet));
+
+ Ok(res)
+ }
+}
+
+#[derive(Clone, Copy)]
+pub enum KeyType {
+ Handshake,
+ Application,
+}
+
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[derive(Clone, PartialEq, Eq, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+/// AEAD key derived by the MLS secret tree.
+pub struct MessageKeyData {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) nonce: Zeroizing<Vec<u8>>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ pub(crate) key: Zeroizing<Vec<u8>>,
+ pub(crate) generation: u32,
+}
+
+impl Debug for MessageKeyData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("MessageKeyData")
+ .field("nonce", &mls_rs_core::debug::pretty_bytes(&self.nonce))
+ .field("key", &mls_rs_core::debug::pretty_bytes(&self.key))
+ .field("generation", &self.generation)
+ .finish()
+ }
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl MessageKeyData {
+ /// AEAD nonce.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn nonce(&self) -> &[u8] {
+ &self.nonce
+ }
+
+ /// AEAD key.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn key(&self) -> &[u8] {
+ &self.key
+ }
+
+ /// Generation of this key within the key schedule.
+ #[cfg_attr(not(feature = "secret_tree_access"), allow(dead_code))]
+ pub fn generation(&self) -> u32 {
+ self.generation
+ }
+}
+
+#[derive(Debug, Clone, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct SecretKeyRatchet {
+ secret: TreeSecret,
+ generation: u32,
+ #[cfg(all(feature = "out_of_order", feature = "std"))]
+ history: HashMap<u32, MessageKeyData>,
+ #[cfg(all(feature = "out_of_order", not(feature = "std")))]
+ history: BTreeMap<u32, MessageKeyData>,
+}
+
+impl MlsSize for SecretKeyRatchet {
+ fn mls_encoded_len(&self) -> usize {
+ let len = mls_rs_codec::byte_vec::mls_encoded_len(&self.secret)
+ + self.generation.mls_encoded_len();
+
+ #[cfg(feature = "out_of_order")]
+ return len + mls_rs_codec::iter::mls_encoded_len(self.history.values());
+ #[cfg(not(feature = "out_of_order"))]
+ return len;
+ }
+}
+
+#[cfg(feature = "out_of_order")]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)?;
+ mls_rs_codec::iter::mls_encode(self.history.values(), writer)
+ }
+}
+
+#[cfg(not(feature = "out_of_order"))]
+impl MlsEncode for SecretKeyRatchet {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ mls_rs_codec::byte_vec::mls_encode(&self.secret, writer)?;
+ self.generation.mls_encode(writer)
+ }
+}
+
+impl MlsDecode for SecretKeyRatchet {
+ fn mls_decode(reader: &mut &[u8]) -> Result<Self, mls_rs_codec::Error> {
+ Ok(Self {
+ secret: mls_rs_codec::byte_vec::mls_decode(reader)?,
+ generation: u32::mls_decode(reader)?,
+ #[cfg(all(feature = "std", feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = HashMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ #[cfg(all(not(feature = "std"), feature = "out_of_order"))]
+ history: mls_rs_codec::iter::mls_decode_collection(reader, |data| {
+ let mut items = alloc::collections::BTreeMap::default();
+
+ while !data.is_empty() {
+ let item = MessageKeyData::mls_decode(data)?;
+ items.insert(item.generation, item);
+ }
+
+ Ok(items)
+ })?,
+ })
+ }
+}
+
+impl SecretKeyRatchet {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn new<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret: &[u8],
+ key_type: KeyType,
+ ) -> Result<Self, MlsError> {
+ let label = match key_type {
+ KeyType::Handshake => b"handshake".as_slice(),
+ KeyType::Application => b"application".as_slice(),
+ };
+
+ let secret = kdf_expand_with_label(cipher_suite_provider, secret, label, &[], None)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(Self {
+ secret: TreeSecret::from(secret),
+ generation: 0,
+ #[cfg(feature = "out_of_order")]
+ history: Default::default(),
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ generation: u32,
+ ) -> Result<MessageKeyData, MlsError> {
+ #[cfg(feature = "out_of_order")]
+ if generation < self.generation {
+ return self
+ .history
+ .remove_entry(&generation)
+ .map(|(_, mk)| mk)
+ .ok_or(MlsError::KeyMissing(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ if generation < self.generation {
+ return Err(MlsError::KeyMissing(generation));
+ }
+
+ let max_generation_allowed = self.generation + MAX_RATCHET_BACK_HISTORY;
+
+ if generation > max_generation_allowed {
+ return Err(MlsError::InvalidFutureGeneration(generation));
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ while self.generation < generation {
+ self.next_message_key(cipher_suite_provider)?;
+ }
+
+ #[cfg(feature = "out_of_order")]
+ while self.generation < generation {
+ let key_data = self.next_message_key(cipher_suite_provider).await?;
+ self.history.insert(key_data.generation, key_data);
+ }
+
+ self.next_message_key(cipher_suite_provider).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn next_message_key<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ ) -> Result<MessageKeyData, MlsError> {
+ let generation = self.generation;
+
+ let key = MessageKeyData {
+ nonce: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"nonce",
+ cipher_suite_provider.aead_nonce_size(),
+ )
+ .await?,
+ key: self
+ .derive_secret(
+ cipher_suite_provider,
+ b"key",
+ cipher_suite_provider.aead_key_size(),
+ )
+ .await?,
+ generation,
+ };
+
+ self.secret = self
+ .derive_secret(
+ cipher_suite_provider,
+ b"secret",
+ cipher_suite_provider.kdf_extract_size(),
+ )
+ .await?
+ .into();
+
+ self.generation = generation + 1;
+
+ Ok(key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn derive_secret<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ label: &[u8],
+ len: usize,
+ ) -> Result<Zeroizing<Vec<u8>>, MlsError> {
+ kdf_expand_with_label(
+ cipher_suite_provider,
+ self.secret.as_ref(),
+ label,
+ &self.generation.to_be_bytes(),
+ Some(len),
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use zeroize::Zeroizing;
+
+ use crate::{crypto::test_utils::try_test_cipher_suite_provider, tree_kem::math::TreeIndex};
+
+ use super::{KeyType, SecretKeyRatchet, SecretTree};
+
+ pub(crate) fn get_test_tree<T: TreeIndex>(secret: Vec<u8>, leaf_count: T) -> SecretTree<T> {
+ SecretTree::new(leaf_count, Zeroizing::new(secret))
+ }
+
+ impl SecretTree<u32> {
+ pub(crate) fn get_root_secret(&self) -> Vec<u8> {
+ self.known_secrets
+ .clone()
+ .take_node(&self.leaf_count.root())
+ .unwrap()
+ .into_secret()
+ .unwrap()
+ .to_vec()
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct RatchetInteropTestCase {
+ #[serde(with = "hex::serde")]
+ secret: Vec<u8>,
+ label: String,
+ generation: u32,
+ length: usize,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ derive_tree_secret: RatchetInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.derive_tree_secret.verify(&cs).await
+ }
+ }
+ }
+
+ impl RatchetInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let mut ratchet = SecretKeyRatchet::new(cs, &self.secret, KeyType::Application)
+ .await
+ .unwrap();
+
+ ratchet.secret = self.secret.clone().into();
+ ratchet.generation = self.generation;
+
+ let computed = ratchet
+ .derive_secret(cs, self.label.as_bytes(), self.length)
+ .await
+ .unwrap();
+
+ assert_eq!(&computed.to_vec(), &self.out);
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ tree_kem::node::NodeIndex,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::group::test_utils::random_bytes;
+
+ use super::{test_utils::get_test_tree, *};
+
+ use assert_matches::assert_matches;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree() {
+ test_secret_tree_custom(16u32, (0..16).map(|i| 2 * i).collect(), true).await;
+ test_secret_tree_custom(1u64 << 62, (1..62).map(|i| 1u64 << i).collect(), false).await;
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_secret_tree_custom<T: TreeIndex>(
+ leaf_count: T,
+ leaves_to_check: Vec<T>,
+ all_deleted: bool,
+ ) {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+
+ let test_secret = vec![0u8; cs_provider.kdf_extract_size()];
+ let mut test_tree = get_test_tree(test_secret, leaf_count.clone());
+
+ let mut secrets = Vec::<SecretRatchets>::new();
+
+ for i in &leaves_to_check {
+ let secret = test_tree
+ .take_leaf_ratchet(&test_cipher_suite_provider(cipher_suite), i)
+ .await
+ .unwrap();
+
+ secrets.push(secret);
+ }
+
+ // Verify the tree is now completely empty
+ assert!(!all_deleted || test_tree.known_secrets.inner.is_empty());
+
+ // Verify that all the secrets are unique
+ let count = secrets.len();
+ secrets.dedup();
+ assert_eq!(count, secrets.len());
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_key_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut app_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut handshake_ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Handshake,
+ )
+ .await
+ .unwrap();
+
+ let app_key_one = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_key_two = app_ratchet.next_message_key(&provider).await.unwrap();
+ let app_keys = vec![app_key_one, app_key_two];
+
+ let handshake_key_one = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_key_two = handshake_ratchet.next_message_key(&provider).await.unwrap();
+ let handshake_keys = vec![handshake_key_one, handshake_key_two];
+
+ // Verify that the keys have different outcomes due to their different labels
+ assert_ne!(app_keys, handshake_keys);
+
+ // Verify that the keys at each generation are different
+ assert_ne!(handshake_keys[0], handshake_keys[1]);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_key() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &test_cipher_suite_provider(cipher_suite),
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let mut ratchet_clone = ratchet.clone();
+
+ // This will generate keys 0 and 1 in ratchet_clone
+ let _ = ratchet_clone.next_message_key(&provider).await.unwrap();
+ let clone_2 = ratchet_clone.next_message_key(&provider).await.unwrap();
+
+ // Going back in time should result in an error
+ let res = ratchet_clone.get_message_key(&provider, 0).await;
+ assert!(res.is_err());
+
+ // Calling get key should be the same as calling next until hitting the desired generation
+ let second_key = ratchet
+ .get_message_key(&provider, ratchet_clone.generation - 1)
+ .await
+ .unwrap();
+
+ assert_eq!(clone_2, second_key)
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_ratchet() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(
+ &provider,
+ &vec![0u8; provider.kdf_extract_size()],
+ KeyType::Application,
+ )
+ .await
+ .unwrap();
+
+ let original_secret = ratchet.secret.clone();
+ let _ = ratchet.next_message_key(&provider).await.unwrap();
+ let new_secret = ratchet.secret;
+ assert_ne!(original_secret, new_secret)
+ }
+ }
+
+ #[cfg(feature = "out_of_order")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_out_of_order_keys() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+ let mut ratchet_clone = ratchet.clone();
+
+ // Ask for all the keys in order from the original ratchet
+ let mut ordered_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..=MAX_RATCHET_BACK_HISTORY {
+ ordered_keys.push(ratchet.get_message_key(&provider, i).await.unwrap());
+ }
+
+ // Ask for a key at index MAX_RATCHET_BACK_HISTORY in the clone
+ let last_key = ratchet_clone
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY)
+ .await
+ .unwrap();
+
+ assert_eq!(last_key, ordered_keys[ordered_keys.len() - 1]);
+
+ // Get all the other keys
+ let mut back_history_keys = Vec::<MessageKeyData>::new();
+
+ for i in 0..MAX_RATCHET_BACK_HISTORY - 1 {
+ back_history_keys.push(ratchet_clone.get_message_key(&provider, i).await.unwrap());
+ }
+
+ assert_eq!(
+ back_history_keys,
+ ordered_keys[..(MAX_RATCHET_BACK_HISTORY as usize) - 1]
+ );
+ }
+
+ #[cfg(not(feature = "out_of_order"))]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn out_of_order_keys_should_throw_error() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ ratchet.get_message_key(&provider, 10).await.unwrap();
+ let res = ratchet.get_message_key(&provider, 9).await;
+ assert_matches!(res, Err(MlsError::KeyMissing(9)))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_too_out_of_order() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut ratchet = SecretKeyRatchet::new(&provider, &[0u8; 32], KeyType::Handshake)
+ .await
+ .unwrap();
+
+ let res = ratchet
+ .get_message_key(&provider, MAX_RATCHET_BACK_HISTORY + 1)
+ .await;
+
+ let invalid_generation = MAX_RATCHET_BACK_HISTORY + 1;
+
+ assert_matches!(
+ res,
+ Err(MlsError::InvalidFutureGeneration(invalid))
+ if invalid == invalid_generation
+ )
+ }
+
+ #[derive(Debug, PartialEq, serde::Serialize, serde::Deserialize)]
+ struct Ratchet {
+ application_keys: Vec<Vec<u8>>,
+ handshake_keys: Vec<Vec<u8>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ ratchets: Vec<Ratchet>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_ratchet_data(
+ secret_tree: &mut SecretTree<NodeIndex>,
+ cipher_suite: CipherSuite,
+ ) -> Vec<Ratchet> {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let mut ratchet_data = Vec::new();
+
+ for index in 0..16 {
+ let mut ratchets = secret_tree
+ .take_leaf_ratchet(&provider, &(index * 2))
+ .await
+ .unwrap();
+
+ let mut application_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ application_keys.push(key);
+ }
+
+ let mut handshake_keys = Vec::new();
+
+ for _ in 0..20 {
+ let key = ratchets
+ .handshake
+ .next_message_key(&provider)
+ .await
+ .unwrap()
+ .mls_encode_to_vec()
+ .unwrap();
+
+ handshake_keys.push(key);
+ }
+
+ ratchet_data.push(Ratchet {
+ application_keys,
+ handshake_keys,
+ });
+ }
+
+ ratchet_data
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all()
+ .map(|cipher_suite| {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let encryption_secret = random_bytes(provider.kdf_extract_size());
+
+ let mut secret_tree =
+ SecretTree::new(16, Zeroizing::new(encryption_secret.clone()));
+
+ TestCase {
+ cipher_suite: cipher_suite.into(),
+ encryption_secret,
+ ratchets: get_ratchet_data(&mut secret_tree, cipher_suite),
+ }
+ })
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_secret_tree_test_vectors() {
+ let test_cases: Vec<TestCase> = load_test_case_json!(secret_tree, generate_test_vector());
+
+ for case in test_cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ let mut secret_tree = SecretTree::new(16, Zeroizing::new(case.encryption_secret));
+ let ratchet_data = get_ratchet_data(&mut secret_tree, cs_provider.cipher_suite()).await;
+
+ assert_eq!(ratchet_data, case.ratchets);
+ }
+ }
+}
+
+#[cfg(all(test, feature = "rfc_compliant", feature = "std"))]
+mod interop_tests {
+ #[cfg(not(mls_build_async))]
+ use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+ use zeroize::Zeroizing;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{ciphertext_processor::InteropSenderData, secret_tree::KeyType},
+ };
+
+ use super::SecretTree;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn interop_test_vector() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/secret-tree.json
+ let test_cases = load_interop_test_cases();
+
+ for case in test_cases {
+ let Some(cs) = try_test_cipher_suite_provider(case.cipher_suite) else {
+ continue;
+ };
+
+ case.sender_data.verify(&cs).await;
+
+ let mut tree = SecretTree::new(
+ case.leaves.len() as u32,
+ Zeroizing::new(case.encryption_secret),
+ );
+
+ for (index, leaves) in case.leaves.iter().enumerate() {
+ for leaf in leaves.iter() {
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Application,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.application_key);
+ assert_eq!(key.nonce.to_vec(), leaf.application_nonce);
+
+ let key = tree
+ .message_key_generation(
+ &cs,
+ (index as u32) * 2,
+ KeyType::Handshake,
+ leaf.generation,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(key.key.to_vec(), leaf.handshake_key);
+ assert_eq!(key.nonce.to_vec(), leaf.handshake_nonce);
+ }
+ }
+ }
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropTestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ encryption_secret: Vec<u8>,
+ sender_data: InteropSenderData,
+ leaves: Vec<Vec<InteropLeaf>>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct InteropLeaf {
+ generation: u32,
+ #[serde(with = "hex::serde")]
+ application_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ application_nonce: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ handshake_nonce: Vec<u8>,
+ }
+
+ fn load_interop_test_cases() -> Vec<InteropTestCase> {
+ load_test_case_json!(secret_tree_interop, generate_test_vector())
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ let mut test_cases = vec![];
+
+ for cs in CipherSuite::all() {
+ let Some(cs) = try_test_cipher_suite_provider(*cs) else {
+ continue;
+ };
+
+ let gens = [0, 15];
+ let tree_sizes = [1, 8, 32];
+
+ for n_leaves in tree_sizes {
+ let encryption_secret = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+
+ let mut tree = SecretTree::new(n_leaves, Zeroizing::new(encryption_secret.clone()));
+
+ let leaves = (0..n_leaves)
+ .map(|leaf| {
+ gens.into_iter()
+ .map(|gen| {
+ let index = leaf * 2u32;
+
+ let handshake_key = tree
+ .message_key_generation(&cs, index, KeyType::Handshake, gen)
+ .unwrap();
+
+ let app_key = tree
+ .message_key_generation(&cs, index, KeyType::Application, gen)
+ .unwrap();
+
+ InteropLeaf {
+ generation: gen,
+ application_key: app_key.key.to_vec(),
+ application_nonce: app_key.nonce.to_vec(),
+ handshake_key: handshake_key.key.to_vec(),
+ handshake_nonce: handshake_key.nonce.to_vec(),
+ }
+ })
+ .collect()
+ })
+ .collect();
+
+ let case = InteropTestCase {
+ cipher_suite: *cs.cipher_suite(),
+ encryption_secret,
+ sender_data: InteropSenderData::new(&cs),
+ leaves,
+ };
+
+ test_cases.push(case);
+ }
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/snapshot.rs b/src/group/snapshot.rs
new file mode 100644
index 0000000..dca64f8
--- /dev/null
+++ b/src/group/snapshot.rs
@@ -0,0 +1,325 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{
+ client::MlsError,
+ client_config::ClientConfig,
+ group::{
+ key_schedule::KeySchedule, CommitGeneration, ConfirmationTag, Group, GroupContext,
+ GroupState, InterimTranscriptHash, ReInitProposal, TreeKemPublic,
+ },
+ tree_kem::TreeKemPrivate,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::{
+ crypto::{HpkePublicKey, HpkeSecretKey},
+ group::ProposalRef,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use super::proposal_cache::{CachedProposal, ProposalCache};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+use mls_rs_core::crypto::SignatureSecretKey;
+#[cfg(feature = "tree_index")]
+use mls_rs_core::identity::IdentityProvider;
+
+#[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+use std::collections::HashMap;
+
+#[cfg(all(feature = "by_ref_proposal", not(feature = "std")))]
+use alloc::vec::Vec;
+
+use super::{cipher_suite_provider, epoch::EpochSecrets, state_repo::GroupStateRepository};
+
+#[derive(Debug, PartialEq, Clone, MlsEncode, MlsDecode, MlsSize)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Snapshot {
+ version: u16,
+ pub(crate) state: RawGroupState,
+ private_tree: TreeKemPrivate,
+ epoch_secrets: EpochSecrets,
+ key_schedule: KeySchedule,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pending_updates: HashMap<HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>)>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pending_updates: Vec<(HpkePublicKey, (HpkeSecretKey, Option<SignatureSecretKey>))>,
+ pending_commit: Option<CommitGeneration>,
+ signer: SignatureSecretKey,
+}
+
+#[derive(Debug, MlsEncode, MlsDecode, MlsSize, PartialEq, Clone)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct RawGroupState {
+ pub(crate) context: GroupContext,
+ #[cfg(all(feature = "std", feature = "by_ref_proposal"))]
+ pub(crate) proposals: HashMap<ProposalRef, CachedProposal>,
+ #[cfg(all(not(feature = "std"), feature = "by_ref_proposal"))]
+ pub(crate) proposals: Vec<(ProposalRef, CachedProposal)>,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl RawGroupState {
+ pub(crate) fn export(state: &GroupState) -> Self {
+ #[cfg(feature = "tree_index")]
+ let public_tree = state.public_tree.clone();
+
+ #[cfg(not(feature = "tree_index"))]
+ let public_tree = {
+ let mut tree = TreeKemPublic::new();
+ tree.nodes = state.public_tree.nodes.clone();
+ tree
+ };
+
+ Self {
+ context: state.context.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: state.proposals.proposals.clone(),
+ public_tree,
+ interim_transcript_hash: state.interim_transcript_hash.clone(),
+ pending_reinit: state.pending_reinit.clone(),
+ confirmation_tag: state.confirmation_tag.clone(),
+ }
+ }
+
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import<C>(self, identity_provider: &C) -> Result<GroupState, MlsError>
+ where
+ C: IdentityProvider,
+ {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ let mut public_tree = self.public_tree;
+
+ public_tree
+ .initialize_index_if_necessary(identity_provider, &context.extensions)
+ .await?;
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import(self) -> Result<GroupState, MlsError> {
+ let context = self.context;
+
+ #[cfg(feature = "by_ref_proposal")]
+ let proposals = ProposalCache::import(
+ context.protocol_version,
+ context.group_id.clone(),
+ self.proposals,
+ );
+
+ Ok(GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals,
+ context,
+ public_tree: self.public_tree,
+ interim_transcript_hash: self.interim_transcript_hash,
+ pending_reinit: self.pending_reinit,
+ confirmation_tag: self.confirmation_tag,
+ })
+ }
+}
+
+impl<C> Group<C>
+where
+ C: ClientConfig + Clone,
+{
+ /// Write the current state of the group to the
+ /// [`GroupStorageProvider`](crate::GroupStateStorage)
+ /// that is currently in use by the group.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self) -> Result<(), MlsError> {
+ self.state_repo.write_to_storage(self.snapshot()).await
+ }
+
+ pub(crate) fn snapshot(&self) -> Snapshot {
+ Snapshot {
+ state: RawGroupState::export(&self.state),
+ private_tree: self.private_tree.clone(),
+ key_schedule: self.key_schedule.clone(),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: self.pending_updates.clone(),
+ pending_commit: self.pending_commit.clone(),
+ epoch_secrets: self.epoch_secrets.clone(),
+ version: 1,
+ signer: self.signer.clone(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn from_snapshot(config: C, snapshot: Snapshot) -> Result<Self, MlsError> {
+ let cipher_suite_provider = cipher_suite_provider(
+ config.crypto_provider(),
+ snapshot.state.context.cipher_suite,
+ )?;
+
+ #[cfg(feature = "tree_index")]
+ let identity_provider = config.identity_provider();
+
+ let state_repo = GroupStateRepository::new(
+ #[cfg(feature = "prior_epoch")]
+ snapshot.state.context.group_id.clone(),
+ config.group_state_storage(),
+ config.key_package_repo(),
+ None,
+ )?;
+
+ Ok(Group {
+ config,
+ state: snapshot
+ .state
+ .import(
+ #[cfg(feature = "tree_index")]
+ &identity_provider,
+ )
+ .await?,
+ private_tree: snapshot.private_tree,
+ key_schedule: snapshot.key_schedule,
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: snapshot.pending_updates,
+ pending_commit: snapshot.pending_commit,
+ #[cfg(test)]
+ commit_modifiers: Default::default(),
+ epoch_secrets: snapshot.epoch_secrets,
+ state_repo,
+ cipher_suite_provider,
+ #[cfg(feature = "psk")]
+ previous_psk: None,
+ signer: snapshot.signer,
+ })
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag, epoch::test_utils::get_test_epoch_secrets,
+ key_schedule::test_utils::get_test_key_schedule, test_utils::get_test_group_context,
+ transcript_hash::InterimTranscriptHash,
+ },
+ tree_kem::{node::LeafIndex, TreeKemPrivate},
+ };
+
+ use super::{RawGroupState, Snapshot};
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_snapshot(cipher_suite: CipherSuite, epoch_id: u64) -> Snapshot {
+ Snapshot {
+ state: RawGroupState {
+ context: get_test_group_context(epoch_id, cipher_suite).await,
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: Default::default(),
+ public_tree: Default::default(),
+ interim_transcript_hash: InterimTranscriptHash::from(vec![]),
+ pending_reinit: None,
+ confirmation_tag: ConfirmationTag::empty(&test_cipher_suite_provider(cipher_suite))
+ .await,
+ },
+ private_tree: TreeKemPrivate::new(LeafIndex(0)),
+ epoch_secrets: get_test_epoch_secrets(cipher_suite),
+ key_schedule: get_test_key_schedule(cipher_suite),
+ #[cfg(feature = "by_ref_proposal")]
+ pending_updates: Default::default(),
+ pending_commit: None,
+ version: 1,
+ signer: vec![].into(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ test_utils::{test_group, TestGroup},
+ Group,
+ },
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn snapshot_restore(group: TestGroup) {
+ let snapshot = group.group.snapshot();
+
+ let group_restored = Group::from_snapshot(group.group.config.clone(), snapshot)
+ .await
+ .unwrap();
+
+ assert!(Group::equal_group_state(&group.group, &group_restored));
+
+ #[cfg(feature = "tree_index")]
+ assert!(group_restored
+ .state
+ .public_tree
+ .equal_internals(&group.group.state.public_tree))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_commit_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+ group.group.commit(vec![]).await.unwrap();
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_with_pending_updates_can_be_serialized_to_json() {
+ let mut group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ // Creating the update proposal will add it to pending updates
+ let update_proposal = group.update_proposal().await;
+
+ // This will insert the proposal into the internal proposal cache
+ let _ = group.group.proposal_message(update_proposal, vec![]).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn snapshot_can_be_serialized_to_json_with_internals() {
+ let group = test_group(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE).await;
+
+ snapshot_restore(group).await
+ }
+
+ #[cfg(feature = "serde")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn serde() {
+ let snapshot = super::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, 5).await;
+ let json = serde_json::to_string_pretty(&snapshot).unwrap();
+ let recovered = serde_json::from_str(&json).unwrap();
+ assert_eq!(snapshot, recovered);
+ }
+}
diff --git a/src/group/state.rs b/src/group/state.rs
new file mode 100644
index 0000000..4d97a04
--- /dev/null
+++ b/src/group/state.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{
+ confirmation_tag::ConfirmationTag, proposal::ReInitProposal,
+ transcript_hash::InterimTranscriptHash,
+};
+use crate::group::{GroupContext, TreeKemPublic};
+
+#[derive(Clone, Debug, PartialEq)]
+#[non_exhaustive]
+pub struct GroupState {
+ #[cfg(feature = "by_ref_proposal")]
+ pub(crate) proposals: crate::group::ProposalCache,
+ pub(crate) context: GroupContext,
+ pub(crate) public_tree: TreeKemPublic,
+ pub(crate) interim_transcript_hash: InterimTranscriptHash,
+ pub(crate) pending_reinit: Option<ReInitProposal>,
+ pub(crate) confirmation_tag: ConfirmationTag,
+}
+
+impl GroupState {
+ pub(crate) fn new(
+ context: GroupContext,
+ current_tree: TreeKemPublic,
+ interim_transcript_hash: InterimTranscriptHash,
+ confirmation_tag: ConfirmationTag,
+ ) -> Self {
+ Self {
+ #[cfg(feature = "by_ref_proposal")]
+ proposals: crate::group::ProposalCache::new(
+ context.protocol_version,
+ context.group_id.clone(),
+ ),
+ context,
+ public_tree: current_tree,
+ interim_transcript_hash,
+ pending_reinit: None,
+ confirmation_tag,
+ }
+ }
+}
diff --git a/src/group/state_repo.rs b/src/group/state_repo.rs
new file mode 100644
index 0000000..6e33b0a
--- /dev/null
+++ b/src/group/state_repo.rs
@@ -0,0 +1,573 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::{group::PriorEpoch, key_package::KeyPackageRef};
+
+use alloc::collections::VecDeque;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::group::{EpochRecord, GroupState};
+use mls_rs_core::{error::IntoAnyError, group::GroupStateStorage, key_package::KeyPackageStorage};
+
+use super::snapshot::Snapshot;
+
+#[cfg(feature = "psk")]
+use crate::group::ResumptionPsk;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::psk::PreSharedKey;
+
+/// A set of changes to apply to a GroupStateStorage implementation. These changes MUST
+/// be made in a single transaction to avoid creating invalid states.
+#[derive(Default, Clone, Debug)]
+struct EpochStorageCommit {
+ pub(crate) inserts: VecDeque<PriorEpoch>,
+ pub(crate) updates: Vec<PriorEpoch>,
+}
+
+#[derive(Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_commit: EpochStorageCommit,
+ pending_key_package_removal: Option<KeyPackageRef>,
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> Debug for GroupStateRepository<S, K>
+where
+ S: GroupStateStorage + Debug,
+ K: KeyPackageStorage + Debug,
+{
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("GroupStateRepository")
+ .field("pending_commit", &self.pending_commit)
+ .field(
+ "pending_key_package_removal",
+ &self.pending_key_package_removal,
+ )
+ .field(
+ "group_id",
+ &mls_rs_core::debug::pretty_group_id(&self.group_id),
+ )
+ .field("storage", &self.storage)
+ .field("key_package_repo", &self.key_package_repo)
+ .finish()
+ }
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ group_id: Vec<u8>,
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ group_id,
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ pending_commit: Default::default(),
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn find_max_id(&self) -> Result<Option<u64>, MlsError> {
+ if let Some(max) = self.pending_commit.inserts.back().map(|e| e.epoch_id()) {
+ Ok(Some(max))
+ } else {
+ self.storage
+ .max_epoch_id(&self.group_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))
+ }
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resumption_secret(
+ &self,
+ psk_id: &ResumptionPsk,
+ ) -> Result<Option<PreSharedKey>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if psk_id.psk_epoch >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get((psk_id.psk_epoch - min) as usize)
+ .map(|e| e.secrets.resumption_secret.clone()));
+ }
+ }
+
+ // Search the local updates cache
+ let maybe_pending = self.find_pending(psk_id.psk_epoch);
+
+ if let Some(pending) = maybe_pending {
+ return Ok(Some(
+ self.pending_commit.updates[pending]
+ .secrets
+ .resumption_secret
+ .clone(),
+ ));
+ }
+
+ // Search the stored cache
+ self.storage
+ .epoch(&psk_id.psk_group_id.0, psk_id.psk_epoch)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .map(|e| Ok(PriorEpoch::mls_decode(&mut &*e)?.secrets.resumption_secret))
+ .transpose()
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_epoch_mut(
+ &mut self,
+ epoch_id: u64,
+ ) -> Result<Option<&mut PriorEpoch>, MlsError> {
+ // Search the local inserts cache
+ if let Some(min) = self.pending_commit.inserts.front().map(|e| e.epoch_id()) {
+ if epoch_id >= min {
+ return Ok(self
+ .pending_commit
+ .inserts
+ .get_mut((epoch_id - min) as usize));
+ }
+ }
+
+ // Look in the cached updates map, and if not found look in disk storage
+ // and insert into the updates map for future caching
+ match self.find_pending(epoch_id) {
+ Some(i) => self.pending_commit.updates.get_mut(i).map(Ok),
+ None => self
+ .storage
+ .epoch(&self.group_id, epoch_id)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?
+ .and_then(|epoch| {
+ PriorEpoch::mls_decode(&mut &*epoch)
+ .map(|epoch| {
+ self.pending_commit.updates.push(epoch);
+ self.pending_commit.updates.last_mut()
+ })
+ .transpose()
+ }),
+ }
+ .transpose()
+ .map_err(Into::into)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn insert(&mut self, epoch: PriorEpoch) -> Result<(), MlsError> {
+ if epoch.group_id() != self.group_id {
+ return Err(MlsError::GroupIdMismatch);
+ }
+
+ let epoch_id = epoch.epoch_id();
+
+ if let Some(expected_id) = self.find_max_id().await?.map(|id| id + 1) {
+ if epoch_id != expected_id {
+ return Err(MlsError::InvalidEpoch);
+ }
+ }
+
+ self.pending_commit.inserts.push_back(epoch);
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let inserts = self
+ .pending_commit
+ .inserts
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let updates = self
+ .pending_commit
+ .updates
+ .iter()
+ .map(|e| Ok(EpochRecord::new(e.epoch_id(), e.mls_encode_to_vec()?)))
+ .collect::<Result<_, MlsError>>()?;
+
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, inserts, updates)
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ self.pending_commit.inserts.clear();
+ self.pending_commit.updates.clear();
+
+ Ok(())
+ }
+
+ #[cfg(any(feature = "psk", feature = "private_message"))]
+ fn find_pending(&self, epoch_id: u64) -> Option<usize> {
+ self.pending_commit
+ .updates
+ .iter()
+ .position(|ep| ep.context.epoch == epoch_id)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use mls_rs_codec::MlsEncode;
+
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ epoch::{test_utils::get_test_epoch_with_id, SenderDataSecret},
+ test_utils::{random_bytes, test_member, TEST_GROUP},
+ PskGroupId, ResumptionPSKUsage,
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use super::*;
+
+ fn test_group_state_repo(
+ retention_limit: usize,
+ ) -> GroupStateRepository<InMemoryGroupStateStorage, InMemoryKeyPackageStorage> {
+ GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new()
+ .with_max_epoch_retention(retention_limit)
+ .unwrap(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap()
+ }
+
+ fn test_epoch(epoch_id: u64) -> PriorEpoch {
+ get_test_epoch_with_id(TEST_GROUP.to_vec(), TEST_CIPHER_SUITE, epoch_id)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ crate::group::snapshot::test_utils::get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_epoch_inserts() {
+ let mut test_repo = test_group_state_repo(1);
+ let test_epoch = test_epoch(0);
+
+ test_repo.insert(test_epoch.clone()).await.unwrap();
+
+ // Check the in-memory state
+ assert_eq!(
+ test_repo.pending_commit.inserts.back().unwrap(),
+ &test_epoch
+ );
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ #[cfg(feature = "std")]
+ assert!(test_repo.storage.inner.lock().unwrap().is_empty());
+ #[cfg(not(feature = "std"))]
+ assert!(test_repo.storage.inner.lock().is_empty());
+
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ // Make sure you can recall an epoch sitting as a pending insert
+ let resumption = test_repo.resumption_secret(&psk_id).await.unwrap();
+ let prior_epoch = test_repo.get_epoch_mut(0).await.unwrap().cloned();
+
+ assert_eq!(
+ prior_epoch.clone().unwrap().secrets.resumption_secret,
+ resumption.unwrap()
+ );
+
+ assert_eq!(prior_epoch.unwrap(), test_epoch);
+
+ // Write to the storage
+ let snapshot = test_snapshot(test_epoch.epoch_id()).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ // Make sure the memory cache cleared
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch.epoch_id(),
+ test_epoch.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_updates() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ assert_eq!(to_update, &test_epoch_0);
+
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ assert_eq!(test_repo.pending_commit.updates.len(), 1);
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ assert_eq!(
+ test_repo.pending_commit.updates.first().unwrap(),
+ &to_update
+ );
+
+ // Make sure you can access an epoch pending update
+ let psk_id = ResumptionPsk {
+ psk_epoch: 0,
+ psk_group_id: PskGroupId(test_repo.group_id.clone()),
+ usage: ResumptionPSKUsage::Application,
+ };
+
+ let owned = test_repo.resumption_secret(&psk_id).await.unwrap();
+ assert_eq!(owned.as_ref(), Some(&to_update.secrets.resumption_secret));
+
+ // Write the update to storage
+ let snapshot = test_snapshot(1).await;
+ test_repo.write_to_storage(snapshot.clone()).await.unwrap();
+
+ assert!(test_repo.pending_commit.updates.is_empty());
+ assert!(test_repo.pending_commit.inserts.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.state_data, snapshot.mls_encode_to_vec().unwrap());
+
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_and_update() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ // Update the stored epoch
+ let to_update = test_repo.get_epoch_mut(0).await.unwrap().unwrap();
+ let new_sender_secret = random_bytes(32);
+ to_update.secrets.sender_data_secret = SenderDataSecret::from(new_sender_secret);
+ let to_update = to_update.clone();
+
+ // Insert another epoch
+ let test_epoch_1 = test_epoch(1);
+ test_repo.insert(test_epoch_1.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(1).await)
+ .await
+ .unwrap();
+
+ assert!(test_repo.pending_commit.inserts.is_empty());
+ assert!(test_repo.pending_commit.updates.is_empty());
+
+ // Make sure the storage was written
+ #[cfg(feature = "std")]
+ let storage = test_repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let storage = test_repo.storage.inner.lock();
+
+ assert_eq!(storage.len(), 1);
+
+ let stored = storage.get(TEST_GROUP).unwrap();
+
+ assert_eq!(stored.epoch_data.len(), 2);
+
+ assert_eq!(
+ stored.epoch_data.front().unwrap(),
+ &EpochRecord::new(to_update.epoch_id(), to_update.mls_encode_to_vec().unwrap())
+ );
+
+ assert_eq!(
+ stored.epoch_data.back().unwrap(),
+ &EpochRecord::new(
+ test_epoch_1.epoch_id(),
+ test_epoch_1.mls_encode_to_vec().unwrap()
+ )
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_many_epochs_in_storage() {
+ let epochs = (0..10).map(test_epoch).collect::<Vec<_>>();
+
+ let mut test_repo = test_group_state_repo(10);
+
+ for epoch in epochs.iter().cloned() {
+ test_repo.insert(epoch).await.unwrap()
+ }
+
+ test_repo
+ .write_to_storage(test_snapshot(9).await)
+ .await
+ .unwrap();
+
+ for mut epoch in epochs {
+ let res = test_repo.get_epoch_mut(epoch.epoch_id()).await.unwrap();
+
+ assert_eq!(res, Some(&mut epoch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = test_group_state_repo(2);
+ let test_epoch_0 = test_epoch(0);
+
+ test_repo.insert(test_epoch_0.clone()).await.unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(
+ test_repo.storage.stored_groups(),
+ vec![test_epoch_0.context.group_id]
+ )
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn reducing_retention_limit_takes_effect_on_epoch_access() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+
+ let mut repo = GroupStateRepository {
+ storage: repo.storage,
+ ..test_group_state_repo(1)
+ };
+
+ let res = repo.get_epoch_mut(0).await.unwrap();
+
+ assert!(res.is_none());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn in_memory_storage_obeys_retention_limit_after_saving() {
+ let mut repo = test_group_state_repo(1);
+
+ repo.insert(test_epoch(0)).await.unwrap();
+ repo.write_to_storage(test_snapshot(0).await).await.unwrap();
+ repo.insert(test_epoch(1)).await.unwrap();
+ repo.write_to_storage(test_snapshot(1).await).await.unwrap();
+
+ #[cfg(feature = "std")]
+ let lock = repo.storage.inner.lock().unwrap();
+ #[cfg(not(feature = "std"))]
+ let lock = repo.storage.inner.lock();
+
+ assert_eq!(lock.get(TEST_GROUP).unwrap().epoch_data.len(), 1);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ TEST_GROUP.to_vec(),
+ InMemoryGroupStateStorage::new(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/state_repo_light.rs b/src/group/state_repo_light.rs
new file mode 100644
index 0000000..76d1fb6
--- /dev/null
+++ b/src/group/state_repo_light.rs
@@ -0,0 +1,132 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::key_package::KeyPackageRef;
+
+use alloc::vec::Vec;
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::{
+ error::IntoAnyError,
+ group::{GroupState, GroupStateStorage},
+ key_package::KeyPackageStorage,
+};
+
+use super::snapshot::Snapshot;
+
+#[derive(Debug, Clone)]
+pub(crate) struct GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pending_key_package_removal: Option<KeyPackageRef>,
+ storage: S,
+ key_package_repo: K,
+}
+
+impl<S, K> GroupStateRepository<S, K>
+where
+ S: GroupStateStorage,
+ K: KeyPackageStorage,
+{
+ pub fn new(
+ storage: S,
+ key_package_repo: K,
+ // Set to `None` if restoring from snapshot; set to `Some` when joining a group.
+ key_package_to_remove: Option<KeyPackageRef>,
+ ) -> Result<GroupStateRepository<S, K>, MlsError> {
+ Ok(GroupStateRepository {
+ storage,
+ pending_key_package_removal: key_package_to_remove,
+ key_package_repo,
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn write_to_storage(&mut self, group_snapshot: Snapshot) -> Result<(), MlsError> {
+ let group_state = GroupState {
+ data: group_snapshot.mls_encode_to_vec()?,
+ id: group_snapshot.state.context.group_id,
+ };
+
+ self.storage
+ .write(group_state, Vec::new(), Vec::new())
+ .await
+ .map_err(|e| MlsError::GroupStorageError(e.into_any_error()))?;
+
+ if let Some(ref key_package_ref) = self.pending_key_package_removal {
+ self.key_package_repo
+ .delete(key_package_ref)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))?;
+ }
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ group::{
+ snapshot::{test_utils::get_test_snapshot, Snapshot},
+ test_utils::{test_member, TEST_GROUP},
+ },
+ storage_provider::in_memory::{InMemoryGroupStateStorage, InMemoryKeyPackageStorage},
+ };
+
+ use alloc::vec;
+
+ use super::GroupStateRepository;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_snapshot(epoch_id: u64) -> Snapshot {
+ get_test_snapshot(TEST_CIPHER_SUITE, epoch_id).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_stored_groups_list() {
+ let mut test_repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ InMemoryKeyPackageStorage::default(),
+ None,
+ )
+ .unwrap();
+
+ test_repo
+ .write_to_storage(test_snapshot(0).await)
+ .await
+ .unwrap();
+
+ assert_eq!(test_repo.storage.stored_groups(), vec![TEST_GROUP])
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn used_key_package_is_deleted() {
+ let key_package_repo = InMemoryKeyPackageStorage::default();
+
+ let key_package = test_member(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, b"member")
+ .await
+ .0;
+
+ let (id, data) = key_package.to_storage().unwrap();
+
+ key_package_repo.insert(id, data);
+
+ let mut repo = GroupStateRepository::new(
+ InMemoryGroupStateStorage::default(),
+ key_package_repo,
+ Some(key_package.reference.clone()),
+ )
+ .unwrap();
+
+ repo.key_package_repo.get(&key_package.reference).unwrap();
+
+ repo.write_to_storage(test_snapshot(4).await).await.unwrap();
+
+ assert!(repo.key_package_repo.get(&key_package.reference).is_none());
+ }
+}
diff --git a/src/group/test_utils.rs b/src/group/test_utils.rs
new file mode 100644
index 0000000..764d5e6
--- /dev/null
+++ b/src/group/test_utils.rs
@@ -0,0 +1,521 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::ops::{Deref, DerefMut};
+
+use alloc::format;
+use rand::RngCore;
+
+use super::*;
+use crate::{
+ client::{
+ test_utils::{
+ test_client_with_key_pkg, test_client_with_key_pkg_custom, TEST_CIPHER_SUITE,
+ TEST_PROTOCOL_VERSION,
+ },
+ MlsError,
+ },
+ client_builder::test_utils::{TestClientBuilder, TestClientConfig},
+ crypto::test_utils::test_cipher_suite_provider,
+ extension::ExtensionType,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ key_package::{KeyPackageGeneration, KeyPackageGenerator},
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
+};
+
+use crate::extension::RequiredCapabilitiesExt;
+
+#[cfg(not(feature = "by_ref_proposal"))]
+use crate::crypto::HpkePublicKey;
+
+pub const TEST_GROUP: &[u8] = b"group";
+
+#[derive(Clone)]
+pub(crate) struct TestGroup {
+ pub group: Group<TestClientConfig>,
+}
+
+impl TestGroup {
+ #[cfg(feature = "external_client")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn propose(&mut self, proposal: Proposal) -> MlsMessage {
+ self.group.proposal_message(proposal, vec![]).await.unwrap()
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn update_proposal(&mut self) -> Proposal {
+ self.group.update_proposal(None, None).await.unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join_with_custom_config<F>(
+ &mut self,
+ name: &str,
+ custom_kp: bool,
+ mut config: F,
+ ) -> Result<(TestGroup, MlsMessage), MlsError>
+ where
+ F: FnMut(&mut TestClientConfig),
+ {
+ let (mut new_client, new_key_package) = if custom_kp {
+ test_client_with_key_pkg_custom(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ &mut config,
+ )
+ .await
+ } else {
+ test_client_with_key_pkg(
+ self.group.protocol_version(),
+ self.group.cipher_suite(),
+ name,
+ )
+ .await
+ };
+
+ // Add new member to the group
+ let CommitOutput {
+ welcome_messages,
+ ratchet_tree,
+ commit_message,
+ ..
+ } = self
+ .group
+ .commit_builder()
+ .add_member(new_key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ // Apply the commit to the original group
+ self.group.apply_pending_commit().await.unwrap();
+
+ config(&mut new_client.config);
+
+ // Group from new member's perspective
+ let (new_group, _) = Group::join(
+ &welcome_messages[0],
+ ratchet_tree,
+ new_client.config.clone(),
+ new_client.signer.clone().unwrap(),
+ )
+ .await?;
+
+ let new_test_group = TestGroup { group: new_group };
+
+ Ok((new_test_group, commit_message))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn join(&mut self, name: &str) -> (TestGroup, MlsMessage) {
+ self.join_with_custom_config(name, false, |_| ())
+ .await
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_pending_commit(
+ &mut self,
+ ) -> Result<CommitMessageDescription, MlsError> {
+ self.group.apply_pending_commit().await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn process_message(
+ &mut self,
+ message: MlsMessage,
+ ) -> Result<ReceivedMessage, MlsError> {
+ self.group.process_incoming_message(message).await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn make_plaintext(&mut self, content: Content) -> MlsMessage {
+ let auth_content = AuthenticatedContent::new_signed(
+ &self.group.cipher_suite_provider,
+ &self.group.state.context,
+ Sender::Member(*self.group.private_tree.self_index),
+ content,
+ &self.group.signer,
+ WireFormat::PublicMessage,
+ Vec::new(),
+ )
+ .await
+ .unwrap();
+
+ self.group.format_for_wire(auth_content).await.unwrap()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_group_context(epoch: u64, cipher_suite: CipherSuite) -> GroupContext {
+ let cs = test_cipher_suite_provider(cipher_suite);
+
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id: TEST_GROUP.to_vec(),
+ epoch,
+ tree_hash: cs.hash(&[1, 2, 3]).await.unwrap(),
+ confirmed_transcript_hash: cs.hash(&[3, 2, 1]).await.unwrap().into(),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+#[cfg(feature = "prior_epoch")]
+pub(crate) fn get_test_group_context_with_id(
+ group_id: Vec<u8>,
+ epoch: u64,
+ cipher_suite: CipherSuite,
+) -> GroupContext {
+ GroupContext {
+ protocol_version: TEST_PROTOCOL_VERSION,
+ cipher_suite,
+ group_id,
+ epoch,
+ tree_hash: vec![],
+ confirmed_transcript_hash: ConfirmedTranscriptHash::from(vec![]),
+ extensions: ExtensionList::from(vec![]),
+ }
+}
+
+pub(crate) fn group_extensions() -> ExtensionList {
+ let required_capabilities = RequiredCapabilitiesExt::default();
+
+ let mut extensions = ExtensionList::new();
+ extensions.set_from(required_capabilities).unwrap();
+ extensions
+}
+
+pub(crate) fn lifetime() -> Lifetime {
+ Lifetime::years(1).unwrap()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_member(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ identifier: &[u8],
+) -> (KeyPackageGeneration, SignatureSecretKey) {
+ let (signing_identity, signing_key) = get_test_signing_identity(cipher_suite, identifier).await;
+
+ let key_package_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let key_package = key_package_generator
+ .generate(
+ lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ (key_package, signing_key)
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ extension_types: Vec<ExtensionType>,
+ leaf_extensions: Option<ExtensionList>,
+ commit_options: Option<CommitOptions>,
+) -> TestGroup {
+ let leaf_extensions = leaf_extensions.unwrap_or_default();
+ let commit_options = commit_options.unwrap_or_default();
+
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let group = TestClientBuilder::new_for_test()
+ .leaf_node_extensions(leaf_extensions)
+ .mls_rules(DefaultMlsRules::default().with_commit_options(commit_options))
+ .extension_types(extension_types)
+ .protocol_versions(ProtocolVersion::all())
+ .used_protocol_version(protocol_version)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+) -> TestGroup {
+ test_group_custom(
+ protocol_version,
+ cipher_suite,
+ Default::default(),
+ None,
+ None,
+ )
+ .await
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_group_custom_config<F>(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ custom: F,
+) -> TestGroup
+where
+ F: FnOnce(TestClientBuilder) -> TestClientBuilder,
+{
+ let (signing_identity, secret_key) = get_test_signing_identity(cipher_suite, b"member").await;
+
+ let client_builder = TestClientBuilder::new_for_test().used_protocol_version(protocol_version);
+
+ let group = custom(client_builder)
+ .signing_identity(signing_identity.clone(), secret_key, cipher_suite)
+ .build()
+ .create_group_with_id(TEST_GROUP.to_vec(), group_extensions())
+ .await
+ .unwrap();
+
+ TestGroup { group }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn test_n_member_group(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_members: usize,
+) -> Vec<TestGroup> {
+ let group = test_group(protocol_version, cipher_suite).await;
+
+ let mut groups = vec![group];
+
+ for i in 1..num_members {
+ let (new_group, commit) = groups.get_mut(0).unwrap().join(&format!("name {i}")).await;
+ process_commit(&mut groups, commit, 0).await;
+ groups.push(new_group);
+ }
+
+ groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn process_commit(groups: &mut [TestGroup], commit: MlsMessage, excluded: u32) {
+ for g in groups
+ .iter_mut()
+ .filter(|g| g.group.current_member_index() != excluded)
+ {
+ g.process_message(commit.clone()).await.unwrap();
+ }
+}
+
+pub(crate) fn get_test_25519_key(key_byte: u8) -> HpkePublicKey {
+ vec![key_byte; 32].into()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn get_test_groups_with_features(
+ n: usize,
+ extensions: ExtensionList,
+ leaf_extensions: ExtensionList,
+) -> Vec<Group<TestClientConfig>> {
+ let mut clients = Vec::new();
+
+ for i in 0..n {
+ let (identity, secret_key) =
+ get_test_signing_identity(TEST_CIPHER_SUITE, format!("member{i}").as_bytes()).await;
+
+ clients.push(
+ TestClientBuilder::new_for_test()
+ .extension_type(999.into())
+ .leaf_node_extensions(leaf_extensions.clone())
+ .signing_identity(identity, secret_key, TEST_CIPHER_SUITE)
+ .build(),
+ );
+ }
+
+ let group = clients[0]
+ .create_group_with_id(b"TEST GROUP".to_vec(), extensions)
+ .await
+ .unwrap();
+
+ let mut groups = vec![group];
+
+ for client in clients.iter().skip(1) {
+ let key_package = client.generate_key_package_message().await.unwrap();
+
+ let commit_output = groups[0]
+ .commit_builder()
+ .add_member(key_package)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ groups[0].apply_pending_commit().await.unwrap();
+
+ for group in groups.iter_mut().skip(1) {
+ group
+ .process_incoming_message(commit_output.commit_message.clone())
+ .await
+ .unwrap();
+ }
+
+ groups.push(
+ client
+ .join_group(None, &commit_output.welcome_messages[0])
+ .await
+ .unwrap()
+ .0,
+ );
+ }
+
+ groups
+}
+
+pub fn random_bytes(count: usize) -> Vec<u8> {
+ let mut buf = vec![0; count];
+ rand::thread_rng().fill_bytes(&mut buf);
+ buf
+}
+
+pub(crate) struct GroupWithoutKeySchedule {
+ inner: Group<TestClientConfig>,
+ pub secrets: Option<(TreeKemPrivate, PathSecret)>,
+ pub provisional_public_state: Option<ProvisionalState>,
+}
+
+impl Deref for GroupWithoutKeySchedule {
+ type Target = Group<TestClientConfig>;
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref(&self) -> &Self::Target {
+ &self.inner
+ }
+}
+
+impl DerefMut for GroupWithoutKeySchedule {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.inner
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+impl GroupWithoutKeySchedule {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn new(cs: CipherSuite) -> Self {
+ Self {
+ inner: test_group(TEST_PROTOCOL_VERSION, cs).await.group,
+ secrets: None,
+ provisional_public_state: None,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+impl MessageProcessor for GroupWithoutKeySchedule {
+ type CipherSuiteProvider = <Group<TestClientConfig> as MessageProcessor>::CipherSuiteProvider;
+ type OutputType = <Group<TestClientConfig> as MessageProcessor>::OutputType;
+ type PreSharedKeyStorage = <Group<TestClientConfig> as MessageProcessor>::PreSharedKeyStorage;
+ type IdentityProvider = <Group<TestClientConfig> as MessageProcessor>::IdentityProvider;
+ type MlsRules = <Group<TestClientConfig> as MessageProcessor>::MlsRules;
+
+ fn group_state(&self) -> &GroupState {
+ self.inner.group_state()
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn group_state_mut(&mut self) -> &mut GroupState {
+ self.inner.group_state_mut()
+ }
+
+ fn mls_rules(&self) -> Self::MlsRules {
+ self.inner.mls_rules()
+ }
+
+ fn identity_provider(&self) -> Self::IdentityProvider {
+ self.inner.identity_provider()
+ }
+
+ fn cipher_suite_provider(&self) -> &Self::CipherSuiteProvider {
+ self.inner.cipher_suite_provider()
+ }
+
+ fn psk_storage(&self) -> Self::PreSharedKeyStorage {
+ self.inner.psk_storage()
+ }
+
+ fn can_continue_processing(&self, provisional_state: &ProvisionalState) -> bool {
+ self.inner.can_continue_processing(provisional_state)
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn min_epoch_available(&self) -> Option<u64> {
+ self.inner.min_epoch_available()
+ }
+
+ async fn apply_update_path(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ provisional_state: &mut ProvisionalState,
+ ) -> Result<Option<(TreeKemPrivate, PathSecret)>, MlsError> {
+ self.inner
+ .apply_update_path(sender, update_path, provisional_state)
+ .await
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn process_ciphertext(
+ &mut self,
+ cipher_text: &PrivateMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.process_ciphertext(cipher_text).await
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn verify_plaintext_authentication(
+ &self,
+ message: PublicMessage,
+ ) -> Result<EventOrContent<Self::OutputType>, MlsError> {
+ self.inner.verify_plaintext_authentication(message).await
+ }
+
+ async fn update_key_schedule(
+ &mut self,
+ secrets: Option<(TreeKemPrivate, PathSecret)>,
+ _interim_transcript_hash: InterimTranscriptHash,
+ _confirmation_tag: &ConfirmationTag,
+ provisional_public_state: ProvisionalState,
+ ) -> Result<(), MlsError> {
+ self.provisional_public_state = Some(provisional_public_state);
+ self.secrets = secrets;
+ Ok(())
+ }
+
+ #[cfg(feature = "private_message")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn self_index(&self) -> Option<LeafIndex> {
+ <Group<TestClientConfig> as MessageProcessor>::self_index(&self.inner)
+ }
+}
diff --git a/src/group/transcript_hash.rs b/src/group/transcript_hash.rs
new file mode 100644
index 0000000..c336dfa
--- /dev/null
+++ b/src/group/transcript_hash.rs
@@ -0,0 +1,293 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{crypto::CipherSuiteProvider, error::IntoAnyError};
+
+use crate::{
+ client::MlsError,
+ group::{framing::FramedContent, MessageSignature},
+ WireFormat,
+};
+
+use super::{AuthenticatedContent, ConfirmationTag};
+
+#[derive(Clone, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ConfirmedTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ConfirmedTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ConfirmedTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for ConfirmedTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for ConfirmedTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl ConfirmedTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct ConfirmedTranscriptHashInput<'a> {
+ wire_format: WireFormat,
+ content: &'a FramedContent,
+ signature: &'a MessageSignature,
+ }
+
+ let input = ConfirmedTranscriptHashInput {
+ wire_format: content.wire_format,
+ content: &content.content,
+ signature: &content.auth.signature,
+ };
+
+ let hash_input = [
+ interim_transcript_hash.deref(),
+ input.mls_encode_to_vec()?.deref(),
+ ]
+ .concat();
+
+ cipher_suite_provider
+ .hash(&hash_input)
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[derive(Clone, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct InterimTranscriptHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for InterimTranscriptHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("InterimTranscriptHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for InterimTranscriptHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for InterimTranscriptHash {
+ fn from(value: Vec<u8>) -> Self {
+ Self(value)
+ }
+}
+
+impl InterimTranscriptHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn create<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ confirmed: &ConfirmedTranscriptHash,
+ confirmation_tag: &ConfirmationTag,
+ ) -> Result<Self, MlsError> {
+ #[derive(Debug, MlsSize, MlsEncode)]
+ struct InterimTranscriptHashInput<'a> {
+ confirmation_tag: &'a ConfirmationTag,
+ }
+
+ let input = InterimTranscriptHashInput { confirmation_tag }.mls_encode_to_vec()?;
+
+ cipher_suite_provider
+ .hash(&[confirmed.0.deref(), &input].concat())
+ .await
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+// Test vectors come from the MLS interop repository and contain a proposal by reference.
+#[cfg(feature = "by_ref_proposal")]
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+
+ use mls_rs_codec::MlsDecode;
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ group::{framing::ContentType, message_signature::AuthenticatedContent, transcript_hashes},
+ };
+
+ #[cfg(not(mls_build_async))]
+ use alloc::{boxed::Box, vec};
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider,
+ group::{
+ confirmation_tag::ConfirmationTag,
+ framing::Content,
+ proposal::{Proposal, ProposalOrRef, RemoveProposal},
+ test_utils::get_test_group_context,
+ Commit, LeafIndex, Sender,
+ },
+ mls_rs_codec::MlsEncode,
+ CipherSuite, CipherSuiteProvider, WireFormat,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use super::{ConfirmedTranscriptHash, InterimTranscriptHash};
+
+ #[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+ struct TestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub confirmation_key: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub authenticated_content: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_before: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub confirmed_transcript_hash_after: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub interim_transcript_hash_after: Vec<u8>,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn transcript_hash() {
+ let test_cases: Vec<TestCase> =
+ load_test_case_json!(interop_transcript_hashes, generate_test_vector());
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let auth_content =
+ AuthenticatedContent::mls_decode(&mut &*test_case.authenticated_content).unwrap();
+
+ assert!(auth_content.content.content_type() == ContentType::Commit);
+
+ let conf_key = &test_case.confirmation_key;
+ let conf_hash_after = test_case.confirmed_transcript_hash_after.into();
+ let conf_tag = auth_content.auth.confirmation_tag.clone().unwrap();
+
+ let matches = conf_tag
+ .matches(conf_key, &conf_hash_after, &cs)
+ .await
+ .unwrap();
+
+ assert!(matches);
+
+ let (expected_interim, expected_conf) = transcript_hashes(
+ &cs,
+ &test_case.interim_transcript_hash_before.into(),
+ &auth_content,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(*expected_interim, test_case.interim_transcript_hash_after);
+ assert_eq!(expected_conf, conf_hash_after);
+ }
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<TestCase> {
+ CipherSuite::all().fold(vec![], |mut test_cases, cs| {
+ let cs = test_cipher_suite_provider(cs);
+
+ let context = get_test_group_context(0x3456, cs.cipher_suite());
+
+ let proposal = Proposal::Remove(RemoveProposal {
+ to_remove: LeafIndex(1),
+ });
+
+ let proposal = ProposalOrRef::Proposal(Box::new(proposal));
+
+ let commit = Commit {
+ proposals: vec![proposal],
+ path: None,
+ };
+
+ let signer = cs.signature_key_generate().unwrap().0;
+
+ let mut auth_content = AuthenticatedContent::new_signed(
+ &cs,
+ &context,
+ Sender::Member(0),
+ Content::Commit(alloc::boxed::Box::new(commit)),
+ &signer,
+ WireFormat::PublicMessage,
+ vec![],
+ )
+ .unwrap();
+
+ let interim_hash_before = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap().into();
+
+ let conf_hash_after =
+ ConfirmedTranscriptHash::create(&cs, &interim_hash_before, &auth_content).unwrap();
+
+ let conf_key = cs.random_bytes_vec(cs.kdf_extract_size()).unwrap();
+ let conf_tag = ConfirmationTag::create(&conf_key, &conf_hash_after, &cs).unwrap();
+
+ let interim_hash_after =
+ InterimTranscriptHash::create(&cs, &conf_hash_after, &conf_tag).unwrap();
+
+ auth_content.auth.confirmation_tag = Some(conf_tag);
+
+ let test_case = TestCase {
+ cipher_suite: cs.cipher_suite().into(),
+
+ confirmation_key: conf_key,
+ authenticated_content: auth_content.mls_encode_to_vec().unwrap(),
+ interim_transcript_hash_before: interim_hash_before.0,
+
+ confirmed_transcript_hash_after: conf_hash_after.0,
+ interim_transcript_hash_after: interim_hash_after.0,
+ };
+
+ test_cases.push(test_case);
+ test_cases
+ })
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+}
diff --git a/src/group/util.rs b/src/group/util.rs
new file mode 100644
index 0000000..dadfafa
--- /dev/null
+++ b/src/group/util.rs
@@ -0,0 +1,202 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{
+ error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageStorage,
+};
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ extension::RatchetTreeExt,
+ key_package::KeyPackageGeneration,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{node::LeafIndex, tree_validator::TreeValidator, TreeKemPublic},
+ CipherSuiteProvider, CryptoProvider,
+};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+use super::{
+ framing::Sender, message_signature::AuthenticatedContent,
+ transcript_hash::InterimTranscriptHash, ConfirmedTranscriptHash, EncryptedGroupSecrets,
+ ExportedTree, GroupInfo, GroupState,
+};
+
+use super::message_processor::ProvisionalState;
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_common<C: CipherSuiteProvider>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: &TreeKemPublic,
+ cs: &C,
+) -> Result<(), MlsError> {
+ if msg_version != group_info.group_context.protocol_version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ if group_info.group_context.cipher_suite != cs.cipher_suite() {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ let sender_leaf = &tree.get_leaf_node(group_info.signer)?;
+
+ group_info
+ .verify(cs, &sender_leaf.signing_identity.signature_key, &())
+ .await?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_member<C: CipherSuiteProvider>(
+ self_state: &GroupState,
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ cs: &C,
+) -> Result<(), MlsError> {
+ validate_group_info_common(msg_version, group_info, &self_state.public_tree, cs).await?;
+
+ let self_tree = ExportedTree::new_borrowed(&self_state.public_tree.nodes);
+
+ if let Some(tree) = group_info.extensions.get_as::<RatchetTreeExt>()? {
+ (tree.tree_data == self_tree)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+ }
+
+ (group_info.group_context == self_state.context
+ && group_info.confirmation_tag == self_state.confirmation_tag)
+ .then_some(())
+ .ok_or(MlsError::InvalidGroupInfo)?;
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_group_info_joiner<C, I>(
+ msg_version: ProtocolVersion,
+ group_info: &GroupInfo,
+ tree: Option<ExportedTree<'_>>,
+ id_provider: &I,
+ cs: &C,
+) -> Result<TreeKemPublic, MlsError>
+where
+ C: CipherSuiteProvider,
+ I: IdentityProvider,
+{
+ let tree = match group_info.extensions.get_as::<RatchetTreeExt>()? {
+ Some(ext) => ext.tree_data,
+ None => tree.ok_or(MlsError::RatchetTreeNotFound)?,
+ };
+
+ let context = &group_info.group_context;
+
+ let mut tree =
+ TreeKemPublic::import_node_data(tree.into(), id_provider, &context.extensions).await?;
+
+ // Verify the integrity of the ratchet tree
+ TreeValidator::new(cs, context, id_provider)
+ .validate(&mut tree)
+ .await?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ if let Some(ext_senders) = context.extensions.get_as::<ExternalSendersExt>()? {
+ // TODO do joiners verify group against current time??
+ ext_senders
+ .verify_all(id_provider, None, &context.extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+ }
+
+ validate_group_info_common(msg_version, group_info, &tree, cs).await?;
+
+ Ok(tree)
+}
+
+pub(crate) fn commit_sender(
+ sender: &Sender,
+ provisional_state: &ProvisionalState,
+) -> Result<LeafIndex, MlsError> {
+ match sender {
+ Sender::Member(index) => Ok(LeafIndex(*index)),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::External(_) => Err(MlsError::ExternalSenderCannotCommit),
+ #[cfg(feature = "by_ref_proposal")]
+ Sender::NewMemberProposal => Err(MlsError::ExpectedAddProposalForNewMemberProposal),
+ Sender::NewMemberCommit => provisional_state
+ .external_init_index
+ .ok_or(MlsError::ExternalCommitMissingExternalInit),
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn transcript_hashes<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ prev_interim_transcript_hash: &InterimTranscriptHash,
+ content: &AuthenticatedContent,
+) -> Result<(InterimTranscriptHash, ConfirmedTranscriptHash), MlsError> {
+ let confirmed_transcript_hash = ConfirmedTranscriptHash::create(
+ cipher_suite_provider,
+ prev_interim_transcript_hash,
+ content,
+ )
+ .await?;
+
+ let confirmation_tag = content
+ .auth
+ .confirmation_tag
+ .as_ref()
+ .ok_or(MlsError::InvalidConfirmationTag)?;
+
+ let interim_transcript_hash = InterimTranscriptHash::create(
+ cipher_suite_provider,
+ &confirmed_transcript_hash,
+ confirmation_tag,
+ )
+ .await?;
+
+ Ok((interim_transcript_hash, confirmed_transcript_hash))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn find_key_package_generation<'a, K: KeyPackageStorage>(
+ key_package_repo: &K,
+ secrets: &'a [EncryptedGroupSecrets],
+) -> Result<(&'a EncryptedGroupSecrets, KeyPackageGeneration), MlsError> {
+ for secret in secrets {
+ if let Some(val) = key_package_repo
+ .get(&secret.new_member)
+ .await
+ .map_err(|e| MlsError::KeyPackageRepoError(e.into_any_error()))
+ .and_then(|maybe_data| {
+ if let Some(data) = maybe_data {
+ KeyPackageGeneration::from_storage(secret.new_member.to_vec(), data)
+ .map(|kpg| Some((secret, kpg)))
+ } else {
+ Ok::<_, MlsError>(None)
+ }
+ })?
+ {
+ return Ok(val);
+ }
+ }
+
+ Err(MlsError::WelcomeKeyPackageNotFound)
+}
+
+pub(crate) fn cipher_suite_provider<P>(
+ crypto: P,
+ cipher_suite: CipherSuite,
+) -> Result<P::CipherSuiteProvider, MlsError>
+where
+ P: CryptoProvider,
+{
+ crypto
+ .cipher_suite_provider(cipher_suite)
+ .ok_or(MlsError::UnsupportedCipherSuite(cipher_suite))
+}
diff --git a/src/hash_reference.rs b/src/hash_reference.rs
new file mode 100644
index 0000000..41cb156
--- /dev/null
+++ b/src/hash_reference.rs
@@ -0,0 +1,166 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+
+use crate::client::MlsError;
+use crate::CipherSuiteProvider;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(MlsSize, MlsEncode)]
+struct RefHashInput<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub label: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub value: &'a [u8],
+}
+
+impl Debug for RefHashInput<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("RefHashInput")
+ .field("label", &mls_rs_core::debug::pretty_bytes(self.label))
+ .field("value", &mls_rs_core::debug::pretty_bytes(self.value))
+ .finish()
+ }
+}
+
+#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Clone, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct HashReference(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for HashReference {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("HashReference")
+ .fmt(f)
+ }
+}
+
+impl Deref for HashReference {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl AsRef<[u8]> for HashReference {
+ fn as_ref(&self) -> &[u8] {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for HashReference {
+ fn from(val: Vec<u8>) -> Self {
+ Self(val)
+ }
+}
+
+impl HashReference {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn compute<P: CipherSuiteProvider>(
+ value: &[u8],
+ label: &[u8],
+ cipher_suite: &P,
+ ) -> Result<HashReference, MlsError> {
+ let input = RefHashInput { label, value };
+ let input_bytes = input.mls_encode_to_vec()?;
+
+ cipher_suite
+ .hash(&input_bytes)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ .map(HashReference)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+
+ #[cfg(not(mls_build_async))]
+ use crate::{cipher_suite::CipherSuite, crypto::test_utils::test_cipher_suite_provider};
+
+ use super::*;
+ use alloc::string::String;
+ use serde::{Deserialize, Serialize};
+
+ #[cfg(not(mls_build_async))]
+ use alloc::string::ToString;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Debug, Deserialize, Serialize)]
+ struct HashRefTestCase {
+ label: String,
+ #[serde(with = "hex::serde")]
+ value: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ out: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ ref_hash: HashRefTestCase,
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ CipherSuite::all()
+ .map(|cipher_suite| {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let input = b"test input";
+ let label = "test label";
+
+ let output = HashReference::compute(input, label.as_bytes(), &provider).unwrap();
+
+ let ref_hash = HashRefTestCase {
+ label: label.to_string(),
+ value: input.to_vec(),
+ out: output.to_vec(),
+ };
+
+ InteropTestCase {
+ cipher_suite: cipher_suite.into(),
+ ref_hash,
+ }
+ })
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate_test_vector() -> Vec<InteropTestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, generate_test_vector());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ let label = test_case.ref_hash.label.as_bytes();
+ let value = &test_case.ref_hash.value;
+ let computed = HashReference::compute(value, label, &cs).await.unwrap();
+ assert_eq!(&*computed, &test_case.ref_hash.out);
+ }
+ }
+ }
+}
diff --git a/src/identity.rs b/src/identity.rs
new file mode 100644
index 0000000..5de7a11
--- /dev/null
+++ b/src/identity.rs
@@ -0,0 +1,182 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Basic credential identity provider.
+pub mod basic;
+
+/// X.509 certificate identity provider.
+#[cfg(feature = "x509")]
+pub mod x509 {
+ pub use mls_rs_identity_x509::*;
+}
+
+pub use mls_rs_core::identity::{
+ Credential, CredentialType, CustomCredential, MlsCredential, SigningIdentity,
+};
+
+pub use mls_rs_core::group::RosterUpdate;
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::boxed::Box;
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, SignatureSecretKey},
+ error::IntoAnyError,
+ extension::ExtensionList,
+ identity::{Credential, CredentialType, IdentityProvider, SigningIdentity},
+ time::MlsTime,
+ };
+
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ use super::basic::{BasicCredential, BasicIdentityProvider};
+
+ #[derive(Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(
+ feature = "std",
+ error("expected basic or custom credential type 42 found: {0:?}")
+ )]
+ pub struct BasicWithCustomProviderError(CredentialType);
+
+ impl IntoAnyError for BasicWithCustomProviderError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ pub struct BasicWithCustomProvider {
+ pub(crate) basic: BasicIdentityProvider,
+ pub(crate) allow_any_custom: bool,
+ supported_cred_types: Vec<CredentialType>,
+ }
+
+ impl BasicWithCustomProvider {
+ pub const CUSTOM_CREDENTIAL_TYPE: u16 = 42;
+
+ pub fn new(basic: BasicIdentityProvider) -> BasicWithCustomProvider {
+ BasicWithCustomProvider {
+ basic,
+ allow_any_custom: false,
+ supported_cred_types: vec![
+ CredentialType::BASIC,
+ Self::CUSTOM_CREDENTIAL_TYPE.into(),
+ ],
+ }
+ }
+
+ pub fn with_credential_type(mut self, cred_type: CredentialType) -> Self {
+ self.supported_cred_types.push(cred_type);
+ self
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_custom_identity(
+ &self,
+ signing_id: &SigningIdentity,
+ ) -> Result<Vec<u8>, BasicWithCustomProviderError> {
+ self.basic
+ .identity(signing_id, &Default::default())
+ .await
+ .or_else(|_| {
+ signing_id
+ .credential
+ .as_custom()
+ .map(|c| {
+ if c.credential_type
+ == CredentialType::from(Self::CUSTOM_CREDENTIAL_TYPE)
+ || self.allow_any_custom
+ {
+ Ok(c.data.to_vec())
+ } else {
+ Err(BasicWithCustomProviderError(c.credential_type))
+ }
+ })
+ .transpose()?
+ .ok_or_else(|| {
+ BasicWithCustomProviderError(signing_id.credential.credential_type())
+ })
+ })
+ }
+ }
+
+ impl Default for BasicWithCustomProvider {
+ fn default() -> Self {
+ Self::new(BasicIdentityProvider::new())
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for BasicWithCustomProvider {
+ type Error = BasicWithCustomProviderError;
+
+ async fn validate_member(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ //TODO: Is it actually beneficial to check the key, or does that already happen elsewhere before
+ //this point?
+ Ok(())
+ }
+
+ async fn validate_external_sender(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ //TODO: Is it actually beneficial to check the key, or does that already happen elsewhere before
+ //this point?
+ Ok(())
+ }
+
+ async fn identity(
+ &self,
+ signing_id: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ self.resolve_custom_identity(signing_id).await
+ }
+
+ async fn valid_successor(
+ &self,
+ predecessor: &SigningIdentity,
+ successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ let predecessor = self.resolve_custom_identity(predecessor).await?;
+ let successor = self.resolve_custom_identity(successor).await?;
+
+ Ok(predecessor == successor)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ self.supported_cred_types.clone()
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_signing_identity(
+ cipher_suite: CipherSuite,
+ identity: &[u8],
+ ) -> (SigningIdentity, SignatureSecretKey) {
+ let provider = test_cipher_suite_provider(cipher_suite);
+ let (secret_key, public_key) = provider.signature_key_generate().await.unwrap();
+
+ let basic = get_test_basic_credential(identity.to_vec());
+
+ (SigningIdentity::new(basic, public_key), secret_key)
+ }
+
+ pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
+ BasicCredential::new(identity).into_credential()
+ }
+}
diff --git a/src/identity/basic.rs b/src/identity/basic.rs
new file mode 100644
index 0000000..b93ab6a
--- /dev/null
+++ b/src/identity/basic.rs
@@ -0,0 +1,99 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{identity::CredentialType, identity::SigningIdentity, time::MlsTime};
+use alloc::vec;
+use alloc::vec::Vec;
+pub use mls_rs_core::identity::BasicCredential;
+use mls_rs_core::{error::IntoAnyError, extension::ExtensionList, identity::IdentityProvider};
+
+#[derive(Debug)]
+#[cfg_attr(feature = "std", derive(thiserror::Error))]
+#[cfg_attr(feature = "std", error("unsupported credential type found: {0:?}"))]
+/// Error returned in the event that a non-basic
+/// credential is passed to a [`BasicIdentityProvider`].
+pub struct BasicIdentityProviderError(CredentialType);
+
+impl IntoAnyError for BasicIdentityProviderError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+}
+
+impl BasicIdentityProviderError {
+ pub fn credential_type(&self) -> CredentialType {
+ self.0
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+/// An always-valid identity provider that works with [`BasicCredential`].
+///
+/// # Warning
+///
+/// This provider always returns `true` for `validate` as long as the
+/// [`SigningIdentity`] used contains a [`BasicCredential`]. It is only
+/// recommended to use this provider for testing purposes.
+pub struct BasicIdentityProvider;
+
+impl BasicIdentityProvider {
+ pub fn new() -> Self {
+ Self
+ }
+}
+
+fn resolve_basic_identity(
+ signing_id: &SigningIdentity,
+) -> Result<&BasicCredential, BasicIdentityProviderError> {
+ signing_id
+ .credential
+ .as_basic()
+ .ok_or_else(|| BasicIdentityProviderError(signing_id.credential.credential_type()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl IdentityProvider for BasicIdentityProvider {
+ type Error = BasicIdentityProviderError;
+
+ async fn validate_member(
+ &self,
+ signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ resolve_basic_identity(signing_identity).map(|_| ())
+ }
+
+ async fn validate_external_sender(
+ &self,
+ signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ resolve_basic_identity(signing_identity).map(|_| ())
+ }
+
+ async fn identity(
+ &self,
+ signing_identity: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ resolve_basic_identity(signing_identity).map(|b| b.identifier.to_vec())
+ }
+
+ async fn valid_successor(
+ &self,
+ predecessor: &SigningIdentity,
+ successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Ok(resolve_basic_identity(predecessor)? == resolve_basic_identity(successor)?)
+ }
+
+ fn supported_types(&self) -> Vec<CredentialType> {
+ vec![BasicCredential::credential_type()]
+ }
+}
diff --git a/src/iter.rs b/src/iter.rs
new file mode 100644
index 0000000..e37f162
--- /dev/null
+++ b/src/iter.rs
@@ -0,0 +1,96 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+mod sync_rayon {
+ use rayon::{
+ iter::IterBridge,
+ prelude::{FromParallelIterator, IntoParallelIterator, ParallelBridge, ParallelIterator},
+ };
+
+ pub fn wrap_iter<I>(it: I) -> I::Iter
+ where
+ I: IntoParallelIterator,
+ {
+ it.into_par_iter()
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> IterBridge<I::IntoIter>
+ where
+ I: IntoIterator,
+ I::IntoIter: Send,
+ I::Item: Send,
+ {
+ it.into_iter().par_bridge()
+ }
+
+ pub trait ParallelIteratorExt {
+ type Ok: Send;
+ type Error: Send;
+
+ fn try_collect<A>(self) -> Result<A, Self::Error>
+ where
+ A: FromParallelIterator<Self::Ok>;
+ }
+
+ impl<I, T, E> ParallelIteratorExt for I
+ where
+ I: ParallelIterator<Item = Result<T, E>>,
+ T: Send,
+ E: Send,
+ {
+ type Ok = T;
+ type Error = E;
+
+ fn try_collect<A>(self) -> Result<A, Self::Error>
+ where
+ A: FromParallelIterator<Self::Ok>,
+ {
+ self.collect()
+ }
+ }
+}
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+pub use sync_rayon::{wrap_impl_iter, wrap_iter, ParallelIteratorExt};
+
+#[cfg(not(any(mls_build_async, feature = "rayon")))]
+mod sync {
+ pub fn wrap_iter<I>(it: I) -> I::IntoIter
+ where
+ I: IntoIterator,
+ {
+ it.into_iter()
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> I::IntoIter
+ where
+ I: IntoIterator,
+ {
+ it.into_iter()
+ }
+}
+
+#[cfg(not(any(mls_build_async, feature = "rayon")))]
+pub use sync::{wrap_impl_iter, wrap_iter};
+
+#[cfg(mls_build_async)]
+mod async_ {
+ pub fn wrap_iter<I>(it: I) -> futures::stream::Iter<I::IntoIter>
+ where
+ I: IntoIterator,
+ {
+ futures::stream::iter(it)
+ }
+
+ pub fn wrap_impl_iter<I>(it: I) -> futures::stream::Iter<I::IntoIter>
+ where
+ I: IntoIterator,
+ {
+ futures::stream::iter(it)
+ }
+}
+
+#[cfg(mls_build_async)]
+pub use async_::{wrap_impl_iter, wrap_iter};
diff --git a/src/key_package/generator.rs b/src/key_package/generator.rs
new file mode 100644
index 0000000..4d71094
--- /dev/null
+++ b/src/key_package/generator.rs
@@ -0,0 +1,339 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider, key_package::KeyPackageData};
+
+use crate::client::MlsError;
+use crate::{
+ crypto::{HpkeSecretKey, SignatureSecretKey},
+ group::framing::MlsMessagePayload,
+ identity::SigningIdentity,
+ protocol_version::ProtocolVersion,
+ signer::Signable,
+ tree_kem::{
+ leaf_node::{ConfigProperties, LeafNode},
+ Capabilities, Lifetime,
+ },
+ CipherSuiteProvider, ExtensionList, MlsMessage,
+};
+
+use super::{KeyPackage, KeyPackageRef};
+
+#[derive(Clone, Debug)]
+pub struct KeyPackageGenerator<'a, IP, CP>
+where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ pub protocol_version: ProtocolVersion,
+ pub cipher_suite_provider: &'a CP,
+ pub signing_identity: &'a SigningIdentity,
+ pub signing_key: &'a SignatureSecretKey,
+ pub identity_provider: &'a IP,
+}
+
+#[derive(Clone, Debug)]
+pub struct KeyPackageGeneration {
+ pub(crate) reference: KeyPackageRef,
+ pub(crate) key_package: KeyPackage,
+ pub(crate) init_secret_key: HpkeSecretKey,
+ pub(crate) leaf_node_secret_key: HpkeSecretKey,
+}
+
+impl KeyPackageGeneration {
+ pub fn to_storage(&self) -> Result<(Vec<u8>, KeyPackageData), MlsError> {
+ let id = self.reference.to_vec();
+
+ let data = KeyPackageData::new(
+ self.key_package.mls_encode_to_vec()?,
+ self.init_secret_key.clone(),
+ self.leaf_node_secret_key.clone(),
+ self.key_package.expiration()?,
+ );
+
+ Ok((id, data))
+ }
+
+ pub fn from_storage(id: Vec<u8>, data: KeyPackageData) -> Result<Self, MlsError> {
+ Ok(KeyPackageGeneration {
+ reference: KeyPackageRef::from(id),
+ key_package: KeyPackage::mls_decode(&mut &*data.key_package_bytes)?,
+ init_secret_key: data.init_key,
+ leaf_node_secret_key: data.leaf_node_key,
+ })
+ }
+
+ pub fn key_package_message(&self) -> MlsMessage {
+ MlsMessage::new(
+ self.key_package.version,
+ MlsMessagePayload::KeyPackage(self.key_package.clone()),
+ )
+ }
+}
+
+impl<'a, IP, CP> KeyPackageGenerator<'a, IP, CP>
+where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn sign(&self, package: &mut KeyPackage) -> Result<(), MlsError> {
+ package
+ .sign(self.cipher_suite_provider, self.signing_key, &())
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate(
+ &self,
+ lifetime: Lifetime,
+ capabilities: Capabilities,
+ key_package_extensions: ExtensionList,
+ leaf_node_extensions: ExtensionList,
+ ) -> Result<KeyPackageGeneration, MlsError> {
+ let (init_secret_key, public_init) = self
+ .cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let properties = ConfigProperties {
+ capabilities,
+ extensions: leaf_node_extensions,
+ };
+
+ let (leaf_node, leaf_node_secret) = LeafNode::generate(
+ self.cipher_suite_provider,
+ properties,
+ self.signing_identity.clone(),
+ self.signing_key,
+ lifetime,
+ )
+ .await?;
+
+ let mut package = KeyPackage {
+ version: self.protocol_version,
+ cipher_suite: self.cipher_suite_provider.cipher_suite(),
+ hpke_init_key: public_init,
+ leaf_node,
+ extensions: key_package_extensions,
+ signature: vec![],
+ };
+
+ package.grease(self.cipher_suite_provider)?;
+
+ self.sign(&mut package).await?;
+
+ let reference = package.to_reference(self.cipher_suite_provider).await?;
+
+ Ok(KeyPackageGeneration {
+ key_package: package,
+ init_secret_key,
+ leaf_node_secret_key: leaf_node_secret,
+ reference,
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::{
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ extension::test_utils::TestExtension,
+ group::test_utils::random_bytes,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ key_package::validate_key_package_properties,
+ protocol_version::ProtocolVersion,
+ tree_kem::{
+ leaf_node::{test_utils::get_test_capabilities, LeafNodeSource},
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ Lifetime,
+ },
+ ExtensionList,
+ };
+
+ use super::KeyPackageGenerator;
+
+ fn test_key_package_ext(val: u8) -> ExtensionList {
+ let mut ext_list = ExtensionList::new();
+ ext_list.set_from(TestExtension::from(val)).unwrap();
+ ext_list
+ }
+
+ fn test_leaf_node_ext(val: u8) -> ExtensionList {
+ let mut ext_list = ExtensionList::new();
+ ext_list.set_from(TestExtension::from(val)).unwrap();
+ ext_list
+ }
+
+ fn test_lifetime() -> Lifetime {
+ Lifetime::years(1).unwrap()
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_generation() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, signing_key) =
+ get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let key_package_ext = test_key_package_ext(32);
+ let leaf_node_ext = test_leaf_node_ext(42);
+ let lifetime = test_lifetime();
+
+ let test_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &cipher_suite_provider,
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+ capabilities.extensions.push(43.into());
+ capabilities.extensions.push(32.into());
+
+ let generated = test_generator
+ .generate(
+ lifetime.clone(),
+ capabilities.clone(),
+ key_package_ext.clone(),
+ leaf_node_ext.clone(),
+ )
+ .await
+ .unwrap();
+
+ assert_matches!(generated.key_package.leaf_node.leaf_node_source,
+ LeafNodeSource::KeyPackage(ref lt) if lt == &lifetime);
+
+ assert_eq!(
+ generated.key_package.leaf_node.ungreased_capabilities(),
+ capabilities
+ );
+
+ assert_eq!(
+ generated.key_package.leaf_node.ungreased_extensions(),
+ leaf_node_ext
+ );
+
+ assert_eq!(
+ generated.key_package.ungreased_extensions(),
+ key_package_ext
+ );
+
+ assert_ne!(
+ generated.key_package.hpke_init_key.as_ref(),
+ generated.key_package.leaf_node.public_key.as_ref()
+ );
+
+ assert_eq!(generated.key_package.cipher_suite, cipher_suite);
+ assert_eq!(generated.key_package.version, protocol_version);
+
+ // Verify that the hpke key pair generated will work
+ let test_data = random_bytes(32);
+
+ let sealed = cipher_suite_provider
+ .hpke_seal(&generated.key_package.hpke_init_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = cipher_suite_provider
+ .hpke_open(
+ &sealed,
+ &generated.init_secret_key,
+ &generated.key_package.hpke_init_key,
+ &[],
+ None,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(opened, test_data);
+
+ let validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ validator
+ .check_if_valid(
+ &generated.key_package.leaf_node,
+ ValidationContext::Add(None),
+ )
+ .await
+ .unwrap();
+
+ validate_key_package_properties(
+ &generated.key_package,
+ protocol_version,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_randomness() {
+ for (protocol_version, cipher_suite) in ProtocolVersion::all().flat_map(|p| {
+ TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .map(move |cs| (p, cs))
+ }) {
+ let (signing_identity, signing_key) =
+ get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let test_generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &signing_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let first_key_package = test_generator
+ .generate(
+ test_lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ for _ in 0..100 {
+ let next_key_package = test_generator
+ .generate(
+ test_lifetime(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(
+ first_key_package.key_package.hpke_init_key,
+ next_key_package.key_package.hpke_init_key
+ );
+
+ assert_ne!(
+ first_key_package.key_package.leaf_node.public_key,
+ next_key_package.key_package.leaf_node.public_key
+ );
+ }
+ }
+ }
+}
diff --git a/src/key_package/mod.rs b/src/key_package/mod.rs
new file mode 100644
index 0000000..b3ef83b
--- /dev/null
+++ b/src/key_package/mod.rs
@@ -0,0 +1,332 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::cipher_suite::CipherSuite;
+use crate::client::MlsError;
+use crate::crypto::HpkePublicKey;
+use crate::hash_reference::HashReference;
+use crate::identity::SigningIdentity;
+use crate::protocol_version::ProtocolVersion;
+use crate::signer::Signable;
+use crate::tree_kem::leaf_node::{LeafNode, LeafNodeSource};
+use crate::CipherSuiteProvider;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::MlsDecode;
+use mls_rs_codec::MlsEncode;
+use mls_rs_codec::MlsSize;
+use mls_rs_core::extension::ExtensionList;
+
+mod validator;
+pub(crate) use validator::*;
+
+pub(crate) mod generator;
+pub(crate) use generator::*;
+
+#[non_exhaustive]
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct KeyPackage {
+ pub version: ProtocolVersion,
+ pub cipher_suite: CipherSuite,
+ pub hpke_init_key: HpkePublicKey,
+ pub(crate) leaf_node: LeafNode,
+ pub extensions: ExtensionList,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub signature: Vec<u8>,
+}
+
+impl Debug for KeyPackage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("KeyPackage")
+ .field("version", &self.version)
+ .field("cipher_suite", &self.cipher_suite)
+ .field("hpke_init_key", &self.hpke_init_key)
+ .field("leaf_node", &self.leaf_node)
+ .field("extensions", &self.extensions)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(
+ all(feature = "ffi", not(test)),
+ safer_ffi_gen::ffi_type(clone, opaque)
+)]
+pub struct KeyPackageRef(HashReference);
+
+impl Deref for KeyPackageRef {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for KeyPackageRef {
+ fn from(v: Vec<u8>) -> Self {
+ Self(HashReference::from(v))
+ }
+}
+
+#[derive(MlsSize, MlsEncode)]
+struct KeyPackageData<'a> {
+ pub version: ProtocolVersion,
+ pub cipher_suite: CipherSuite,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ pub hpke_init_key: &'a HpkePublicKey,
+ pub leaf_node: &'a LeafNode,
+ pub extensions: &'a ExtensionList,
+}
+
+#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen)]
+impl KeyPackage {
+ #[cfg(feature = "ffi")]
+ pub fn version(&self) -> ProtocolVersion {
+ self.version
+ }
+
+ #[cfg(feature = "ffi")]
+ pub fn cipher_suite(&self) -> CipherSuite {
+ self.cipher_suite
+ }
+
+ pub fn signing_identity(&self) -> &SigningIdentity {
+ &self.leaf_node.signing_identity
+ }
+
+ #[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::safer_ffi_gen_ignore)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn to_reference<CP: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &CP,
+ ) -> Result<KeyPackageRef, MlsError> {
+ if cipher_suite_provider.cipher_suite() != self.cipher_suite {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ Ok(KeyPackageRef(
+ HashReference::compute(
+ &self.mls_encode_to_vec()?,
+ b"MLS 1.0 KeyPackage Reference",
+ cipher_suite_provider,
+ )
+ .await?,
+ ))
+ }
+
+ pub fn expiration(&self) -> Result<u64, MlsError> {
+ if let LeafNodeSource::KeyPackage(lifetime) = &self.leaf_node.leaf_node_source {
+ Ok(lifetime.not_after)
+ } else {
+ Err(MlsError::InvalidLeafNodeSource)
+ }
+ }
+}
+
+impl<'a> Signable<'a> for KeyPackage {
+ const SIGN_LABEL: &'static str = "KeyPackageTBS";
+
+ type SigningContext = ();
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ _context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ KeyPackageData {
+ version: self.version,
+ cipher_suite: self.cipher_suite,
+ hpke_init_key: &self.hpke_init_key,
+ leaf_node: &self.leaf_node,
+ extensions: &self.extensions,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider,
+ group::framing::MlsMessagePayload,
+ identity::basic::BasicIdentityProvider,
+ identity::test_utils::get_test_signing_identity,
+ tree_kem::{leaf_node::test_utils::get_test_capabilities, Lifetime},
+ MlsMessage,
+ };
+
+ use mls_rs_core::crypto::SignatureSecretKey;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> KeyPackage {
+ test_key_package_with_signer(protocol_version, cipher_suite, id)
+ .await
+ .0
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package_with_signer(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> (KeyPackage, SignatureSecretKey) {
+ let (signing_identity, secret_key) =
+ get_test_signing_identity(cipher_suite, id.as_bytes()).await;
+
+ let generator = KeyPackageGenerator {
+ protocol_version,
+ cipher_suite_provider: &test_cipher_suite_provider(cipher_suite),
+ signing_identity: &signing_identity,
+ signing_key: &secret_key,
+ identity_provider: &BasicIdentityProvider,
+ };
+
+ let key_package = generator
+ .generate(
+ Lifetime::years(1).unwrap(),
+ get_test_capabilities(),
+ ExtensionList::default(),
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap()
+ .key_package;
+
+ (key_package, secret_key)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_key_package_message(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> MlsMessage {
+ MlsMessage::new(
+ protocol_version,
+ MlsMessagePayload::KeyPackage(
+ test_key_package(protocol_version, cipher_suite, id).await,
+ ),
+ )
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ client::test_utils::{TEST_CIPHER_SUITE, TEST_PROTOCOL_VERSION},
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ };
+
+ use super::{test_utils::test_key_package, *};
+ use alloc::format;
+ use assert_matches::assert_matches;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ input: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ output: Vec<u8>,
+ }
+
+ impl TestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for (i, (protocol_version, cipher_suite)) in ProtocolVersion::all()
+ .flat_map(|p| CipherSuite::all().map(move |cs| (p, cs)))
+ .enumerate()
+ {
+ let pkg =
+ test_key_package(protocol_version, cipher_suite, &format!("alice{i}")).await;
+
+ let pkg_ref = pkg
+ .to_reference(&test_cipher_suite_provider(cipher_suite))
+ .await
+ .unwrap();
+
+ let case = TestCase {
+ cipher_suite: cipher_suite.into(),
+ input: pkg.mls_encode_to_vec().unwrap(),
+ output: pkg_ref.to_vec(),
+ };
+
+ test_cases.push(case);
+ }
+
+ test_cases
+ }
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(key_package_ref, TestCase::generate().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(key_package_ref, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_package_ref() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let key_package = KeyPackage::mls_decode(&mut one_case.input.as_slice()).unwrap();
+
+ let key_package_ref = key_package.to_reference(&provider).await.unwrap();
+
+ let expected_out = KeyPackageRef::from(one_case.output);
+ assert_eq!(expected_out, key_package_ref);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn key_package_ref_fails_invalid_cipher_suite() {
+ let key_package = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "test").await;
+
+ for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
+ if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
+ let res = key_package.to_reference(&cs).await;
+
+ assert_matches!(res, Err(MlsError::CipherSuiteMismatch));
+ }
+ }
+ }
+}
diff --git a/src/key_package/validator.rs b/src/key_package/validator.rs
new file mode 100644
index 0000000..9cf1dae
--- /dev/null
+++ b/src/key_package/validator.rs
@@ -0,0 +1,39 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs_core::{crypto::CipherSuiteProvider, protocol_version::ProtocolVersion};
+
+use crate::{client::MlsError, signer::Signable, KeyPackage};
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_key_package_properties<CSP: CipherSuiteProvider>(
+ package: &KeyPackage,
+ version: ProtocolVersion,
+ cs: &CSP,
+) -> Result<(), MlsError> {
+ package
+ .verify(cs, &package.leaf_node.signing_identity.signature_key, &())
+ .await?;
+
+ // Verify that the protocol version matches
+ if package.version != version {
+ return Err(MlsError::ProtocolVersionMismatch);
+ }
+
+ // Verify that the cipher suite matches
+ if package.cipher_suite != cs.cipher_suite() {
+ return Err(MlsError::CipherSuiteMismatch);
+ }
+
+ // Verify that the public init key is a valid format for this cipher suite
+ cs.kem_public_key_validate(&package.hpke_init_key)
+ .map_err(|_| MlsError::InvalidInitKey)?;
+
+ // Verify that the init key and the leaf node public key are different
+ if package.hpke_init_key.as_ref() == package.leaf_node.public_key.as_ref() {
+ return Err(MlsError::InitLeafKeyEquality);
+ }
+
+ Ok(())
+}
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..115b3f8
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,218 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+//! An implementation of the [IETF Messaging Layer Security](https://messaginglayersecurity.rocks)
+//! end-to-end encryption (E2EE) protocol.
+//!
+//! ## What is MLS?
+//!
+//! MLS is a new IETF end-to-end encryption standard that is designed to
+//! provide transport agnostic, asynchronous, and highly performant
+//! communication between a group of clients.
+//!
+//! ## MLS Protocol Features
+//!
+//! - Multi-party E2EE [group evolution](https://www.rfc-editor.org/rfc/rfc9420.html#name-cryptographic-state-and-evo)
+//! via a propose-then-commit mechanism.
+//! - Asynchronous by design with pre-computed [key packages](https://www.rfc-editor.org/rfc/rfc9420.html#name-key-packages),
+//! allowing members to be added to a group while offline.
+//! - Customizable credential system with built in support for X.509 certificates.
+//! - [Extension system](https://www.rfc-editor.org/rfc/rfc9420.html#name-extensions)
+//! allowing for application specific data to be negotiated via the protocol.
+//! - Strong forward secrecy and post compromise security.
+//! - Crypto agility via support for multiple [cipher suites](https://www.rfc-editor.org/rfc/rfc9420.html#name-cipher-suites).
+//! - Pre-shared key support.
+//! - Subgroup branching.
+//! - Group reinitialization for breaking changes such as protocol upgrades.
+//!
+//! ## Features
+//!
+//! - Easy to use client interface that can manage multiple MLS identities and groups.
+//! - 100% RFC 9420 conformance with support for all default credential, proposal,
+//! and extension types.
+//! - Support for WASM builds.
+//! - Configurable storage for key packages, secrets and group state
+//! via traits along with provided "in memory" and SQLite implementations.
+//! - Support for custom user proposal and extension types.
+//! - Ability to create user defined credentials with custom validation
+//! routines that can bridge to existing credential schemes.
+//! - OpenSSL and Rust Crypto based cipher suite implementations.
+//! - Crypto agility with support for user defined cipher suite.
+//! - Extensive test suite including security and interop focused tests against
+//! pre-computed test vectors.
+//!
+//! ## Crypto Providers
+//!
+//! For cipher suite descriptions see the RFC documentation [here](https://www.rfc-editor.org/rfc/rfc9420.html#name-mls-cipher-suites)
+//!
+//! | Name | Cipher Suites | X509 Support |
+//! |------|---------------|--------------|
+//! | OpenSSL | 1-7 | Stable |
+//! | AWS-LC | 1,2,3,5,7 | Stable |
+//! | Rust Crypto | 1,2,3 | ⚠️ Experimental |
+//!
+//! ## Security Notice
+//!
+//! This library has been validated for conformance to the RFC 9420 specification but has not yet received a full security audit by a 3rd party.
+
+#![allow(clippy::enum_variant_names)]
+#![allow(clippy::result_large_err)]
+#![allow(clippy::nonstandard_macro_braces)]
+#![cfg_attr(not(feature = "std"), no_std)]
+#![cfg_attr(docsrs, feature(doc_cfg))]
+#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
+extern crate alloc;
+
+#[cfg(all(test, target_arch = "wasm32"))]
+wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
+
+#[cfg(all(test, target_arch = "wasm32"))]
+use wasm_bindgen_test::wasm_bindgen_test as futures_test;
+
+#[cfg(all(test, mls_build_async, not(target_arch = "wasm32")))]
+use futures_test::test as futures_test;
+
+#[cfg(test)]
+macro_rules! hex {
+ ($input:literal) => {
+ hex::decode($input).expect("invalid hex value")
+ };
+}
+
+#[cfg(test)]
+macro_rules! load_test_case_json {
+ ($name:ident, $generate:expr) => {
+ load_test_case_json!($name, $generate, to_vec_pretty)
+ };
+ ($name:ident, $generate:expr, $to_json:ident) => {{
+ #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
+ {
+ // Do not remove `async`! (The goal of this line is to remove warnings
+ // about `$generate` not being used. Actually calling it will make tests fail.)
+ let _ = async { $generate };
+ serde_json::from_slice(include_bytes!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".json"
+ )))
+ .unwrap()
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
+ {
+ let path = concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".json"
+ );
+ if !std::path::Path::new(path).exists() {
+ std::fs::write(path, serde_json::$to_json(&$generate).unwrap()).unwrap();
+ }
+ serde_json::from_slice(&std::fs::read(path).unwrap()).unwrap()
+ }
+ }};
+}
+
+mod cipher_suite {
+ pub use mls_rs_core::crypto::CipherSuite;
+}
+
+pub use cipher_suite::CipherSuite;
+
+mod protocol_version {
+ pub use mls_rs_core::protocol_version::ProtocolVersion;
+}
+
+pub use protocol_version::ProtocolVersion;
+
+pub mod client;
+pub mod client_builder;
+mod client_config;
+/// Dependencies of [`CryptoProvider`] and [`CipherSuiteProvider`]
+pub mod crypto;
+/// Extension utilities and built-in extension types.
+pub mod extension;
+/// Tools to observe groups without being a member, useful
+/// for server implementations.
+#[cfg(feature = "external_client")]
+#[cfg_attr(docsrs, doc(cfg(feature = "external_client")))]
+pub mod external_client;
+mod grease;
+/// E2EE group created by a [`Client`].
+pub mod group;
+mod hash_reference;
+/// Identity providers to use with [`ClientBuilder`](client_builder::ClientBuilder).
+pub mod identity;
+mod iter;
+mod key_package;
+/// Pre-shared key support.
+pub mod psk;
+mod signer;
+/// Storage providers to use with
+/// [`ClientBuilder`](client_builder::ClientBuilder).
+pub mod storage_provider;
+
+pub use mls_rs_core::{
+ crypto::{CipherSuiteProvider, CryptoProvider},
+ group::GroupStateStorage,
+ identity::IdentityProvider,
+ key_package::KeyPackageStorage,
+ psk::PreSharedKeyStorage,
+};
+
+/// Dependencies of [`MlsRules`].
+pub mod mls_rules {
+ pub use crate::group::{
+ mls_rules::{
+ CommitDirection, CommitOptions, CommitSource, DefaultMlsRules, EncryptionOptions,
+ },
+ proposal_filter::{ProposalBundle, ProposalInfo, ProposalSource},
+ };
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub use crate::group::proposal_ref::ProposalRef;
+}
+
+pub use mls_rs_core::extension::{Extension, ExtensionList};
+
+pub use crate::{
+ client::Client,
+ group::{
+ framing::{MlsMessage, WireFormat},
+ mls_rules::MlsRules,
+ Group,
+ },
+ key_package::{KeyPackage, KeyPackageRef},
+};
+
+/// Error types.
+pub mod error {
+ pub use crate::client::MlsError;
+ pub use mls_rs_core::error::{AnyError, IntoAnyError};
+ pub use mls_rs_core::extension::ExtensionError;
+}
+
+/// WASM compatible timestamp.
+pub mod time {
+ pub use mls_rs_core::time::*;
+}
+
+mod tree_kem;
+
+pub use mls_rs_codec;
+
+mod private {
+ pub trait Sealed {}
+}
+
+use private::Sealed;
+
+#[cfg(any(test, feature = "test_util"))]
+#[doc(hidden)]
+pub mod test_utils;
+
+#[cfg(feature = "ffi")]
+pub use safer_ffi_gen;
diff --git a/src/message.rs b/src/message.rs
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/message.rs
diff --git a/src/psk.rs b/src/psk.rs
new file mode 100644
index 0000000..5bf95c3
--- /dev/null
+++ b/src/psk.rs
@@ -0,0 +1,200 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+
+#[cfg(any(test, feature = "external_client"))]
+use alloc::vec;
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[cfg(any(test, feature = "external_client"))]
+use mls_rs_core::psk::PreSharedKeyStorage;
+
+#[cfg(any(test, feature = "external_client"))]
+use core::convert::Infallible;
+use core::fmt::{self, Debug};
+
+#[cfg(feature = "psk")]
+use crate::{client::MlsError, CipherSuiteProvider};
+
+#[cfg(feature = "psk")]
+use mls_rs_core::error::IntoAnyError;
+
+#[cfg(feature = "psk")]
+pub(crate) mod resolver;
+pub(crate) mod secret;
+
+pub use mls_rs_core::psk::{ExternalPskId, PreSharedKey};
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PreSharedKeyID {
+ pub key_id: JustPreSharedKeyID,
+ pub psk_nonce: PskNonce,
+}
+
+impl PreSharedKeyID {
+ #[cfg(feature = "psk")]
+ pub(crate) fn new<P: CipherSuiteProvider>(
+ key_id: JustPreSharedKeyID,
+ cs: &P,
+ ) -> Result<Self, MlsError> {
+ Ok(Self {
+ key_id,
+ psk_nonce: PskNonce::random(cs)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?,
+ })
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum JustPreSharedKeyID {
+ External(ExternalPskId) = 1u8,
+ Resumption(ResumptionPsk) = 2u8,
+}
+
+#[derive(Clone, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PskGroupId(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub Vec<u8>,
+);
+
+impl Debug for PskGroupId {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskGroupId")
+ .fmt(f)
+ }
+}
+
+#[derive(Clone, Eq, Hash, PartialEq, PartialOrd, Ord, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct PskNonce(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub Vec<u8>,
+);
+
+impl Debug for PskNonce {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskNonce")
+ .fmt(f)
+ }
+}
+
+#[cfg(feature = "psk")]
+impl PskNonce {
+ pub fn random<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> Result<Self, <P as CipherSuiteProvider>::Error> {
+ Ok(Self(cipher_suite_provider.random_bytes_vec(
+ cipher_suite_provider.kdf_extract_size(),
+ )?))
+ }
+}
+
+#[derive(Clone, Debug, Eq, Hash, Ord, PartialOrd, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct ResumptionPsk {
+ pub usage: ResumptionPSKUsage,
+ pub psk_group_id: PskGroupId,
+ pub psk_epoch: u64,
+}
+
+#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub(crate) enum ResumptionPSKUsage {
+ Application = 1u8,
+ Reinit = 2u8,
+ Branch = 3u8,
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode)]
+struct PSKLabel<'a> {
+ id: &'a PreSharedKeyID,
+ index: u16,
+ count: u16,
+}
+
+#[cfg(any(test, feature = "external_client"))]
+#[derive(Clone, Copy, Debug)]
+pub(crate) struct AlwaysFoundPskStorage;
+
+#[cfg(any(test, feature = "external_client"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl PreSharedKeyStorage for AlwaysFoundPskStorage {
+ type Error = Infallible;
+
+ async fn get(&self, _: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(Some(vec![].into()))
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+
+ use super::PskNonce;
+ use mls_rs_core::crypto::CipherSuite;
+
+ #[cfg(not(mls_build_async))]
+ use mls_rs_core::{crypto::CipherSuiteProvider, psk::ExternalPskId};
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ #[cfg(not(mls_build_async))]
+ pub(crate) fn make_external_psk_id<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> ExternalPskId {
+ ExternalPskId::new(
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .unwrap(),
+ )
+ }
+
+ pub(crate) fn make_nonce(cipher_suite: CipherSuite) -> PskNonce {
+ PskNonce::random(&test_cipher_suite_provider(cipher_suite)).unwrap()
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use core::iter;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ use super::test_utils::make_nonce;
+
+ #[test]
+ fn random_generation_of_nonces_is_random() {
+ let good = TestCryptoProvider::all_supported_cipher_suites()
+ .into_iter()
+ .all(|cipher_suite| {
+ let nonce = make_nonce(cipher_suite);
+ iter::repeat_with(|| make_nonce(cipher_suite))
+ .take(1000)
+ .all(|other| other != nonce)
+ });
+
+ assert!(good);
+ }
+}
diff --git a/src/psk/resolver.rs b/src/psk/resolver.rs
new file mode 100644
index 0000000..0e3b7c9
--- /dev/null
+++ b/src/psk/resolver.rs
@@ -0,0 +1,95 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_core::{
+ crypto::CipherSuiteProvider,
+ error::IntoAnyError,
+ group::GroupStateStorage,
+ key_package::KeyPackageStorage,
+ psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage},
+};
+
+use crate::{
+ client::MlsError,
+ group::{epoch::EpochSecrets, state_repo::GroupStateRepository, GroupContext},
+ psk::secret::PskSecret,
+};
+
+use super::{secret::PskSecretInput, JustPreSharedKeyID, PreSharedKeyID, ResumptionPsk};
+
+pub(crate) struct PskResolver<'a, GS, K, PS>
+where
+ GS: GroupStateStorage,
+ PS: PreSharedKeyStorage,
+ K: KeyPackageStorage,
+{
+ pub group_context: Option<&'a GroupContext>,
+ pub current_epoch: Option<&'a EpochSecrets>,
+ pub prior_epochs: Option<&'a GroupStateRepository<GS, K>>,
+ pub psk_store: &'a PS,
+}
+
+impl<GS: GroupStateStorage, K: KeyPackageStorage, PS: PreSharedKeyStorage>
+ PskResolver<'_, GS, K, PS>
+{
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_resumption(&self, psk_id: &ResumptionPsk) -> Result<PreSharedKey, MlsError> {
+ if let Some(ctx) = self.group_context {
+ if ctx.epoch == psk_id.psk_epoch && ctx.group_id == psk_id.psk_group_id.0 {
+ let epoch = self.current_epoch.ok_or(MlsError::OldGroupStateNotFound)?;
+ return Ok(epoch.resumption_secret.clone());
+ }
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ if let Some(eps) = self.prior_epochs {
+ if let Some(psk) = eps.resumption_secret(psk_id).await? {
+ return Ok(psk);
+ }
+ }
+
+ Err(MlsError::OldGroupStateNotFound)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve_external(&self, psk_id: &ExternalPskId) -> Result<PreSharedKey, MlsError> {
+ self.psk_store
+ .get(psk_id)
+ .await
+ .map_err(|e| MlsError::PskStoreError(e.into_any_error()))?
+ .ok_or(MlsError::MissingRequiredPsk)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn resolve(&self, id: &[PreSharedKeyID]) -> Result<Vec<PskSecretInput>, MlsError> {
+ let mut secret_inputs = Vec::new();
+
+ for id in id {
+ let psk = match &id.key_id {
+ JustPreSharedKeyID::External(external) => self.resolve_external(external).await,
+ JustPreSharedKeyID::Resumption(resumption) => {
+ self.resolve_resumption(resumption).await
+ }
+ }?;
+
+ secret_inputs.push(PskSecretInput {
+ id: id.clone(),
+ psk,
+ })
+ }
+
+ Ok(secret_inputs)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn resolve_to_secret<P: CipherSuiteProvider>(
+ &self,
+ id: &[PreSharedKeyID],
+ cipher_suite_provider: &P,
+ ) -> Result<PskSecret, MlsError> {
+ let psk = self.resolve(id).await?;
+ PskSecret::calculate(&psk, cipher_suite_provider).await
+ }
+}
diff --git a/src/psk/secret.rs b/src/psk/secret.rs
new file mode 100644
index 0000000..4fe9cc8
--- /dev/null
+++ b/src/psk/secret.rs
@@ -0,0 +1,239 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_core::crypto::CipherSuiteProvider;
+use zeroize::Zeroizing;
+
+#[cfg(feature = "psk")]
+use mls_rs_codec::MlsEncode;
+
+#[cfg(feature = "psk")]
+use mls_rs_core::{error::IntoAnyError, psk::PreSharedKey};
+
+#[cfg(feature = "psk")]
+use crate::{
+ client::MlsError,
+ group::key_schedule::kdf_expand_with_label,
+ psk::{PSKLabel, PreSharedKeyID},
+};
+
+#[cfg(feature = "psk")]
+#[derive(Clone)]
+pub(crate) struct PskSecretInput {
+ pub id: PreSharedKeyID,
+ pub psk: PreSharedKey,
+}
+
+#[derive(PartialEq, Eq, Clone)]
+pub(crate) struct PskSecret(Zeroizing<Vec<u8>>);
+
+impl Debug for PskSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PskSecret")
+ .fmt(f)
+ }
+}
+
+#[cfg(test)]
+impl From<Vec<u8>> for PskSecret {
+ fn from(value: Vec<u8>) -> Self {
+ PskSecret(Zeroizing::new(value))
+ }
+}
+
+impl Deref for PskSecret {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl PskSecret {
+ pub(crate) fn new<P: CipherSuiteProvider>(provider: &P) -> PskSecret {
+ PskSecret(Zeroizing::new(vec![0u8; provider.kdf_extract_size()]))
+ }
+
+ #[cfg(feature = "psk")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn calculate<P: CipherSuiteProvider>(
+ input: &[PskSecretInput],
+ cipher_suite_provider: &P,
+ ) -> Result<PskSecret, MlsError> {
+ let len = u16::try_from(input.len()).map_err(|_| MlsError::TooManyPskIds)?;
+ let mut psk_secret = PskSecret::new(cipher_suite_provider);
+
+ for (index, psk_secret_input) in input.iter().enumerate() {
+ let index = index as u16;
+
+ let label = PSKLabel {
+ id: &psk_secret_input.id,
+ index,
+ count: len,
+ };
+
+ let psk_extracted = cipher_suite_provider
+ .kdf_extract(
+ &vec![0; cipher_suite_provider.kdf_extract_size()],
+ &psk_secret_input.psk,
+ )
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let psk_input = kdf_expand_with_label(
+ cipher_suite_provider,
+ &psk_extracted,
+ b"derived psk",
+ &label.mls_encode_to_vec()?,
+ None,
+ )
+ .await?;
+
+ psk_secret = cipher_suite_provider
+ .kdf_extract(&psk_input, &psk_secret)
+ .await
+ .map(PskSecret)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+ }
+
+ Ok(psk_secret)
+ }
+}
+
+#[cfg(feature = "psk")]
+#[cfg(test)]
+mod tests {
+ use alloc::vec::Vec;
+ #[cfg(not(mls_build_async))]
+ use core::iter;
+ use serde::{Deserialize, Serialize};
+
+ use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider,
+ psk::ExternalPskId,
+ psk::{JustPreSharedKeyID, PreSharedKeyID, PskNonce},
+ CipherSuiteProvider,
+ };
+
+ #[cfg(not(mls_build_async))]
+ use crate::{
+ crypto::test_utils::test_cipher_suite_provider, psk::test_utils::make_external_psk_id,
+ CipherSuite,
+ };
+
+ use super::{PskSecret, PskSecretInput};
+
+ #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
+ struct PskInfo {
+ #[serde(with = "hex::serde")]
+ id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ psk: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ nonce: Vec<u8>,
+ }
+
+ impl From<PskInfo> for PskSecretInput {
+ fn from(info: PskInfo) -> Self {
+ let id = PreSharedKeyID {
+ key_id: JustPreSharedKeyID::External(ExternalPskId::new(info.id)),
+ psk_nonce: PskNonce(info.nonce),
+ };
+
+ PskSecretInput {
+ id,
+ psk: info.psk.into(),
+ }
+ }
+ }
+
+ #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
+ struct TestScenario {
+ cipher_suite: u16,
+ psks: Vec<PskInfo>,
+ #[serde(with = "hex::serde")]
+ psk_secret: Vec<u8>,
+ }
+
+ impl TestScenario {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ #[cfg(not(mls_build_async))]
+ fn make_psk_list<CS: CipherSuiteProvider>(cs: &CS, n: usize) -> Vec<PskInfo> {
+ iter::repeat_with(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ || PskInfo {
+ id: make_external_psk_id(cs).to_vec(),
+ psk: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
+ nonce: crate::psk::test_utils::make_nonce(cs.cipher_suite()).0,
+ },
+ )
+ .take(n)
+ .collect::<Vec<_>>()
+ }
+
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate() -> Vec<TestScenario> {
+ CipherSuite::all()
+ .flat_map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |cs| (1..=10).map(move |n| (cs, n)),
+ )
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |(cs, n)| {
+ let provider = test_cipher_suite_provider(cs);
+ let psks = Self::make_psk_list(&provider, n);
+ let psk_secret = Self::compute_psk_secret(&provider, psks.clone());
+ TestScenario {
+ cipher_suite: cs.into(),
+ psks: psks.to_vec(),
+ psk_secret: psk_secret.to_vec(),
+ }
+ },
+ )
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate() -> Vec<TestScenario> {
+ panic!("Tests cannot be generated in async mode");
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn compute_psk_secret<P: CipherSuiteProvider>(
+ provider: &P,
+ psks: Vec<PskInfo>,
+ ) -> PskSecret {
+ let input = psks
+ .into_iter()
+ .map(PskSecretInput::from)
+ .collect::<Vec<_>>();
+
+ PskSecret::calculate(&input, provider).await.unwrap()
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn expected_psk_secret_is_produced() {
+ let scenarios: Vec<TestScenario> =
+ load_test_case_json!(psk_secret, TestScenario::generate());
+
+ for scenario in scenarios {
+ if let Some(provider) = try_test_cipher_suite_provider(scenario.cipher_suite) {
+ let computed =
+ TestScenario::compute_psk_secret(&provider, scenario.psks.clone()).await;
+
+ assert_eq!(scenario.psk_secret, computed.to_vec());
+ }
+ }
+ }
+}
diff --git a/src/signer.rs b/src/signer.rs
new file mode 100644
index 0000000..12970ec
--- /dev/null
+++ b/src/signer.rs
@@ -0,0 +1,357 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, SignaturePublicKey, SignatureSecretKey};
+
+#[derive(Clone, MlsSize, MlsEncode)]
+struct SignContent {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ content: Vec<u8>,
+}
+
+impl Debug for SignContent {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("SignContent")
+ .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
+ .field("content", &mls_rs_core::debug::pretty_bytes(&self.content))
+ .finish()
+ }
+}
+
+impl SignContent {
+ pub fn new(label: &str, content: Vec<u8>) -> Self {
+ Self {
+ label: [b"MLS 1.0 ", label.as_bytes()].concat(),
+ content,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+pub(crate) trait Signable<'a> {
+ const SIGN_LABEL: &'static str;
+
+ type SigningContext: Send + Sync;
+
+ fn signature(&self) -> &[u8];
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error>;
+
+ fn write_signature(&mut self, signature: Vec<u8>);
+
+ async fn sign<P: CipherSuiteProvider>(
+ &mut self,
+ signature_provider: &P,
+ signer: &SignatureSecretKey,
+ context: &Self::SigningContext,
+ ) -> Result<(), MlsError> {
+ let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
+
+ let signature = signature_provider
+ .sign(signer, &sign_content.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.write_signature(signature);
+
+ Ok(())
+ }
+
+ async fn verify<P: CipherSuiteProvider>(
+ &self,
+ signature_provider: &P,
+ public_key: &SignaturePublicKey,
+ context: &Self::SigningContext,
+ ) -> Result<(), MlsError> {
+ let sign_content = SignContent::new(Self::SIGN_LABEL, self.signable_content(context)?);
+
+ signature_provider
+ .verify(
+ public_key,
+ self.signature(),
+ &sign_content.mls_encode_to_vec()?,
+ )
+ .await
+ .map_err(|_| MlsError::InvalidSignature)
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+
+ use super::Signable;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct SignatureInteropTestCase {
+ #[serde(with = "hex::serde", rename = "priv")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde", rename = "pub")]
+ public: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ content: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ signature: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ sign_with_label: SignatureInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.sign_with_label.verify(&cs).await;
+ }
+ }
+ }
+
+ pub struct TestSignable {
+ pub content: Vec<u8>,
+ pub signature: Vec<u8>,
+ }
+
+ impl<'a> Signable<'a> for TestSignable {
+ const SIGN_LABEL: &'static str = "SignWithLabel";
+
+ type SigningContext = Vec<u8>;
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ Ok([context.as_slice(), self.content.as_slice()].concat())
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+ }
+
+ impl SignatureInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let public = self.public.clone().into();
+
+ let signable = TestSignable {
+ content: self.content.clone(),
+ signature: self.signature.clone(),
+ };
+
+ signable.verify(cs, &public, &vec![]).await.unwrap();
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{test_utils::TestSignable, *};
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ group::test_utils::random_bytes,
+ };
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ content: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signature: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ signer: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ public: Vec<u8>,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signer, public) = provider.signature_key_generate().await.unwrap();
+
+ let content = random_bytes(32);
+ let context = random_bytes(32);
+
+ let mut test_signable = TestSignable {
+ content: content.clone(),
+ signature: Vec::new(),
+ };
+
+ test_signable
+ .sign(&provider, &signer, &context)
+ .await
+ .unwrap();
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ content,
+ context,
+ signature: test_signable.signature,
+ signer: signer.to_vec(),
+ public: public.to_vec(),
+ });
+ }
+
+ test_cases
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(signatures, generate_test_cases().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(signatures, generate_test_cases())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_signatures() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(cipher_suite_provider) = try_test_cipher_suite_provider(one_case.cipher_suite)
+ else {
+ continue;
+ };
+
+ let public_key = SignaturePublicKey::from(one_case.public);
+
+ // Wasm uses incompatible signature secret key format
+ #[cfg(not(target_arch = "wasm32"))]
+ {
+ // Test signature generation
+ let mut test_signable = TestSignable {
+ content: one_case.content.clone(),
+ signature: Vec::new(),
+ };
+
+ let signature_key = SignatureSecretKey::from(one_case.signer);
+
+ test_signable
+ .sign(&cipher_suite_provider, &signature_key, &one_case.context)
+ .await
+ .unwrap();
+
+ test_signable
+ .verify(&cipher_suite_provider, &public_key, &one_case.context)
+ .await
+ .unwrap();
+ }
+
+ // Test verifying an existing signature
+ let test_signable = TestSignable {
+ content: one_case.content,
+ signature: one_case.signature,
+ };
+
+ test_signable
+ .verify(&cipher_suite_provider, &public_key, &one_case.context)
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_signature() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (correct_secret, _) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+ let (_, incorrect_public) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let mut test_signable = TestSignable {
+ content: random_bytes(32),
+ signature: vec![],
+ };
+
+ test_signable
+ .sign(&cipher_suite_provider, &correct_secret, &vec![])
+ .await
+ .unwrap();
+
+ let res = test_signable
+ .verify(&cipher_suite_provider, &incorrect_public, &vec![])
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_context() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (secret, public) = cipher_suite_provider
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let correct_context = random_bytes(32);
+ let incorrect_context = random_bytes(32);
+
+ let mut test_signable = TestSignable {
+ content: random_bytes(32),
+ signature: vec![],
+ };
+
+ test_signable
+ .sign(&cipher_suite_provider, &secret, &correct_context)
+ .await
+ .unwrap();
+
+ let res = test_signable
+ .verify(&cipher_suite_provider, &public, &incorrect_context)
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+}
diff --git a/src/storage_provider.rs b/src/storage_provider.rs
new file mode 100644
index 0000000..ffe8cd9
--- /dev/null
+++ b/src/storage_provider.rs
@@ -0,0 +1,14 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+/// Storage providers that operate completely in memory.
+pub mod in_memory;
+pub(crate) mod key_package;
+
+pub use key_package::*;
+
+#[cfg(feature = "sqlite")]
+#[cfg_attr(docsrs, doc(cfg(feature = "sqlite")))]
+/// SQLite based storage providers.
+pub mod sqlite;
diff --git a/src/storage_provider/group_state.rs b/src/storage_provider/group_state.rs
new file mode 100644
index 0000000..b6c854d
--- /dev/null
+++ b/src/storage_provider/group_state.rs
@@ -0,0 +1,43 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use mls_rs_codec::MlsEncode;
+pub use mls_rs_core::group::{EpochRecord, GroupState};
+
+use crate::group::snapshot::Snapshot;
+
+#[cfg(feature = "prior_epoch")]
+use crate::group::epoch::PriorEpoch;
+
+#[cfg(feature = "prior_epoch")]
+impl EpochRecord for PriorEpoch {
+ fn id(&self) -> u64 {
+ self.epoch_id()
+ }
+}
+
+impl GroupState for Snapshot {
+ fn id(&self) -> Vec<u8> {
+ self.group_id().to_vec()
+ }
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+pub(crate) struct EpochData {
+ pub(crate) id: u64,
+ pub(crate) data: Vec<u8>,
+}
+
+impl EpochData {
+ pub(crate) fn new<T>(value: T) -> Result<Self, mls_rs_codec::Error>
+ where
+ T: MlsEncode + EpochRecord,
+ {
+ Ok(Self {
+ id: value.id(),
+ data: value.mls_encode_to_vec()?,
+ })
+ }
+}
diff --git a/src/storage_provider/in_memory.rs b/src/storage_provider/in_memory.rs
new file mode 100644
index 0000000..cb8f5d7
--- /dev/null
+++ b/src/storage_provider/in_memory.rs
@@ -0,0 +1,11 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+mod group_state_storage;
+mod key_package_storage;
+mod psk_storage;
+
+pub use group_state_storage::*;
+pub use key_package_storage::*;
+pub use psk_storage::*;
diff --git a/src/storage_provider/in_memory/group_state_storage.rs b/src/storage_provider/in_memory/group_state_storage.rs
new file mode 100644
index 0000000..5999ed0
--- /dev/null
+++ b/src/storage_provider/in_memory/group_state_storage.rs
@@ -0,0 +1,354 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::collections::VecDeque;
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+use alloc::vec::Vec;
+use core::{
+ convert::Infallible,
+ fmt::{self, Debug},
+};
+use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use crate::client::MlsError;
+
+#[cfg(feature = "std")]
+use std::collections::{hash_map::Entry, HashMap};
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::{btree_map::Entry, BTreeMap};
+
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+pub(crate) const DEFAULT_EPOCH_RETENTION_LIMIT: usize = 3;
+
+#[derive(Clone)]
+pub(crate) struct InMemoryGroupData {
+ pub(crate) state_data: Vec<u8>,
+ pub(crate) epoch_data: VecDeque<EpochRecord>,
+}
+
+impl Debug for InMemoryGroupData {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryGroupData")
+ .field(
+ "state_data",
+ &mls_rs_core::debug::pretty_bytes(&self.state_data),
+ )
+ .field("epoch_data", &self.epoch_data)
+ .finish()
+ }
+}
+
+impl InMemoryGroupData {
+ pub fn new(state_data: Vec<u8>) -> InMemoryGroupData {
+ InMemoryGroupData {
+ state_data,
+ epoch_data: Default::default(),
+ }
+ }
+
+ fn get_epoch_data_index(&self, epoch_id: u64) -> Option<u64> {
+ self.epoch_data
+ .front()
+ .and_then(|e| epoch_id.checked_sub(e.id))
+ }
+
+ pub fn get_epoch(&self, epoch_id: u64) -> Option<&EpochRecord> {
+ self.get_epoch_data_index(epoch_id)
+ .and_then(|i| self.epoch_data.get(i as usize))
+ }
+
+ pub fn get_mut_epoch(&mut self, epoch_id: u64) -> Option<&mut EpochRecord> {
+ self.get_epoch_data_index(epoch_id)
+ .and_then(|i| self.epoch_data.get_mut(i as usize))
+ }
+
+ pub fn insert_epoch(&mut self, epoch: EpochRecord) {
+ self.epoch_data.push_back(epoch)
+ }
+
+ // This function does not fail if an update can't be made. If the epoch
+ // is not in the store, then it can no longer be accessed by future
+ // get_epoch calls and is no longer relevant.
+ pub fn update_epoch(&mut self, epoch: EpochRecord) {
+ if let Some(existing_epoch) = self.get_mut_epoch(epoch.id) {
+ *existing_epoch = epoch
+ }
+ }
+
+ pub fn trim_epochs(&mut self, max_epoch_retention: usize) {
+ while self.epoch_data.len() > max_epoch_retention {
+ self.epoch_data.pop_front();
+ }
+ }
+}
+
+#[derive(Clone)]
+/// In memory group state storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryGroupStateStorage {
+ #[cfg(feature = "std")]
+ pub(crate) inner: Arc<Mutex<HashMap<Vec<u8>, InMemoryGroupData>>>,
+ #[cfg(not(feature = "std"))]
+ pub(crate) inner: Arc<Mutex<BTreeMap<Vec<u8>, InMemoryGroupData>>>,
+ pub(crate) max_epoch_retention: usize,
+}
+
+impl Debug for InMemoryGroupStateStorage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryGroupStateStorage")
+ .field(
+ "inner",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.lock()
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .field("max_epoch_retention", &self.max_epoch_retention)
+ .finish()
+ }
+}
+
+impl InMemoryGroupStateStorage {
+ /// Create an empty group state storage.
+ pub fn new() -> Self {
+ Self {
+ inner: Default::default(),
+ max_epoch_retention: DEFAULT_EPOCH_RETENTION_LIMIT,
+ }
+ }
+
+ pub fn with_max_epoch_retention(self, max_epoch_retention: usize) -> Result<Self, MlsError> {
+ (max_epoch_retention > 0)
+ .then_some(())
+ .ok_or(MlsError::NonZeroRetentionRequired)?;
+
+ Ok(Self {
+ inner: self.inner,
+ max_epoch_retention,
+ })
+ }
+
+ /// Get the set of unique group ids that have data stored.
+ pub fn stored_groups(&self) -> Vec<Vec<u8>> {
+ self.lock().keys().cloned().collect()
+ }
+
+ /// Delete all data corresponding to `group_id`.
+ pub fn delete_group(&self, group_id: &[u8]) {
+ self.lock().remove(group_id);
+ }
+
+ #[cfg(feature = "std")]
+ fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, InMemoryGroupData>> {
+ self.inner.lock().unwrap()
+ }
+
+ #[cfg(not(feature = "std"))]
+ fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, InMemoryGroupData>> {
+ self.inner.lock()
+ }
+}
+
+impl Default for InMemoryGroupStateStorage {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl GroupStateStorage for InMemoryGroupStateStorage {
+ type Error = Infallible;
+
+ async fn max_epoch_id(&self, group_id: &[u8]) -> Result<Option<u64>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .and_then(|group_data| group_data.epoch_data.back().map(|e| e.id)))
+ }
+
+ async fn state(&self, group_id: &[u8]) -> Result<Option<Vec<u8>>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .map(|data| data.state_data.clone()))
+ }
+
+ async fn epoch(&self, group_id: &[u8], epoch_id: u64) -> Result<Option<Vec<u8>>, Self::Error> {
+ Ok(self
+ .lock()
+ .get(group_id)
+ .and_then(|data| data.get_epoch(epoch_id).map(|ep| ep.data.clone())))
+ }
+
+ async fn write(
+ &mut self,
+ state: GroupState,
+ epoch_inserts: Vec<EpochRecord>,
+ epoch_updates: Vec<EpochRecord>,
+ ) -> Result<(), Self::Error> {
+ let mut group_map = self.lock();
+
+ let group_data = match group_map.entry(state.id) {
+ Entry::Occupied(entry) => {
+ let data = entry.into_mut();
+ data.state_data = state.data;
+ data
+ }
+ Entry::Vacant(entry) => entry.insert(InMemoryGroupData::new(state.data)),
+ };
+
+ epoch_inserts
+ .into_iter()
+ .for_each(|e| group_data.insert_epoch(e));
+
+ epoch_updates
+ .into_iter()
+ .for_each(|e| group_data.update_epoch(e));
+
+ group_data.trim_epochs(self.max_epoch_retention);
+
+ Ok(())
+ }
+}
+
+#[cfg(all(test, feature = "prior_epoch"))]
+mod tests {
+ use alloc::{format, vec, vec::Vec};
+ use assert_matches::assert_matches;
+
+ use super::{InMemoryGroupData, InMemoryGroupStateStorage};
+ use crate::{client::MlsError, group::test_utils::TEST_GROUP};
+
+ use mls_rs_core::group::{EpochRecord, GroupState, GroupStateStorage};
+
+ impl InMemoryGroupStateStorage {
+ fn test_data(&self) -> InMemoryGroupData {
+ self.lock().get(TEST_GROUP).unwrap().clone()
+ }
+ }
+
+ fn test_storage(retention_limit: usize) -> Result<InMemoryGroupStateStorage, MlsError> {
+ InMemoryGroupStateStorage::new().with_max_epoch_retention(retention_limit)
+ }
+
+ fn test_epoch(epoch_id: u64) -> EpochRecord {
+ EpochRecord::new(epoch_id, format!("epoch {epoch_id}").as_bytes().to_vec())
+ }
+
+ fn test_snapshot(epoch_id: u64) -> GroupState {
+ GroupState {
+ id: TEST_GROUP.into(),
+ data: format!("snapshot {epoch_id}").as_bytes().to_vec(),
+ }
+ }
+
+ #[test]
+ fn test_zero_max_retention() {
+ assert_matches!(test_storage(0), Err(MlsError::NonZeroRetentionRequired))
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn existing_storage_can_have_larger_epoch_count() {
+ let mut storage = test_storage(2).unwrap();
+
+ let epoch_inserts = vec![test_epoch(0), test_epoch(1)];
+
+ storage
+ .write(test_snapshot(0), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 2);
+
+ storage.max_epoch_retention = 4;
+
+ let epoch_inserts = vec![test_epoch(3), test_epoch(4)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 4);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn existing_storage_can_have_smaller_epoch_count() {
+ let mut storage = test_storage(4).unwrap();
+
+ let epoch_inserts = vec![test_epoch(0), test_epoch(1), test_epoch(3), test_epoch(4)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 4);
+
+ storage.max_epoch_retention = 2;
+
+ let epoch_inserts = vec![test_epoch(5)];
+
+ storage
+ .write(test_snapshot(1), epoch_inserts, Vec::new())
+ .await
+ .unwrap();
+
+ assert_eq!(storage.test_data().epoch_data.len(), 2);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn epoch_insert_over_limit() {
+ test_epoch_insert_over_limit(false).await
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn epoch_insert_over_limit_with_update() {
+ test_epoch_insert_over_limit(true).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_epoch_insert_over_limit(with_update: bool) {
+ let mut storage = test_storage(1).unwrap();
+
+ let mut epoch_inserts = vec![test_epoch(0), test_epoch(1)];
+ let updates = with_update
+ .then_some(vec![test_epoch(0)])
+ .unwrap_or_default();
+ let snapshot = test_snapshot(1);
+
+ storage
+ .write(snapshot.clone(), epoch_inserts.clone(), updates)
+ .await
+ .unwrap();
+
+ let stored = storage.test_data();
+
+ assert_eq!(stored.state_data, snapshot.data);
+ assert_eq!(stored.epoch_data.len(), 1);
+
+ let expected = epoch_inserts.pop().unwrap();
+ assert_eq!(stored.epoch_data[0], expected);
+ }
+}
diff --git a/src/storage_provider/in_memory/key_package_storage.rs b/src/storage_provider/in_memory/key_package_storage.rs
new file mode 100644
index 0000000..427a8a4
--- /dev/null
+++ b/src/storage_provider/in_memory/key_package_storage.rs
@@ -0,0 +1,120 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use core::{
+ convert::Infallible,
+ fmt::{self, Debug},
+};
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+use alloc::vec::Vec;
+use mls_rs_core::key_package::{KeyPackageData, KeyPackageStorage};
+
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+#[derive(Clone, Default)]
+/// In memory key package storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryKeyPackageStorage {
+ #[cfg(feature = "std")]
+ inner: Arc<Mutex<HashMap<Vec<u8>, KeyPackageData>>>,
+ #[cfg(not(feature = "std"))]
+ inner: Arc<Mutex<BTreeMap<Vec<u8>, KeyPackageData>>>,
+}
+
+impl Debug for InMemoryKeyPackageStorage {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("InMemoryKeyPackageStorage")
+ .field(
+ "inner",
+ &mls_rs_core::debug::pretty_with(|f| {
+ f.debug_map()
+ .entries(
+ self.lock()
+ .iter()
+ .map(|(k, v)| (mls_rs_core::debug::pretty_bytes(k), v)),
+ )
+ .finish()
+ }),
+ )
+ .finish()
+ }
+}
+
+impl InMemoryKeyPackageStorage {
+ /// Create an empty key package storage.
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ /// Insert key package data.
+ pub fn insert(&self, id: Vec<u8>, pkg: KeyPackageData) {
+ self.lock().insert(id, pkg);
+ }
+
+ /// Get a key package data by `id`.
+ pub fn get(&self, id: &[u8]) -> Option<KeyPackageData> {
+ self.lock().get(id).cloned()
+ }
+
+ /// Delete key package data by `id`.
+ pub fn delete(&self, id: &[u8]) {
+ self.lock().remove(id);
+ }
+
+ /// Get all key packages that are currently stored.
+ pub fn key_packages(&self) -> Vec<(Vec<u8>, KeyPackageData)> {
+ self.lock()
+ .iter()
+ .map(|(k, v)| (k.clone(), v.clone()))
+ .collect()
+ }
+
+ #[cfg(feature = "std")]
+ fn lock(&self) -> std::sync::MutexGuard<'_, HashMap<Vec<u8>, KeyPackageData>> {
+ self.inner.lock().unwrap()
+ }
+
+ #[cfg(not(feature = "std"))]
+ fn lock(&self) -> spin::mutex::MutexGuard<'_, BTreeMap<Vec<u8>, KeyPackageData>> {
+ self.inner.lock()
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl KeyPackageStorage for InMemoryKeyPackageStorage {
+ type Error = Infallible;
+
+ async fn delete(&mut self, id: &[u8]) -> Result<(), Self::Error> {
+ (*self).delete(id);
+ Ok(())
+ }
+
+ async fn insert(&mut self, id: Vec<u8>, pkg: KeyPackageData) -> Result<(), Self::Error> {
+ (*self).insert(id, pkg);
+ Ok(())
+ }
+
+ async fn get(&self, id: &[u8]) -> Result<Option<KeyPackageData>, Self::Error> {
+ Ok(self.get(id))
+ }
+}
diff --git a/src/storage_provider/in_memory/psk_storage.rs b/src/storage_provider/in_memory/psk_storage.rs
new file mode 100644
index 0000000..e1b0b75
--- /dev/null
+++ b/src/storage_provider/in_memory/psk_storage.rs
@@ -0,0 +1,83 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(target_has_atomic = "ptr")]
+use alloc::sync::Arc;
+
+#[cfg(not(target_has_atomic = "ptr"))]
+use portable_atomic_util::Arc;
+
+use core::convert::Infallible;
+
+#[cfg(feature = "std")]
+use std::collections::HashMap;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeMap;
+
+use mls_rs_core::psk::{ExternalPskId, PreSharedKey, PreSharedKeyStorage};
+
+#[cfg(mls_build_async)]
+use alloc::boxed::Box;
+#[cfg(feature = "std")]
+use std::sync::Mutex;
+
+#[cfg(not(feature = "std"))]
+use spin::Mutex;
+
+#[derive(Clone, Debug, Default)]
+/// In memory pre-shared key storage backed by a HashMap.
+///
+/// All clones of an instance of this type share the same underlying HashMap.
+pub struct InMemoryPreSharedKeyStorage {
+ #[cfg(feature = "std")]
+ inner: Arc<Mutex<HashMap<ExternalPskId, PreSharedKey>>>,
+ #[cfg(not(feature = "std"))]
+ inner: Arc<Mutex<BTreeMap<ExternalPskId, PreSharedKey>>>,
+}
+
+impl InMemoryPreSharedKeyStorage {
+ /// Insert a pre-shared key into storage.
+ pub fn insert(&mut self, id: ExternalPskId, psk: PreSharedKey) {
+ #[cfg(feature = "std")]
+ let mut lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let mut lock = self.inner.lock();
+
+ lock.insert(id, psk);
+ }
+
+ /// Get a pre-shared key by `id`.
+ pub fn get(&self, id: &ExternalPskId) -> Option<PreSharedKey> {
+ #[cfg(feature = "std")]
+ let lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let lock = self.inner.lock();
+
+ lock.get(id).cloned()
+ }
+
+ /// Delete a pre-shared key from storage.
+ pub fn delete(&mut self, id: &ExternalPskId) {
+ #[cfg(feature = "std")]
+ let mut lock = self.inner.lock().unwrap();
+
+ #[cfg(not(feature = "std"))]
+ let mut lock = self.inner.lock();
+
+ lock.remove(id);
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+impl PreSharedKeyStorage for InMemoryPreSharedKeyStorage {
+ type Error = Infallible;
+
+ async fn get(&self, id: &ExternalPskId) -> Result<Option<PreSharedKey>, Self::Error> {
+ Ok(self.get(id))
+ }
+}
diff --git a/src/storage_provider/key_package.rs b/src/storage_provider/key_package.rs
new file mode 100644
index 0000000..1e209fb
--- /dev/null
+++ b/src/storage_provider/key_package.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::key_package::KeyPackageData;
diff --git a/src/storage_provider/sqlite.rs b/src/storage_provider/sqlite.rs
new file mode 100644
index 0000000..f4e4f1f
--- /dev/null
+++ b/src/storage_provider/sqlite.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_provider_sqlite::*;
diff --git a/src/test_utils/benchmarks.rs b/src/test_utils/benchmarks.rs
new file mode 100644
index 0000000..93d8964
--- /dev/null
+++ b/src/test_utils/benchmarks.rs
@@ -0,0 +1,140 @@
+use mls_rs_codec::MlsEncode;
+use mls_rs_core::protocol_version::ProtocolVersion;
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client_builder::{BaseConfig, MlsConfig, WithCryptoProvider, WithIdentityProvider},
+ group::{framing::MlsMessage, Group},
+ identity::basic::BasicIdentityProvider,
+ test_utils::{generate_basic_client, get_test_groups},
+};
+
+pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
+
+pub type TestClientConfig =
+ WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
+
+macro_rules! load_test_case_mls {
+ ($name:ident, $generate:expr) => {
+ load_test_case_mls!($name, $generate, to_vec_pretty)
+ };
+ ($name:ident, $generate:expr, $to_json:ident) => {{
+ #[cfg(any(target_arch = "wasm32", not(feature = "std")))]
+ {
+ // Do not remove `async`! (The goal of this line is to remove warnings
+ // about `$generate` not being used. Actually calling it will make tests fail.)
+ let _ = async { $generate };
+
+ mls_rs_codec::MlsDecode::mls_decode(&mut &include_bytes!(concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".mls"
+ )))
+ .unwrap()
+ }
+
+ #[cfg(all(not(target_arch = "wasm32"), feature = "std"))]
+ {
+ let path = concat!(
+ env!("CARGO_MANIFEST_DIR"),
+ "/test_data/",
+ stringify!($name),
+ ".mls"
+ );
+
+ if !std::path::Path::new(path).exists() {
+ std::fs::write(path, $generate.mls_encode_to_vec().unwrap()).unwrap();
+ }
+
+ mls_rs_codec::MlsDecode::mls_decode(&mut std::fs::read(path).unwrap().as_slice())
+ .unwrap()
+ }
+ }};
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn generate_test_cases(cs: CipherSuite) -> Vec<MlsMessage> {
+ let mut cases = Vec::new();
+
+ for size in [16, 64, 128] {
+ let group = get_test_groups(
+ ProtocolVersion::MLS_10,
+ cs,
+ size,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ )
+ .await
+ .pop()
+ .unwrap();
+
+ let group_info = group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ cases.push(group_info)
+ }
+
+ cases
+}
+
+#[derive(Clone)]
+pub struct GroupStates<C: MlsConfig> {
+ pub sender: Group<C>,
+ pub receiver: Group<C>,
+}
+
+#[cfg(mls_build_async)]
+pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
+ let group_info = load_test_case_mls!(group_state, block_on(generate_test_cases(cs)), to_vec);
+ join_group(cs, group_info)
+}
+
+#[cfg(not(mls_build_async))]
+pub fn load_group_states(cs: CipherSuite) -> Vec<GroupStates<impl MlsConfig>> {
+ let group_infos: Vec<MlsMessage> =
+ load_test_case_mls!(group_state, generate_test_cases(cs), to_vec);
+
+ group_infos
+ .into_iter()
+ .map(|info| join_group(cs, info))
+ .collect()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn join_group(cs: CipherSuite, group_info: MlsMessage) -> GroupStates<impl MlsConfig> {
+ let client = generate_basic_client(
+ cs,
+ ProtocolVersion::MLS_10,
+ 99999999999,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ None,
+ );
+
+ let mut sender = client.commit_external(group_info).await.unwrap().0;
+
+ let client = generate_basic_client(
+ cs,
+ ProtocolVersion::MLS_10,
+ 99999999998,
+ None,
+ false,
+ &MlsCryptoProvider::new(),
+ None,
+ );
+
+ let group_info = sender
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (receiver, commit) = client.commit_external(group_info).await.unwrap();
+ sender.process_incoming_message(commit).await.unwrap();
+
+ GroupStates { sender, receiver }
+}
diff --git a/src/test_utils/fuzz_tests.rs b/src/test_utils/fuzz_tests.rs
new file mode 100644
index 0000000..9ec143e
--- /dev/null
+++ b/src/test_utils/fuzz_tests.rs
@@ -0,0 +1,109 @@
+use std::sync::Mutex;
+
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, CryptoProvider, SignatureSecretKey},
+ identity::BasicCredential,
+};
+
+use once_cell::sync::Lazy;
+
+use crate::{
+ cipher_suite::CipherSuite,
+ client::MlsError,
+ client_builder::{BaseConfig, WithCryptoProvider, WithIdentityProvider},
+ group::{
+ framing::{Content, MlsMessage, Sender, WireFormat},
+ message_processor::MessageProcessor,
+ message_signature::AuthenticatedContent,
+ Commit, Group,
+ },
+ identity::{basic::BasicIdentityProvider, SigningIdentity},
+ Client, ExtensionList,
+};
+
+#[cfg(awslc)]
+pub use mls_rs_crypto_awslc::AwsLcCryptoProvider as MlsCryptoProvider;
+#[cfg(not(any(awslc, rustcrypto)))]
+pub use mls_rs_crypto_openssl::OpensslCryptoProvider as MlsCryptoProvider;
+#[cfg(rustcrypto)]
+pub use mls_rs_crypto_rustcrypto::RustCryptoProvider as MlsCryptoProvider;
+
+pub type TestClientConfig =
+ WithIdentityProvider<BasicIdentityProvider, WithCryptoProvider<MlsCryptoProvider, BaseConfig>>;
+
+pub static GROUP: Lazy<Mutex<Group<TestClientConfig>>> = Lazy::new(|| Mutex::new(create_group()));
+
+pub fn create_group() -> Group<TestClientConfig> {
+ let cipher_suite = CipherSuite::CURVE25519_AES128;
+ let alice = make_client(cipher_suite, "alice");
+ let bob = make_client(cipher_suite, "bob");
+
+ let mut alice = alice.create_group(ExtensionList::new()).unwrap();
+
+ alice
+ .commit_builder()
+ .add_member(bob.generate_key_package_message().unwrap())
+ .unwrap()
+ .build()
+ .unwrap();
+
+ alice.apply_pending_commit().unwrap();
+
+ alice
+}
+
+pub fn create_fuzz_commit_message(
+ group_id: Vec<u8>,
+ epoch: u64,
+ authenticated_data: Vec<u8>,
+) -> Result<MlsMessage, MlsError> {
+ let mut group = GROUP.lock().unwrap();
+
+ let mut context = group.context().clone();
+ context.group_id = group_id;
+ context.epoch = epoch;
+
+ #[cfg(feature = "private_message")]
+ let wire_format = WireFormat::PrivateMessage;
+
+ #[cfg(not(feature = "private_message"))]
+ let wire_format = WireFormat::PublicMessage;
+
+ let auth_content = AuthenticatedContent::new_signed(
+ group.cipher_suite_provider(),
+ &context,
+ Sender::Member(0),
+ Content::Commit(alloc::boxed::Box::new(Commit {
+ proposals: Vec::new(),
+ path: None,
+ })),
+ &group.signer,
+ wire_format,
+ authenticated_data,
+ )?;
+
+ group.format_for_wire(auth_content)
+}
+
+fn make_client(cipher_suite: CipherSuite, name: &str) -> Client<TestClientConfig> {
+ let (secret, signing_identity) = make_identity(cipher_suite, name);
+
+ // TODO : consider fuzzing on encrypted controls (doesn't seem very useful)
+ Client::builder()
+ .identity_provider(BasicIdentityProvider)
+ .crypto_provider(MlsCryptoProvider::default())
+ .signing_identity(signing_identity, secret, cipher_suite)
+ .build()
+}
+
+fn make_identity(cipher_suite: CipherSuite, name: &str) -> (SignatureSecretKey, SigningIdentity) {
+ let cipher_suite = MlsCryptoProvider::new()
+ .cipher_suite_provider(cipher_suite)
+ .unwrap();
+
+ let (secret, public) = cipher_suite.signature_key_generate().unwrap();
+ let basic_identity = BasicCredential::new(name.as_bytes().to_vec());
+ let signing_identity = SigningIdentity::new(basic_identity.into_credential(), public);
+
+ (secret, signing_identity)
+}
diff --git a/src/test_utils/mod.rs b/src/test_utils/mod.rs
new file mode 100644
index 0000000..d7c238b
--- /dev/null
+++ b/src/test_utils/mod.rs
@@ -0,0 +1,184 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(all(feature = "benchmark_util", not(mls_build_async)))]
+pub mod benchmarks;
+
+#[cfg(all(feature = "fuzz_util", not(mls_build_async)))]
+pub mod fuzz_tests;
+
+use mls_rs_core::{
+ crypto::{CipherSuite, CipherSuiteProvider, CryptoProvider},
+ identity::{BasicCredential, Credential, SigningIdentity},
+ protocol_version::ProtocolVersion,
+ psk::ExternalPskId,
+};
+
+use crate::{
+ client_builder::{ClientBuilder, MlsConfig},
+ identity::basic::BasicIdentityProvider,
+ mls_rules::{CommitOptions, DefaultMlsRules},
+ tree_kem::Lifetime,
+ Client, Group, MlsMessage,
+};
+
+#[cfg(feature = "private_message")]
+use crate::group::{mls_rules::EncryptionOptions, padding::PaddingMode};
+
+use alloc::{vec, vec::Vec};
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub fn get_test_basic_credential(identity: Vec<u8>) -> Credential {
+ BasicCredential::new(identity).into_credential()
+}
+
+pub const TEST_EXT_PSK_ID: &[u8] = b"external psk";
+
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub fn make_test_ext_psk() -> Vec<u8> {
+ b"secret psk key".to_vec()
+}
+
+pub fn is_edwards(cs: u16) -> bool {
+ [
+ CipherSuite::CURVE25519_AES128,
+ CipherSuite::CURVE25519_CHACHA,
+ CipherSuite::CURVE448_AES256,
+ CipherSuite::CURVE448_CHACHA,
+ ]
+ .contains(&cs.into())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn generate_basic_client<C: CryptoProvider + Clone>(
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ id: usize,
+ commit_options: Option<CommitOptions>,
+ #[cfg(feature = "private_message")] encrypt_controls: bool,
+ #[cfg(not(feature = "private_message"))] _encrypt_controls: bool,
+ crypto: &C,
+ lifetime: Option<Lifetime>,
+) -> Client<impl MlsConfig> {
+ let cs = crypto.cipher_suite_provider(cipher_suite).unwrap();
+
+ let (secret_key, public_key) = cs.signature_key_generate().await.unwrap();
+ let credential = get_test_basic_credential(alloc::format!("{id}").into_bytes());
+
+ let identity = SigningIdentity::new(credential, public_key);
+
+ let mls_rules =
+ DefaultMlsRules::default().with_commit_options(commit_options.unwrap_or_default());
+
+ #[cfg(feature = "private_message")]
+ let mls_rules = if encrypt_controls {
+ mls_rules.with_encryption_options(EncryptionOptions::new(true, PaddingMode::None))
+ } else {
+ mls_rules
+ };
+
+ let mut builder = ClientBuilder::new()
+ .crypto_provider(crypto.clone())
+ .identity_provider(BasicIdentityProvider::new())
+ .mls_rules(mls_rules)
+ .psk(
+ ExternalPskId::new(TEST_EXT_PSK_ID.to_vec()),
+ make_test_ext_psk().into(),
+ )
+ .used_protocol_version(protocol_version)
+ .signing_identity(identity, secret_key, cipher_suite);
+
+ if let Some(lifetime) = lifetime {
+ builder = builder
+ .key_package_lifetime(lifetime.not_after - lifetime.not_before)
+ .key_package_not_before(lifetime.not_before);
+ }
+
+ builder.build()
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn get_test_groups<C: CryptoProvider + Clone>(
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_participants: usize,
+ commit_options: Option<CommitOptions>,
+ encrypt_controls: bool,
+ crypto: &C,
+) -> Vec<Group<impl MlsConfig>> {
+ // Create the group with Alice as the group initiator
+ let creator = generate_basic_client(
+ cipher_suite,
+ version,
+ 0,
+ commit_options,
+ encrypt_controls,
+ crypto,
+ None,
+ )
+ .await;
+
+ let mut creator_group = creator.create_group(Default::default()).await.unwrap();
+
+ let mut receiver_clients = Vec::new();
+ let mut commit_builder = creator_group.commit_builder();
+
+ for i in 1..num_participants {
+ let client = generate_basic_client(
+ cipher_suite,
+ version,
+ i,
+ commit_options,
+ encrypt_controls,
+ crypto,
+ None,
+ )
+ .await;
+ let kp = client.generate_key_package_message().await.unwrap();
+
+ receiver_clients.push(client);
+ commit_builder = commit_builder.add_member(kp.clone()).unwrap();
+ }
+
+ let welcome = commit_builder.build().await.unwrap().welcome_messages;
+
+ creator_group.apply_pending_commit().await.unwrap();
+
+ let tree_data = creator_group.export_tree().into_owned();
+
+ let mut groups = vec![creator_group];
+
+ for client in &receiver_clients {
+ let (test_client, _info) = client
+ .join_group(Some(tree_data.clone()), &welcome[0])
+ .await
+ .unwrap();
+
+ groups.push(test_client);
+ }
+
+ groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+pub async fn all_process_message<C: MlsConfig>(
+ groups: &mut [Group<C>],
+ message: &MlsMessage,
+ sender: usize,
+ is_commit: bool,
+) {
+ for group in groups {
+ if sender != group.current_member_index() as usize {
+ group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+ } else if is_commit {
+ group.apply_pending_commit().await.unwrap();
+ }
+ }
+}
diff --git a/src/tree_kem/capabilities.rs b/src/tree_kem/capabilities.rs
new file mode 100644
index 0000000..6fc498d
--- /dev/null
+++ b/src/tree_kem/capabilities.rs
@@ -0,0 +1,5 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+pub use mls_rs_core::group::Capabilities;
diff --git a/src/tree_kem/hpke_encryption.rs b/src/tree_kem/hpke_encryption.rs
new file mode 100644
index 0000000..77a598a
--- /dev/null
+++ b/src/tree_kem/hpke_encryption.rs
@@ -0,0 +1,172 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsEncode, MlsSize};
+use mls_rs_core::{
+ crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey, HpkeSecretKey},
+ error::IntoAnyError,
+};
+use zeroize::Zeroizing;
+
+use crate::client::MlsError;
+
+#[derive(Clone, MlsSize, MlsEncode)]
+struct EncryptContext<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ label: Vec<u8>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ context: &'a [u8],
+}
+
+impl Debug for EncryptContext<'_> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("EncryptContext")
+ .field("label", &mls_rs_core::debug::pretty_bytes(&self.label))
+ .field("context", &mls_rs_core::debug::pretty_bytes(self.context))
+ .finish()
+ }
+}
+
+impl<'a> EncryptContext<'a> {
+ pub fn new(label: &str, context: &'a [u8]) -> Self {
+ Self {
+ label: [b"MLS 1.0 ", label.as_bytes()].concat(),
+ context,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(all(target_arch = "wasm32", mls_build_async), maybe_async::must_be_async(?Send))]
+#[cfg_attr(
+ all(not(target_arch = "wasm32"), mls_build_async),
+ maybe_async::must_be_async
+)]
+
+pub(crate) trait HpkeEncryptable: Sized {
+ const ENCRYPT_LABEL: &'static str;
+
+ async fn encrypt<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ public_key: &HpkePublicKey,
+ context: &[u8],
+ ) -> Result<HpkeCiphertext, MlsError> {
+ let context = EncryptContext::new(Self::ENCRYPT_LABEL, context)
+ .mls_encode_to_vec()
+ .map(Zeroizing::new)?;
+
+ let content = self.get_bytes().map(Zeroizing::new)?;
+
+ cipher_suite_provider
+ .hpke_seal(public_key, &context, None, &content)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ async fn decrypt<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ secret_key: &HpkeSecretKey,
+ public_key: &HpkePublicKey,
+ context: &[u8],
+ ciphertext: &HpkeCiphertext,
+ ) -> Result<Self, MlsError> {
+ let context = EncryptContext::new(Self::ENCRYPT_LABEL, context).mls_encode_to_vec()?;
+
+ let plaintext = cipher_suite_provider
+ .hpke_open(ciphertext, secret_key, public_key, &context, None)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Self::from_bytes(plaintext.to_vec())
+ }
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError>;
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError>;
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::{string::String, vec::Vec};
+ use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+ use mls_rs_core::crypto::{CipherSuiteProvider, HpkeCiphertext};
+
+ use crate::{client::MlsError, crypto::test_utils::try_test_cipher_suite_provider};
+
+ use super::HpkeEncryptable;
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct HpkeInteropTestCase {
+ #[serde(with = "hex::serde", rename = "priv")]
+ secret: Vec<u8>,
+ #[serde(with = "hex::serde", rename = "pub")]
+ public: Vec<u8>,
+ label: String,
+ #[serde(with = "hex::serde")]
+ context: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ plaintext: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ kem_output: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ ciphertext: Vec<u8>,
+ }
+
+ #[derive(Debug, serde::Serialize, serde::Deserialize)]
+ pub struct InteropTestCase {
+ cipher_suite: u16,
+ encrypt_with_label: HpkeInteropTestCase,
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_crypto_test_vectors() {
+ // The test vector can be found here https://github.com/mlswg/mls-implementations/blob/main/test-vectors/crypto-basics.json
+ let test_cases: Vec<InteropTestCase> =
+ load_test_case_json!(basic_crypto, Vec::<InteropTestCase>::new());
+
+ for test_case in test_cases {
+ if let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) {
+ test_case.encrypt_with_label.verify(&cs).await
+ }
+ }
+ }
+
+ #[derive(Clone, Debug, MlsSize, MlsEncode, MlsDecode)]
+ struct TestEncryptable(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+ impl HpkeEncryptable for TestEncryptable {
+ const ENCRYPT_LABEL: &'static str = "EncryptWithLabel";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Ok(Self(bytes))
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.0.clone())
+ }
+ }
+
+ impl HpkeInteropTestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn verify<P: CipherSuiteProvider>(&self, cs: &P) {
+ let secret = self.secret.clone().into();
+ let public = self.public.clone().into();
+
+ let ciphertext = HpkeCiphertext {
+ kem_output: self.kem_output.clone(),
+ ciphertext: self.ciphertext.clone(),
+ };
+
+ let computed_plaintext =
+ TestEncryptable::decrypt(cs, &secret, &public, &self.context, &ciphertext)
+ .await
+ .unwrap();
+
+ assert_eq!(&computed_plaintext.0, &self.plaintext)
+ }
+ }
+}
diff --git a/src/tree_kem/interop_test_vectors.rs b/src/tree_kem/interop_test_vectors.rs
new file mode 100644
index 0000000..50e0077
--- /dev/null
+++ b/src/tree_kem/interop_test_vectors.rs
@@ -0,0 +1,199 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+use mls_rs_codec::{MlsDecode, MlsEncode};
+use mls_rs_core::crypto::{CipherSuite, CipherSuiteProvider};
+
+use itertools::Itertools;
+
+use crate::{
+ crypto::test_utils::try_test_cipher_suite_provider, identity::basic::BasicIdentityProvider,
+};
+
+use super::{
+ node::NodeVec, test_utils::TreeWithSigners, tree_validator::TreeValidator, TreeKemPublic,
+};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct ValidationTestCase {
+ pub cipher_suite: u16,
+
+ #[serde(with = "hex::serde")]
+ pub tree: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub group_id: Vec<u8>,
+ pub tree_hashes: Vec<TreeHash>,
+ pub resolutions: Vec<Vec<u32>>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+struct TreeHash(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl From<crate::tree_kem::tree_hash::TreeHash> for TreeHash {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn from(value: crate::tree_kem::tree_hash::TreeHash) -> Self {
+ TreeHash(value.to_vec())
+ }
+}
+
+impl ValidationTestCase {
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn new<P: CipherSuiteProvider>(tree: TreeKemPublic, group_id: &[u8], cs: &P) -> Self {
+ let tree_size = tree.total_leaf_count() * 2 - 1;
+
+ assert!(
+ tree.tree_hashes.current.len() == tree_size as usize,
+ "hashes not initialized"
+ );
+
+ let resolutions = (0..tree_size)
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |i| tree.nodes.get_resolution_index(i).unwrap(),
+ )
+ .collect();
+
+ Self {
+ cipher_suite: cs.cipher_suite().into(),
+ tree: tree.nodes.mls_encode_to_vec().unwrap(),
+ tree_hashes: tree
+ .tree_hashes
+ .current
+ .into_iter()
+ .map(TreeHash::from)
+ .collect(),
+ group_id: group_id.to_vec(),
+ resolutions,
+ }
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn validation() {
+ use crate::group::test_utils::get_test_group_context;
+
+ #[cfg(mls_build_async)]
+ let test_cases: Vec<ValidationTestCase> = load_test_case_json!(
+ interop_tree_validation,
+ generate_validation_test_vector().await
+ );
+
+ #[cfg(not(mls_build_async))]
+ let test_cases: Vec<ValidationTestCase> =
+ load_test_case_json!(interop_tree_validation, generate_validation_test_vector());
+
+ for test_case in test_cases.into_iter() {
+ let Some(cs) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut tree = TreeKemPublic::import_node_data(
+ NodeVec::mls_decode(&mut &*test_case.tree).unwrap(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let tree_hash = tree.tree_hash(&cs).await.unwrap();
+
+ tree.tree_hashes
+ .current
+ .iter()
+ .zip_eq(test_case.tree_hashes.iter())
+ .for_each(|(l, r)| assert_eq!(**l, *r.0));
+
+ test_case
+ .resolutions
+ .iter()
+ .enumerate()
+ .for_each(|(i, res)| {
+ assert_eq!(&tree.nodes.get_resolution_index(i as u32).unwrap(), res)
+ });
+
+ let mut context = get_test_group_context(1, test_case.cipher_suite.into()).await;
+ context.tree_hash = tree_hash;
+ context.group_id = test_case.group_id;
+
+ TreeValidator::new(&cs, &context, &BasicIdentityProvider)
+ .validate(&mut tree)
+ .await
+ .unwrap();
+ }
+}
+
+#[cfg(feature = "rfc_compliant")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+#[cfg_attr(coverage_nightly, coverage(off))]
+async fn generate_validation_test_vector() -> Vec<ValidationTestCase> {
+ let mut test_cases = vec![];
+
+ for cs in CipherSuite::all() {
+ let Some(cs) = try_test_cipher_suite_provider(*cs) else {
+ continue;
+ };
+
+ let mut trees = vec![];
+
+ // Generate trees with increasing complexity. Start: full complete trees
+ for n_leaves in [2, 4, 8, 32] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Internal blanks, no skipping : 8 leaves, 0 commits removing 2, 3 and adding new member
+ let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
+ tree.remove_member(2);
+ tree.remove_member(3);
+ tree.add_member("Bob", &cs).await;
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Blanks at the end, no skipping
+ for n_leaves in [3, 5, 7, 33] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Internal blanks, with skipping : 8 leaves, 0 commits removing 1, 2, 3
+ let mut tree = TreeWithSigners::make_full_tree(8, &cs).await;
+ [1, 2, 3].into_iter().for_each(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |i| tree.remove_member(i),
+ );
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Blanks at the end, with skipping
+ for n_leaves in [6, 34] {
+ trees.push(TreeWithSigners::make_full_tree(n_leaves, &cs).await);
+ }
+
+ // Unmerged leaves, no skipping : 7 leaves; 0 commits adding a member
+ let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
+ tree.add_member("Bob", &cs).await;
+ tree.update_committer_path(0, &cs).await;
+ trees.push(tree);
+
+ // Unmerged leaves, with skipping : figure 20 in the RFC
+ let mut tree = TreeWithSigners::make_full_tree(7, &cs).await;
+ tree.remove_member(5);
+ tree.update_committer_path(0, &cs).await;
+ tree.update_committer_path(4, &cs).await;
+ tree.add_member("Bob", &cs).await;
+ tree.tree.tree_hashes.current = vec![];
+ tree.tree.tree_hash(&cs).await.unwrap();
+ trees.push(tree);
+
+ // Generate tests
+ trees.into_iter().for_each(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |tree| test_cases.push(ValidationTestCase::new(tree.tree, &tree.group_id, &cs)),
+ );
+ }
+
+ test_cases
+}
diff --git a/src/tree_kem/kem.rs b/src/tree_kem/kem.rs
new file mode 100644
index 0000000..cedeb0e
--- /dev/null
+++ b/src/tree_kem/kem.rs
@@ -0,0 +1,699 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, SignatureSecretKey};
+use crate::group::GroupContext;
+use crate::identity::SigningIdentity;
+use crate::iter::wrap_iter;
+use crate::tree_kem::math as tree_math;
+use alloc::vec;
+use alloc::vec::Vec;
+use itertools::Itertools;
+use mls_rs_codec::MlsEncode;
+use tree_math::{CopathNode, TreeIndex};
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+use super::hpke_encryption::HpkeEncryptable;
+use super::leaf_node::ConfigProperties;
+use super::node::NodeTypeResolver;
+use super::{
+ node::{LeafIndex, NodeIndex},
+ path_secret::{PathSecret, PathSecretGenerator},
+ TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath,
+};
+
+#[cfg(test)]
+use crate::{group::CommitModifiers, signer::Signable};
+
+pub struct TreeKem<'a> {
+ tree_kem_public: &'a mut TreeKemPublic,
+ private_key: &'a mut TreeKemPrivate,
+}
+
+pub struct EncapGeneration {
+ pub update_path: UpdatePath,
+ pub path_secrets: Vec<Option<PathSecret>>,
+ pub commit_secret: PathSecret,
+}
+
+impl<'a> TreeKem<'a> {
+ pub fn new(
+ tree_kem_public: &'a mut TreeKemPublic,
+ private_key: &'a mut TreeKemPrivate,
+ ) -> Self {
+ TreeKem {
+ tree_kem_public,
+ private_key,
+ }
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn encap<P>(
+ self,
+ context: &mut GroupContext,
+ excluding: &[LeafIndex],
+ signer: &SignatureSecretKey,
+ update_leaf_properties: ConfigProperties,
+ signing_identity: Option<SigningIdentity>,
+ cipher_suite_provider: &P,
+ #[cfg(test)] commit_modifiers: &CommitModifiers,
+ ) -> Result<EncapGeneration, MlsError>
+ where
+ P: CipherSuiteProvider + Send + Sync,
+ {
+ let self_index = self.private_key.self_index;
+ let path = self.tree_kem_public.nodes.direct_copath(self_index);
+ let filtered = self.tree_kem_public.nodes.filtered(self_index)?;
+
+ self.private_key.secret_keys.resize(path.len() + 1, None);
+
+ let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider);
+ let mut path_secrets = vec![];
+
+ for (i, (node, f)) in path.iter().zip(&filtered).enumerate() {
+ if !f {
+ let secret = secret_generator.next_secret().await?;
+
+ let (secret_key, public_key) =
+ secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ self.private_key.secret_keys[i + 1] = Some(secret_key);
+ self.tree_kem_public.update_node(public_key, node.path)?;
+ path_secrets.push(Some(secret));
+ } else {
+ self.private_key.secret_keys[i + 1] = None;
+ path_secrets.push(None);
+ }
+ }
+
+ #[cfg(test)]
+ (commit_modifiers.modify_tree)(self.tree_kem_public);
+
+ self.tree_kem_public
+ .update_parent_hashes(self_index, false, cipher_suite_provider)
+ .await?;
+
+ let update_path_leaf = {
+ let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?;
+
+ self.private_key.secret_keys[0] = Some(
+ own_leaf
+ .commit(
+ cipher_suite_provider,
+ &context.group_id,
+ *self_index,
+ update_leaf_properties,
+ signing_identity,
+ signer,
+ )
+ .await?,
+ );
+
+ #[cfg(test)]
+ if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) {
+ let context = &(context.group_id.as_slice(), *self_index).into();
+
+ own_leaf
+ .sign(cipher_suite_provider, &signer, context)
+ .await
+ .unwrap();
+ }
+
+ own_leaf.clone()
+ };
+
+ // Tree modifications are all done so we can update the tree hash and encrypt with the new context
+ self.tree_kem_public
+ .update_hashes(&[self_index], cipher_suite_provider)
+ .await?;
+
+ context.tree_hash = self
+ .tree_kem_public
+ .tree_hash(cipher_suite_provider)
+ .await?;
+
+ let context_bytes = context.mls_encode_to_vec()?;
+
+ let node_updates = self
+ .encrypt_path_secrets(
+ path,
+ &path_secrets,
+ &context_bytes,
+ cipher_suite_provider,
+ excluding,
+ )
+ .await?;
+
+ #[cfg(test)]
+ let node_updates = (commit_modifiers.modify_path)(node_updates);
+
+ // Create an update path with the new node and parent node updates
+ let update_path = UpdatePath {
+ leaf_node: update_path_leaf,
+ nodes: node_updates,
+ };
+
+ Ok(EncapGeneration {
+ update_path,
+ path_secrets,
+ commit_secret: secret_generator.next_secret().await?,
+ })
+ }
+
+ #[cfg(any(mls_build_async, not(feature = "rayon")))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_path_secrets<P: CipherSuiteProvider>(
+ &self,
+ path: Vec<CopathNode<NodeIndex>>,
+ path_secrets: &[Option<PathSecret>],
+ context_bytes: &[u8],
+ cipher_suite: &P,
+ excluding: &[LeafIndex],
+ ) -> Result<Vec<UpdatePathNode>, MlsError> {
+ let excluding = excluding.iter().copied().map(NodeIndex::from);
+
+ #[cfg(feature = "std")]
+ let excluding = excluding.collect::<HashSet<NodeIndex>>();
+ #[cfg(not(feature = "std"))]
+ let excluding = excluding.collect::<Vec<NodeIndex>>();
+
+ let mut node_updates = Vec::new();
+
+ for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) {
+ if let Some(path_secret) = path_secret {
+ node_updates.push(
+ self.encrypt_copath_node_resolution(
+ cipher_suite,
+ path_secret,
+ index.copath,
+ context_bytes,
+ &excluding,
+ )
+ .await?,
+ );
+ }
+ }
+
+ Ok(node_updates)
+ }
+
+ #[cfg(all(not(mls_build_async), feature = "rayon"))]
+ fn encrypt_path_secrets<P: CipherSuiteProvider>(
+ &self,
+ path: Vec<CopathNode<NodeIndex>>,
+ path_secrets: &[Option<PathSecret>],
+ context_bytes: &[u8],
+ cipher_suite: &P,
+ excluding: &[LeafIndex],
+ ) -> Result<Vec<UpdatePathNode>, MlsError> {
+ let excluding = excluding.iter().copied().map(NodeIndex::from);
+
+ #[cfg(feature = "std")]
+ let excluding = excluding.collect::<HashSet<NodeIndex>>();
+ #[cfg(not(feature = "std"))]
+ let excluding = excluding.collect::<Vec<NodeIndex>>();
+
+ path.into_par_iter()
+ .zip(path_secrets.par_iter())
+ .filter_map(|(node, path_secret)| {
+ path_secret.as_ref().map(|path_secret| {
+ self.encrypt_copath_node_resolution(
+ cipher_suite,
+ path_secret,
+ node.copath,
+ context_bytes,
+ &excluding,
+ )
+ })
+ })
+ .collect()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn decap<CP>(
+ self,
+ sender_index: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ added_leaves: &[LeafIndex],
+ context_bytes: &[u8],
+ cipher_suite_provider: &CP,
+ ) -> Result<PathSecret, MlsError>
+ where
+ CP: CipherSuiteProvider,
+ {
+ let self_index = self.private_key.self_index;
+
+ let lca_index =
+ tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2;
+
+ let mut path = self.tree_kem_public.nodes.direct_copath(self_index);
+ let leaf = CopathNode::new(self_index.into(), 0);
+ path.insert(0, leaf);
+ let resolved_pos = self.find_resolved_pos(&path, lca_index)?;
+
+ let ct_pos =
+ self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?;
+
+ let lca_node = update_path.nodes[lca_index]
+ .as_ref()
+ .ok_or(MlsError::LcaNotFoundInDirectPath)?;
+
+ let ct = lca_node
+ .encrypted_path_secret
+ .get(ct_pos)
+ .ok_or(MlsError::LcaNotFoundInDirectPath)?;
+
+ let secret = self.private_key.secret_keys[resolved_pos]
+ .as_ref()
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?;
+
+ let public = self
+ .tree_kem_public
+ .nodes
+ .borrow_node(path[resolved_pos].path)?
+ .as_ref()
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?
+ .public_key();
+
+ let lca_path_secret =
+ PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?;
+
+ // Derive the rest of the secrets for the tree and assign to the proper nodes
+ let mut node_secret_gen =
+ PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret);
+
+ // Update secrets based on the decrypted path secret in the update
+ self.private_key.secret_keys.resize(path.len() + 1, None);
+
+ for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) {
+ if let Some(update) = update {
+ let secret = node_secret_gen.next_secret().await?;
+
+ // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees
+ // that we will be able to decrypt later.
+ let (hpke_private, hpke_public) =
+ secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ if hpke_public != update.public_key {
+ return Err(MlsError::PubKeyMismatch);
+ }
+
+ self.private_key.secret_keys[i + 1] = Some(hpke_private);
+ } else {
+ self.private_key.secret_keys[i + 1] = None;
+ }
+ }
+
+ node_secret_gen.next_secret().await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encrypt_copath_node_resolution<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ path_secret: &PathSecret,
+ copath_index: NodeIndex,
+ context: &[u8],
+ #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>,
+ #[cfg(not(feature = "std"))] excluding: &[NodeIndex],
+ ) -> Result<UpdatePathNode, MlsError> {
+ let reso = self
+ .tree_kem_public
+ .nodes
+ .get_resolution_index(copath_index)?;
+
+ let make_ctxt = |idx| async move {
+ let node = self
+ .tree_kem_public
+ .nodes
+ .borrow_node(idx)?
+ .as_non_empty()?;
+
+ path_secret
+ .encrypt(cipher_suite_provider, node.public_key(), context)
+ .await
+ };
+
+ let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) });
+
+ #[cfg(not(mls_build_async))]
+ let ctxts = ctxts.map(make_ctxt);
+
+ #[cfg(mls_build_async)]
+ let ctxts = ctxts.then(make_ctxt);
+
+ let ctxts = ctxts.try_collect().await?;
+
+ let path_index = copath_index
+ .parent_sibling(&self.tree_kem_public.total_leaf_count())
+ .ok_or(MlsError::ExpectedNode)?
+ .parent;
+
+ Ok(UpdatePathNode {
+ public_key: self
+ .tree_kem_public
+ .nodes
+ .borrow_as_parent(path_index)?
+ .public_key
+ .clone(),
+ encrypted_path_secret: ctxts,
+ })
+ }
+
+ #[inline]
+ fn find_resolved_pos(
+ &self,
+ path: &[CopathNode<NodeIndex>],
+ mut lca_index: usize,
+ ) -> Result<usize, MlsError> {
+ while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? {
+ lca_index -= 1;
+ }
+
+ // If we don't have the key, we should be an unmerged leaf at the resolved node. (If
+ // we're not, an error will be thrown later.)
+ if self.private_key.secret_keys[lca_index].is_none() {
+ lca_index = 0;
+ }
+
+ Ok(lca_index)
+ }
+
+ #[inline]
+ fn find_ciphertext_pos(
+ &self,
+ lca: NodeIndex,
+ resolved: NodeIndex,
+ excluding: &[LeafIndex],
+ ) -> Result<usize, MlsError> {
+ let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?;
+
+ let (ct_pos, _) = reso
+ .iter()
+ .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2)))
+ .find_position(|idx| idx == &&resolved)
+ .ok_or(MlsError::UpdateErrorNoSecretKey)?;
+
+ Ok(ct_pos)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::{tree_math, TreeKem};
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ extension::test_utils::TestExtension,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ leaf_node::{
+ test_utils::{get_basic_test_node_sig_key, get_test_capabilities},
+ ConfigProperties,
+ },
+ node::LeafIndex,
+ Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath,
+ },
+ ExtensionList,
+ };
+ use alloc::{format, vec, vec::Vec};
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use tree_math::TreeIndex;
+
+ // Verify that the tree is in the correct state after generating an update path
+ fn verify_tree_update_path(
+ tree: &TreeKemPublic,
+ update_path: &UpdatePath,
+ index: LeafIndex,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) {
+ // Make sure the update path is based on the direct path of the sender
+ let direct_path = tree.nodes.direct_copath(index);
+
+ for (i, n) in direct_path.iter().enumerate() {
+ assert_eq!(
+ *tree
+ .nodes
+ .borrow_node(n.path)
+ .unwrap()
+ .as_ref()
+ .unwrap()
+ .public_key(),
+ update_path.nodes[i].public_key
+ );
+ }
+
+ // Verify that the leaf from the update path has been installed
+ assert_eq!(
+ tree.nodes.borrow_as_leaf(index).unwrap(),
+ &update_path.leaf_node
+ );
+
+ // Verify that updated capabilities were installed
+ if let Some(capabilities) = capabilities {
+ assert_eq!(update_path.leaf_node.capabilities, capabilities);
+ }
+
+ // Verify that update extensions were installed
+ if let Some(extensions) = extensions {
+ assert_eq!(update_path.leaf_node.extensions, extensions);
+ }
+
+ // Verify that we have a public keys up to the root
+ let root = tree.total_leaf_count().root();
+ assert!(tree.nodes.borrow_node(root).unwrap().is_some());
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn verify_tree_private_path(
+ cipher_suite: &CipherSuite,
+ public_tree: &TreeKemPublic,
+ private_tree: &TreeKemPrivate,
+ index: LeafIndex,
+ ) {
+ let provider = test_cipher_suite_provider(*cipher_suite);
+
+ assert_eq!(private_tree.self_index, index);
+
+ // Make sure we have private values along the direct path, and the public keys match
+ let path_iter = public_tree
+ .nodes
+ .direct_copath(index)
+ .into_iter()
+ .enumerate();
+
+ for (i, n) in path_iter {
+ let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap();
+
+ let public_key = public_tree
+ .nodes
+ .borrow_node(n.path)
+ .unwrap()
+ .as_ref()
+ .unwrap()
+ .public_key();
+
+ let test_data = random_bytes(32);
+
+ let sealed = provider
+ .hpke_seal(public_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = provider
+ .hpke_open(&sealed, secret_key, public_key, &[], None)
+ .await
+ .unwrap();
+
+ assert_eq!(test_data, opened);
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn encap_decap(
+ cipher_suite: CipherSuite,
+ size: usize,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ // Generate signing keys and key package generations, and private keys for multiple
+ // participants in order to set up state
+
+ let mut leaf_nodes = Vec::new();
+ let mut private_keys = Vec::new();
+
+ for index in 1..size {
+ let (leaf_node, hpke_secret, _) =
+ get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await;
+
+ let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret);
+
+ leaf_nodes.push(leaf_node);
+ private_keys.push(private_key);
+ }
+
+ let (encap_node, encap_hpke_secret, encap_signer) =
+ get_basic_test_node_sig_key(cipher_suite, "encap").await;
+
+ // Build a test tree we can clone for all leaf nodes
+ let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive(
+ encap_node,
+ encap_hpke_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ test_tree
+ .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // Clone the tree for the first leaf, generate a new key package for that leaf
+ let mut encap_tree = test_tree.clone();
+
+ let update_leaf_properties = ConfigProperties {
+ capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities),
+ extensions: extensions.clone().unwrap_or_default(),
+ };
+
+ // Perform the encap function
+ let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[],
+ &encap_signer,
+ update_leaf_properties,
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Verify that the state of the tree matches the produced update path
+ verify_tree_update_path(
+ &encap_tree,
+ &encap_gen.update_path,
+ LeafIndex(0),
+ capabilities,
+ extensions,
+ );
+
+ // Verify that the private key matches the data in the public key
+ verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0))
+ .await;
+
+ let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap();
+ let mut unfiltered_nodes = vec![None; filtered.len()];
+ filtered
+ .into_iter()
+ .enumerate()
+ .filter(|(_, f)| !*f)
+ .zip(encap_gen.update_path.nodes.iter())
+ .for_each(|((i, _), node)| {
+ unfiltered_nodes[i] = Some(node.clone());
+ });
+
+ // Apply the update path to the rest of the leaf nodes using the decap function
+ let validated_update_path = ValidatedUpdatePath {
+ leaf_node: encap_gen.update_path.leaf_node,
+ nodes: unfiltered_nodes,
+ };
+
+ encap_tree
+ .update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let mut receiver_trees: Vec<TreeKemPublic> = (1..size).map(|_| test_tree.clone()).collect();
+
+ for (i, tree) in receiver_trees.iter_mut().enumerate() {
+ tree.apply_update_path(
+ LeafIndex(0),
+ &validated_update_path,
+ &Default::default(),
+ BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let mut context = get_test_group_context(42, cipher_suite).await;
+ context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ TreeKem::new(tree, &mut private_keys[i])
+ .decap(
+ LeafIndex(0),
+ &validated_update_path,
+ &[],
+ &context.mls_encode_to_vec().unwrap(),
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ assert_eq!(tree, &encap_tree);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_decap() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ encap_decap(cipher_suite, 10, None, None).await;
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_capabilities() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+
+ encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_extensions() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut extensions = ExtensionList::default();
+ extensions.set_from(TestExtension { foo: 10 }).unwrap();
+
+ encap_decap(cipher_suite, 10, None, Some(extensions)).await;
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_encap_capabilities_extensions() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let mut capabilities = get_test_capabilities();
+ capabilities.extensions.push(42.into());
+
+ let mut extensions = ExtensionList::default();
+ extensions.set_from(TestExtension { foo: 10 }).unwrap();
+
+ encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await;
+ }
+}
diff --git a/src/tree_kem/leaf_node.rs b/src/tree_kem/leaf_node.rs
new file mode 100644
index 0000000..c59ed78
--- /dev/null
+++ b/src/tree_kem/leaf_node.rs
@@ -0,0 +1,688 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::{parent_hash::ParentHash, Capabilities, Lifetime};
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey, SignatureSecretKey};
+use crate::{identity::SigningIdentity, signer::Signable, ExtensionList};
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+
+#[derive(Debug, Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+pub enum LeafNodeSource {
+ KeyPackage(Lifetime) = 1u8,
+ Update = 2u8,
+ Commit(ParentHash) = 3u8,
+}
+
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct LeafNode {
+ pub public_key: HpkePublicKey,
+ pub signing_identity: SigningIdentity,
+ pub capabilities: Capabilities,
+ pub leaf_node_source: LeafNodeSource,
+ pub extensions: ExtensionList,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ pub signature: Vec<u8>,
+}
+
+impl Debug for LeafNode {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_struct("LeafNode")
+ .field("public_key", &self.public_key)
+ .field("signing_identity", &self.signing_identity)
+ .field("capabilities", &self.capabilities)
+ .field("leaf_node_source", &self.leaf_node_source)
+ .field("extensions", &self.extensions)
+ .field(
+ "signature",
+ &mls_rs_core::debug::pretty_bytes(&self.signature),
+ )
+ .finish()
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct ConfigProperties {
+ pub capabilities: Capabilities,
+ pub extensions: ExtensionList,
+}
+
+impl LeafNode {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn generate<CSP>(
+ cipher_suite_provider: &CSP,
+ properties: ConfigProperties,
+ signing_identity: SigningIdentity,
+ signer: &SignatureSecretKey,
+ lifetime: Lifetime,
+ ) -> Result<(Self, HpkeSecretKey), MlsError>
+ where
+ CSP: CipherSuiteProvider,
+ {
+ let (secret_key, public_key) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ let mut leaf_node = LeafNode {
+ public_key,
+ signing_identity,
+ capabilities: properties.capabilities,
+ leaf_node_source: LeafNodeSource::KeyPackage(lifetime),
+ extensions: properties.extensions,
+ signature: Default::default(),
+ };
+
+ leaf_node.grease(cipher_suite_provider)?;
+
+ leaf_node
+ .sign(
+ cipher_suite_provider,
+ signer,
+ &LeafNodeSigningContext::default(),
+ )
+ .await?;
+
+ Ok((leaf_node, secret_key))
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ group_id: &[u8],
+ leaf_index: u32,
+ new_properties: ConfigProperties,
+ signing_identity: Option<SigningIdentity>,
+ signer: &SignatureSecretKey,
+ ) -> Result<HpkeSecretKey, MlsError> {
+ let (secret, public) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.public_key = public;
+ self.capabilities = new_properties.capabilities;
+ self.extensions = new_properties.extensions;
+ self.leaf_node_source = LeafNodeSource::Update;
+
+ self.grease(cipher_suite_provider)?;
+
+ if let Some(signing_identity) = signing_identity {
+ self.signing_identity = signing_identity;
+ }
+
+ self.sign(
+ cipher_suite_provider,
+ signer,
+ &(group_id, leaf_index).into(),
+ )
+ .await?;
+
+ Ok(secret)
+ }
+
+ #[allow(clippy::too_many_arguments)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn commit<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ group_id: &[u8],
+ leaf_index: u32,
+ new_properties: ConfigProperties,
+ new_signing_identity: Option<SigningIdentity>,
+ signer: &SignatureSecretKey,
+ ) -> Result<HpkeSecretKey, MlsError> {
+ let (secret, public) = cipher_suite_provider
+ .kem_generate()
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ self.public_key = public;
+ self.capabilities = new_properties.capabilities;
+ self.extensions = new_properties.extensions;
+
+ if let Some(new_signing_identity) = new_signing_identity {
+ self.signing_identity = new_signing_identity;
+ }
+
+ self.sign(
+ cipher_suite_provider,
+ signer,
+ &(group_id, leaf_index).into(),
+ )
+ .await?;
+
+ Ok(secret)
+ }
+}
+
+#[derive(Debug)]
+struct LeafNodeTBS<'a> {
+ public_key: &'a HpkePublicKey,
+ signing_identity: &'a SigningIdentity,
+ capabilities: &'a Capabilities,
+ leaf_node_source: &'a LeafNodeSource,
+ extensions: &'a ExtensionList,
+ group_id: Option<&'a [u8]>,
+ leaf_index: Option<u32>,
+}
+
+impl<'a> MlsSize for LeafNodeTBS<'a> {
+ fn mls_encoded_len(&self) -> usize {
+ self.public_key.mls_encoded_len()
+ + self.signing_identity.mls_encoded_len()
+ + self.capabilities.mls_encoded_len()
+ + self.leaf_node_source.mls_encoded_len()
+ + self.extensions.mls_encoded_len()
+ + self
+ .group_id
+ .as_ref()
+ .map_or(0, mls_rs_codec::byte_vec::mls_encoded_len)
+ + self.leaf_index.map_or(0, |i| i.mls_encoded_len())
+ }
+}
+
+impl<'a> MlsEncode for LeafNodeTBS<'a> {
+ fn mls_encode(&self, writer: &mut Vec<u8>) -> Result<(), mls_rs_codec::Error> {
+ self.public_key.mls_encode(writer)?;
+ self.signing_identity.mls_encode(writer)?;
+ self.capabilities.mls_encode(writer)?;
+ self.leaf_node_source.mls_encode(writer)?;
+ self.extensions.mls_encode(writer)?;
+
+ if let Some(ref group_id) = self.group_id {
+ mls_rs_codec::byte_vec::mls_encode(group_id, writer)?;
+ }
+
+ if let Some(leaf_index) = self.leaf_index {
+ leaf_index.mls_encode(writer)?;
+ }
+
+ Ok(())
+ }
+}
+
+#[derive(Clone, Debug, Default)]
+pub(crate) struct LeafNodeSigningContext<'a> {
+ pub group_id: Option<&'a [u8]>,
+ pub leaf_index: Option<u32>,
+}
+
+impl<'a> From<(&'a [u8], u32)> for LeafNodeSigningContext<'a> {
+ fn from((group_id, leaf_index): (&'a [u8], u32)) -> Self {
+ Self {
+ group_id: Some(group_id),
+ leaf_index: Some(leaf_index),
+ }
+ }
+}
+
+impl<'a> Signable<'a> for LeafNode {
+ const SIGN_LABEL: &'static str = "LeafNodeTBS";
+
+ type SigningContext = LeafNodeSigningContext<'a>;
+
+ fn signature(&self) -> &[u8] {
+ &self.signature
+ }
+
+ fn signable_content(
+ &self,
+ context: &Self::SigningContext,
+ ) -> Result<Vec<u8>, mls_rs_codec::Error> {
+ LeafNodeTBS {
+ public_key: &self.public_key,
+ signing_identity: &self.signing_identity,
+ capabilities: &self.capabilities,
+ leaf_node_source: &self.leaf_node_source,
+ extensions: &self.extensions,
+ group_id: context.group_id,
+ leaf_index: context.leaf_index,
+ }
+ .mls_encode_to_vec()
+ }
+
+ fn write_signature(&mut self, signature: Vec<u8>) {
+ self.signature = signature
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use mls_rs_core::identity::{BasicCredential, CredentialType};
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
+ identity::test_utils::{get_test_signing_identity, BasicWithCustomProvider},
+ };
+
+ use crate::extension::ApplicationIdExt;
+
+ use super::*;
+
+ #[allow(unused)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_node(
+ cipher_suite: CipherSuite,
+ signing_identity: SigningIdentity,
+ secret: &SignatureSecretKey,
+ capabilities: Option<Capabilities>,
+ extensions: Option<ExtensionList>,
+ ) -> (LeafNode, HpkeSecretKey) {
+ get_test_node_with_lifetime(
+ cipher_suite,
+ signing_identity,
+ secret,
+ capabilities.unwrap_or_else(get_test_capabilities),
+ extensions.unwrap_or_default(),
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_node_with_lifetime(
+ cipher_suite: CipherSuite,
+ signing_identity: SigningIdentity,
+ secret: &SignatureSecretKey,
+ capabilities: Capabilities,
+ extensions: ExtensionList,
+ lifetime: Lifetime,
+ ) -> (LeafNode, HpkeSecretKey) {
+ let properties = ConfigProperties {
+ capabilities,
+ extensions,
+ };
+
+ LeafNode::generate(
+ &test_cipher_suite_provider(cipher_suite),
+ properties,
+ signing_identity,
+ secret,
+ lifetime,
+ )
+ .await
+ .unwrap()
+ }
+
+ #[allow(unused)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node(cipher_suite: CipherSuite, id: &str) -> LeafNode {
+ get_basic_test_node_sig_key(cipher_suite, id).await.0
+ }
+
+ #[allow(unused)]
+ pub fn default_properties() -> ConfigProperties {
+ ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: Default::default(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node_capabilities(
+ cipher_suite: CipherSuite,
+ id: &str,
+ capabilities: Capabilities,
+ ) -> (LeafNode, HpkeSecretKey, SignatureSecretKey) {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(cipher_suite, id.as_bytes()).await;
+
+ LeafNode::generate(
+ &test_cipher_suite_provider(cipher_suite),
+ ConfigProperties {
+ capabilities,
+ extensions: Default::default(),
+ },
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .map(|(leaf, hpke_secret_key)| (leaf, hpke_secret_key, signature_key))
+ .unwrap()
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_basic_test_node_sig_key(
+ cipher_suite: CipherSuite,
+ id: &str,
+ ) -> (LeafNode, HpkeSecretKey, SignatureSecretKey) {
+ get_basic_test_node_capabilities(cipher_suite, id, get_test_capabilities()).await
+ }
+
+ #[allow(unused)]
+ pub fn get_test_extensions() -> ExtensionList {
+ let mut extension_list = ExtensionList::new();
+
+ extension_list
+ .set_from(ApplicationIdExt {
+ identifier: b"identifier".to_vec(),
+ })
+ .unwrap();
+
+ extension_list
+ }
+
+ pub fn get_test_capabilities() -> Capabilities {
+ Capabilities {
+ credentials: vec![
+ BasicCredential::credential_type(),
+ CredentialType::from(BasicWithCustomProvider::CUSTOM_CREDENTIAL_TYPE),
+ ],
+ cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
+ ..Default::default()
+ }
+ }
+
+ #[allow(unused)]
+ pub fn get_test_client_identity(leaf: &LeafNode) -> Vec<u8> {
+ leaf.signing_identity
+ .credential
+ .mls_encode_to_vec()
+ .unwrap()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::test_utils::*;
+ use super::*;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::group::test_utils::random_bytes;
+ use crate::identity::test_utils::get_test_signing_identity;
+ use assert_matches::assert_matches;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_generation() {
+ let capabilities = get_test_capabilities();
+ let extensions = get_test_extensions();
+ let lifetime = Lifetime::years(1).unwrap();
+
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (leaf_node, secret_key) = get_test_node_with_lifetime(
+ cipher_suite,
+ signing_identity.clone(),
+ &secret,
+ capabilities.clone(),
+ extensions.clone(),
+ lifetime.clone(),
+ )
+ .await;
+
+ assert_eq!(leaf_node.ungreased_capabilities(), capabilities);
+ assert_eq!(leaf_node.ungreased_extensions(), extensions);
+ assert_eq!(leaf_node.signing_identity, signing_identity);
+
+ assert_matches!(
+ &leaf_node.leaf_node_source,
+ LeafNodeSource::KeyPackage(lt) if lt == &lifetime,
+ "Expected {:?}, got {:?}", LeafNodeSource::KeyPackage(lifetime),
+ leaf_node.leaf_node_source
+ );
+
+ let provider = test_cipher_suite_provider(cipher_suite);
+
+ // Verify that the hpke key pair generated will work
+ let test_data = random_bytes(32);
+
+ let sealed = provider
+ .hpke_seal(&leaf_node.public_key, &[], None, &test_data)
+ .await
+ .unwrap();
+
+ let opened = provider
+ .hpke_open(&sealed, &secret_key, &leaf_node.public_key, &[], None)
+ .await
+ .unwrap();
+
+ assert_eq!(opened, test_data);
+
+ leaf_node
+ .verify(
+ &test_cipher_suite_provider(cipher_suite),
+ &signing_identity.signature_key,
+ &LeafNodeSigningContext::default(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_generation_randomness() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (first_leaf, first_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ for _ in 0..100 {
+ let (next_leaf, next_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ assert_ne!(first_secret, next_secret);
+ assert_ne!(first_leaf.public_key, next_leaf.public_key);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_update_no_meta_changes() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf, leaf_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ let original_leaf = leaf.clone();
+
+ let new_secret = leaf
+ .update(
+ &cipher_suite_provider,
+ b"group",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(new_secret, leaf_secret);
+ assert_ne!(original_leaf.public_key, leaf.public_key);
+
+ assert_eq!(
+ leaf.ungreased_capabilities(),
+ original_leaf.ungreased_capabilities()
+ );
+
+ assert_eq!(
+ leaf.ungreased_extensions(),
+ original_leaf.ungreased_extensions()
+ );
+
+ assert_eq!(leaf.signing_identity, original_leaf.signing_identity);
+ assert_matches!(&leaf.leaf_node_source, LeafNodeSource::Update);
+
+ leaf.verify(
+ &cipher_suite_provider,
+ &signing_identity.signature_key,
+ &(b"group".as_slice(), 0).into(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_update_meta_changes() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let new_properties = ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: get_test_extensions(),
+ };
+
+ let (mut leaf, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ leaf.update(
+ &test_cipher_suite_provider(cipher_suite),
+ b"group",
+ 0,
+ new_properties.clone(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(leaf.ungreased_capabilities(), new_properties.capabilities);
+ assert_eq!(leaf.ungreased_extensions(), new_properties.extensions);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_commit_no_meta_changes() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf, leaf_secret) =
+ get_test_node(cipher_suite, signing_identity.clone(), &secret, None, None).await;
+
+ let original_leaf = leaf.clone();
+
+ let new_secret = leaf
+ .commit(
+ &cipher_suite_provider,
+ b"group",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_ne!(new_secret, leaf_secret);
+ assert_ne!(original_leaf.public_key, leaf.public_key);
+
+ assert_eq!(
+ leaf.ungreased_capabilities(),
+ original_leaf.ungreased_capabilities()
+ );
+
+ assert_eq!(
+ leaf.ungreased_extensions(),
+ original_leaf.ungreased_extensions()
+ );
+
+ assert_eq!(leaf.signing_identity, original_leaf.signing_identity);
+
+ leaf.verify(
+ &cipher_suite_provider,
+ &signing_identity.signature_key,
+ &(b"group".as_slice(), 0).into(),
+ )
+ .await
+ .unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_node_commit_meta_changes() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+ let (mut leaf, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ let new_properties = ConfigProperties {
+ capabilities: get_test_capabilities(),
+ extensions: get_test_extensions(),
+ };
+
+ // The new identity has a fresh public key
+ let new_signing_identity = get_test_signing_identity(cipher_suite, b"foo").await.0;
+
+ leaf.commit(
+ &test_cipher_suite_provider(cipher_suite),
+ b"group",
+ 0,
+ new_properties.clone(),
+ Some(new_signing_identity.clone()),
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(leaf.capabilities, new_properties.capabilities);
+ assert_eq!(leaf.extensions, new_properties.extensions);
+ assert_eq!(leaf.signing_identity, new_signing_identity);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn context_is_signed() {
+ let provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let (mut leaf, _) = get_test_node(
+ TEST_CIPHER_SUITE,
+ signing_identity.clone(),
+ &secret,
+ None,
+ None,
+ )
+ .await;
+
+ leaf.sign(&provider, &secret, &(b"foo".as_slice(), 0).into())
+ .await
+ .unwrap();
+
+ let res = leaf
+ .verify(
+ &provider,
+ &signing_identity.signature_key,
+ &(b"foo".as_slice(), 1).into(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+
+ let res = leaf
+ .verify(
+ &provider,
+ &signing_identity.signature_key,
+ &(b"bar".as_slice(), 0).into(),
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+}
diff --git a/src/tree_kem/leaf_node_validator.rs b/src/tree_kem/leaf_node_validator.rs
new file mode 100644
index 0000000..17742ec
--- /dev/null
+++ b/src/tree_kem/leaf_node_validator.rs
@@ -0,0 +1,708 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::{LeafNode, LeafNodeSigningContext, LeafNodeSource};
+use crate::client::MlsError;
+use crate::CipherSuiteProvider;
+use crate::{signer::Signable, time::MlsTime};
+use mls_rs_core::{error::IntoAnyError, extension::ExtensionList, identity::IdentityProvider};
+
+use crate::extension::RequiredCapabilitiesExt;
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::extension::ExternalSendersExt;
+
+pub enum ValidationContext<'a> {
+ Add(Option<MlsTime>),
+ Update((&'a [u8], u32, Option<MlsTime>)),
+ Commit((&'a [u8], u32, Option<MlsTime>)),
+}
+
+impl<'a> ValidationContext<'a> {
+ fn signing_context(&self) -> LeafNodeSigningContext {
+ match *self {
+ ValidationContext::Add(_) => Default::default(),
+ ValidationContext::Update((group_id, leaf_index, _)) => (group_id, leaf_index).into(),
+ ValidationContext::Commit((group_id, leaf_index, _)) => (group_id, leaf_index).into(),
+ }
+ }
+
+ fn generation_time(&self) -> Option<MlsTime> {
+ match *self {
+ ValidationContext::Add(t) => t,
+ ValidationContext::Update((_, _, t)) => t,
+ ValidationContext::Commit((_, _, t)) => t,
+ }
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct LeafNodeValidator<'a, C, CP>
+where
+ C: IdentityProvider,
+ CP: CipherSuiteProvider,
+{
+ cipher_suite_provider: &'a CP,
+ identity_provider: &'a C,
+ group_context_extensions: Option<&'a ExtensionList>,
+}
+
+impl<'a, C: IdentityProvider, CP: CipherSuiteProvider> LeafNodeValidator<'a, C, CP> {
+ pub fn new(
+ cipher_suite_provider: &'a CP,
+ identity_provider: &'a C,
+ group_context_extensions: Option<&'a ExtensionList>,
+ ) -> Self {
+ Self {
+ cipher_suite_provider,
+ identity_provider,
+ group_context_extensions,
+ }
+ }
+
+ fn check_context(
+ &self,
+ leaf_node: &LeafNode,
+ context: &ValidationContext,
+ ) -> Result<(), MlsError> {
+ // Context specific checks
+ match context {
+ ValidationContext::Add(time) => {
+ // If the context is add, and we specified a time to check for lifetime, verify it
+ if let LeafNodeSource::KeyPackage(lifetime) = &leaf_node.leaf_node_source {
+ if let Some(current_time) = time {
+ if !lifetime.within_lifetime(*current_time) {
+ return Err(MlsError::InvalidLifetime);
+ }
+ }
+ } else {
+ // If the leaf_node_source is anything other than Add it is invalid
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ ValidationContext::Update(_) => {
+ // If the leaf_node_source is anything other than Update it is invalid
+ if !matches!(leaf_node.leaf_node_source, LeafNodeSource::Update) {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ ValidationContext::Commit(_) => {
+ // If the leaf_node_source is anything other than Commit it is invalid
+ if !matches!(leaf_node.leaf_node_source, LeafNodeSource::Commit(_)) {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn revalidate(
+ &self,
+ leaf_node: &LeafNode,
+ group_id: &[u8],
+ leaf_index: u32,
+ ) -> Result<(), MlsError> {
+ let context = match leaf_node.leaf_node_source {
+ LeafNodeSource::KeyPackage(_) => ValidationContext::Add(None),
+ LeafNodeSource::Update => ValidationContext::Update((group_id, leaf_index, None)),
+ LeafNodeSource::Commit(_) => ValidationContext::Commit((group_id, leaf_index, None)),
+ };
+
+ self.check_if_valid(leaf_node, context).await
+ }
+
+ pub fn validate_required_capabilities(&self, leaf_node: &LeafNode) -> Result<(), MlsError> {
+ let Some(required_capabilities) = self
+ .group_context_extensions
+ .and_then(|exts| exts.get_as::<RequiredCapabilitiesExt>().transpose())
+ .transpose()?
+ else {
+ return Ok(());
+ };
+
+ for extension in &required_capabilities.extensions {
+ if !leaf_node.capabilities.extensions.contains(extension) {
+ return Err(MlsError::RequiredExtensionNotFound(*extension));
+ }
+ }
+
+ for proposal in &required_capabilities.proposals {
+ if !leaf_node.capabilities.proposals.contains(proposal) {
+ return Err(MlsError::RequiredProposalNotFound(*proposal));
+ }
+ }
+
+ for credential in &required_capabilities.credentials {
+ if !leaf_node.capabilities.credentials.contains(credential) {
+ return Err(MlsError::RequiredCredentialNotFound(*credential));
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn validate_external_senders_ext_credentials(
+ &self,
+ leaf_node: &LeafNode,
+ ) -> Result<(), MlsError> {
+ let Some(ext) = self
+ .group_context_extensions
+ .and_then(|exts| exts.get_as::<ExternalSendersExt>().transpose())
+ .transpose()?
+ else {
+ return Ok(());
+ };
+
+ ext.allowed_senders.iter().try_for_each(|sender| {
+ let cred_type = sender.credential.credential_type();
+ leaf_node
+ .capabilities
+ .credentials
+ .contains(&cred_type)
+ .then_some(())
+ .ok_or(MlsError::RequiredCredentialNotFound(cred_type))
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn check_if_valid(
+ &self,
+ leaf_node: &LeafNode,
+ context: ValidationContext<'_>,
+ ) -> Result<(), MlsError> {
+ // Check that we are validating within the proper context
+ self.check_context(leaf_node, &context)?;
+
+ // Verify the credential
+ self.identity_provider
+ .validate_member(
+ &leaf_node.signing_identity,
+ context.generation_time(),
+ self.group_context_extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ // Verify that the credential signed the leaf node
+ leaf_node
+ .verify(
+ self.cipher_suite_provider,
+ &leaf_node.signing_identity.signature_key,
+ &context.signing_context(),
+ )
+ .await?;
+
+ // If required capabilities are specified, verify the leaf node meets the requirements
+ self.validate_required_capabilities(leaf_node)?;
+
+ // If there are extensions, make sure they are referenced in the capabilities field
+ for one_ext in &*leaf_node.extensions {
+ if !leaf_node
+ .capabilities
+ .extensions
+ .contains(&one_ext.extension_type)
+ {
+ return Err(MlsError::ExtensionNotInCapabilities(one_ext.extension_type));
+ }
+ }
+
+ // Verify that group extensions are supported by the leaf
+ self.group_context_extensions
+ .into_iter()
+ .flat_map(|exts| &**exts)
+ .map(|ext| ext.extension_type)
+ .find(|ext_type| {
+ !ext_type.is_default() && !leaf_node.capabilities.extensions.contains(ext_type)
+ })
+ .map(MlsError::UnsupportedGroupExtension)
+ .map_or(Ok(()), Err)?;
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.validate_external_senders_ext_credentials(leaf_node)?;
+
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::crypto::test_utils::try_test_cipher_suite_provider;
+ use crate::extension::MlsExtension;
+ use alloc::vec;
+ use assert_matches::assert_matches;
+ #[cfg(feature = "std")]
+ use core::time::Duration;
+ use mls_rs_core::crypto::CipherSuite;
+ use mls_rs_core::group::ProposalType;
+
+ use super::*;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::crypto::SignatureSecretKey;
+ use crate::extension::test_utils::TestExtension;
+ use crate::group::test_utils::random_bytes;
+ use crate::identity::basic::BasicCredential;
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::identity::test_utils::get_test_signing_identity;
+ use crate::tree_kem::leaf_node::test_utils::*;
+ use crate::tree_kem::leaf_node_validator::test_utils::FailureIdentityProvider;
+ use crate::tree_kem::Capabilities;
+ use crate::ExtensionList;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_test_add_node() -> (LeafNode, SignatureSecretKey) {
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let (leaf_node, _) =
+ get_test_node(TEST_CIPHER_SUITE, signing_identity, &secret, None, None).await;
+
+ (leaf_node, secret)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_add_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_failed_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let fail_test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &FailureIdentityProvider, None);
+
+ let res = fail_test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::IdentityProviderError(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_update_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group_id = b"group_id";
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ leaf_node
+ .update(
+ &cipher_suite_provider,
+ group_id,
+ 0,
+ // TODO remove identity from input
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((group_id, 0, None)))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_basic_commit_validation() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let group_id = b"group_id";
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(hex!("f00d").into());
+
+ leaf_node
+ .commit(
+ &cipher_suite_provider,
+ group_id,
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((group_id, 0, None)))
+ .await;
+
+ assert_matches!(res, Ok(_));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_incorrect_context() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let (mut leaf_node, secret) = get_test_add_node().await;
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ leaf_node
+ .update(
+ &cipher_suite_provider,
+ b"foo",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Commit((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(hex!("f00d").into());
+
+ leaf_node
+ .commit(
+ &cipher_suite_provider,
+ b"foo",
+ 0,
+ default_properties(),
+ None,
+ &secret,
+ )
+ .await
+ .unwrap();
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Update((b"foo", 0, None)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLeafNodeSource));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_bad_signature() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (signing_identity, secret) = get_test_signing_identity(cipher_suite, b"foo").await;
+
+ let (mut leaf_node, _) =
+ get_test_node(cipher_suite, signing_identity, &secret, None, None).await;
+
+ leaf_node.signature = random_bytes(leaf_node.signature.len());
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_capabilities_mismatch() {
+ let (signing_identity, secret) = get_test_signing_identity(TEST_CIPHER_SUITE, b"foo").await;
+
+ let mut extensions = ExtensionList::new();
+
+ extensions.set_from(TestExtension::from(0)).unwrap();
+
+ let capabilities = Capabilities {
+ credentials: vec![BasicCredential::credential_type()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_node(
+ TEST_CIPHER_SUITE,
+ signing_identity,
+ &secret,
+ Some(capabilities),
+ Some(extensions),
+ )
+ .await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res,
+ Err(MlsError::ExtensionNotInCapabilities(ext)) if ext == 42.into());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_cipher_suite_mismatch() {
+ for another_cipher_suite in CipherSuite::all().filter(|cs| cs != &TEST_CIPHER_SUITE) {
+ if let Some(cs) = try_test_cipher_suite_provider(*another_cipher_suite) {
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let test_validator = LeafNodeValidator::new(&cs, &BasicIdentityProvider, None);
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_extension() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ extensions: vec![43.into()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredExtensionNotFound(v)) if v == 43.into()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_proposal() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ proposals: vec![42.into()],
+ ..Default::default()
+ };
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(
+ res,
+ Err(MlsError::RequiredProposalNotFound(p)) if p == ProposalType::new(42)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_required_credential() {
+ let required_capabilities = RequiredCapabilitiesExt {
+ credentials: vec![0.into()],
+ ..Default::default()
+ };
+
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let group_context_extensions =
+ core::iter::once(required_capabilities.into_extension().unwrap()).collect();
+
+ let test_validator = LeafNodeValidator::new(
+ &cipher_suite_provider,
+ &BasicIdentityProvider,
+ Some(&group_context_extensions),
+ );
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(None))
+ .await;
+
+ assert_matches!(res,
+ Err(MlsError::RequiredCredentialNotFound(ext)) if ext == 0.into()
+ );
+ }
+
+ #[cfg(feature = "std")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_lifetime() {
+ let (leaf_node, _) = get_test_add_node().await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let test_validator =
+ LeafNodeValidator::new(&cipher_suite_provider, &BasicIdentityProvider, None);
+
+ let good_lifetime = MlsTime::now();
+
+ let over_one_year = good_lifetime.seconds_since_epoch() + (86400 * 366);
+
+ let bad_lifetime = MlsTime::from_duration_since_epoch(Duration::from_secs(over_one_year));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(Some(good_lifetime)))
+ .await;
+
+ assert_matches!(res, Ok(()));
+
+ let res = test_validator
+ .check_if_valid(&leaf_node, ValidationContext::Add(Some(bad_lifetime)))
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidLifetime));
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use alloc::vec;
+ use alloc::{boxed::Box, vec::Vec};
+ use mls_rs_codec::MlsEncode;
+ use mls_rs_core::{
+ error::IntoAnyError,
+ extension::ExtensionList,
+ identity::{BasicCredential, IdentityProvider},
+ };
+
+ use crate::{identity::SigningIdentity, time::MlsTime};
+
+ #[derive(Clone, Debug, Default)]
+ pub struct FailureIdentityProvider;
+
+ #[cfg(feature = "by_ref_proposal")]
+ impl FailureIdentityProvider {
+ pub fn new() -> Self {
+ Self
+ }
+ }
+
+ #[derive(Debug)]
+ #[cfg_attr(feature = "std", derive(thiserror::Error))]
+ #[cfg_attr(feature = "std", error("test error"))]
+ pub struct TestFailureError;
+
+ impl IntoAnyError for TestFailureError {
+ #[cfg(feature = "std")]
+ fn into_dyn_error(self) -> Result<Box<dyn std::error::Error + Send + Sync>, Self> {
+ Ok(self.into())
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(mls_build_async, maybe_async::must_be_async)]
+ impl IdentityProvider for FailureIdentityProvider {
+ type Error = TestFailureError;
+
+ async fn validate_member(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ Err(TestFailureError)
+ }
+
+ async fn validate_external_sender(
+ &self,
+ _signing_identity: &SigningIdentity,
+ _timestamp: Option<MlsTime>,
+ _extensions: Option<&ExtensionList>,
+ ) -> Result<(), Self::Error> {
+ Err(TestFailureError)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn identity(
+ &self,
+ signing_id: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<Vec<u8>, Self::Error> {
+ Ok(signing_id.credential.mls_encode_to_vec().unwrap())
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn valid_successor(
+ &self,
+ _predecessor: &SigningIdentity,
+ _successor: &SigningIdentity,
+ _extensions: &ExtensionList,
+ ) -> Result<bool, Self::Error> {
+ Err(TestFailureError)
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn supported_types(&self) -> Vec<crate::identity::CredentialType> {
+ vec![BasicCredential::credential_type()]
+ }
+ }
+}
diff --git a/src/tree_kem/lifetime.rs b/src/tree_kem/lifetime.rs
new file mode 100644
index 0000000..d508ad6
--- /dev/null
+++ b/src/tree_kem/lifetime.rs
@@ -0,0 +1,119 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::{client::MlsError, time::MlsTime};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Default)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct Lifetime {
+ pub not_before: u64,
+ pub not_after: u64,
+}
+
+impl Lifetime {
+ pub fn new(not_before: u64, not_after: u64) -> Lifetime {
+ Lifetime {
+ not_before,
+ not_after,
+ }
+ }
+
+ pub fn seconds(s: u64) -> Result<Self, MlsError> {
+ #[cfg(feature = "std")]
+ let not_before = MlsTime::now().seconds_since_epoch();
+ #[cfg(not(feature = "std"))]
+ // There is no clock on no_std, this is here just so that we can run tests.
+ let not_before = 3600u64;
+
+ let not_after = not_before.checked_add(s).ok_or(MlsError::TimeOverflow)?;
+
+ Ok(Lifetime {
+ // Subtract 1 hour to address time difference between machines
+ not_before: not_before - 3600,
+ not_after,
+ })
+ }
+
+ pub fn days(d: u32) -> Result<Self, MlsError> {
+ Self::seconds((d * 86400) as u64)
+ }
+
+ pub fn years(y: u8) -> Result<Self, MlsError> {
+ Self::days(365 * y as u32)
+ }
+
+ pub(crate) fn within_lifetime(&self, time: MlsTime) -> bool {
+ let since_epoch = time.seconds_since_epoch();
+ since_epoch >= self.not_before && since_epoch <= self.not_after
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use core::time::Duration;
+
+ use super::*;
+ use assert_matches::assert_matches;
+
+ #[test]
+ fn test_lifetime_overflow() {
+ let res = Lifetime::seconds(u64::MAX);
+ assert_matches!(res, Err(MlsError::TimeOverflow))
+ }
+
+ #[test]
+ fn test_seconds() {
+ let seconds = 10;
+ let lifetime = Lifetime::seconds(seconds).unwrap();
+ assert_eq!(lifetime.not_after - lifetime.not_before, 3610);
+ }
+
+ #[test]
+ fn test_days() {
+ let days = 2;
+ let lifetime = Lifetime::days(days).unwrap();
+
+ assert_eq!(
+ lifetime.not_after - lifetime.not_before,
+ 86400u64 * days as u64 + 3600
+ );
+ }
+
+ #[test]
+ fn test_years() {
+ let years = 2;
+ let lifetime = Lifetime::years(years).unwrap();
+
+ assert_eq!(
+ lifetime.not_after - lifetime.not_before,
+ 86400 * 365 * years as u64 + 3600
+ );
+ }
+
+ #[test]
+ fn test_bounds() {
+ let test_lifetime = Lifetime {
+ not_before: 5,
+ not_after: 10,
+ };
+
+ assert!(!test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(4))));
+
+ assert!(!test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(11))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(5))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(10))));
+
+ assert!(test_lifetime
+ .within_lifetime(MlsTime::from_duration_since_epoch(Duration::from_secs(6))));
+ }
+}
diff --git a/src/tree_kem/math.rs b/src/tree_kem/math.rs
new file mode 100644
index 0000000..51f82ed
--- /dev/null
+++ b/src/tree_kem/math.rs
@@ -0,0 +1,383 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec::Vec;
+use core::{fmt::Debug, hash::Hash};
+use mls_rs_codec::{MlsDecode, MlsEncode};
+
+use super::node::LeafIndex;
+
+pub trait TreeIndex:
+ Send + Sync + Eq + Clone + Debug + Default + MlsEncode + MlsDecode + Hash + Ord
+{
+ fn root(&self) -> Self;
+
+ fn left_unchecked(&self) -> Self;
+ fn right_unchecked(&self) -> Self;
+
+ fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>>;
+ fn is_leaf(&self) -> bool;
+ fn is_in_tree(&self, root: &Self) -> bool;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ fn zero() -> Self;
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
+ fn left(&self) -> Option<Self> {
+ (!self.is_leaf()).then(|| self.left_unchecked())
+ }
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message", test))]
+ fn right(&self) -> Option<Self> {
+ (!self.is_leaf()).then(|| self.right_unchecked())
+ }
+
+ fn direct_copath(&self, leaf_count: &Self) -> Vec<CopathNode<Self>> {
+ let root = leaf_count.root();
+
+ if !self.is_in_tree(&root) {
+ return Vec::new();
+ }
+
+ let mut path = Vec::new();
+ let mut parent = self.clone();
+
+ while let Some(ps) = parent.parent_sibling(leaf_count) {
+ path.push(CopathNode::new(ps.parent.clone(), ps.sibling));
+ parent = ps.parent;
+ }
+
+ path
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct CopathNode<T> {
+ pub path: T,
+ pub copath: T,
+}
+
+impl<T: Clone + PartialEq + Eq + core::fmt::Debug> CopathNode<T> {
+ pub fn new(path: T, copath: T) -> CopathNode<T> {
+ CopathNode { path, copath }
+ }
+}
+
+#[derive(Clone, PartialEq, Eq, Debug)]
+pub struct ParentSibling<T> {
+ pub parent: T,
+ pub sibling: T,
+}
+
+impl<T: Clone + PartialEq + Eq + core::fmt::Debug> ParentSibling<T> {
+ pub fn new(parent: T, sibling: T) -> ParentSibling<T> {
+ ParentSibling { parent, sibling }
+ }
+}
+
+macro_rules! impl_tree_stdint {
+ ($t:ty) => {
+ impl TreeIndex for $t {
+ fn root(&self) -> $t {
+ *self - 1
+ }
+
+ /// Panicks if `x` is even in debug, overflows in release.
+ fn left_unchecked(&self) -> Self {
+ *self ^ (0x01 << (self.trailing_ones() - 1))
+ }
+
+ /// Panicks if `x` is even in debug, overflows in release.
+ fn right_unchecked(&self) -> Self {
+ *self ^ (0x03 << (self.trailing_ones() - 1))
+ }
+
+ fn parent_sibling(&self, leaf_count: &Self) -> Option<ParentSibling<Self>> {
+ if self == &leaf_count.root() {
+ return None;
+ }
+
+ let lvl = self.trailing_ones();
+ let p = (self & !(1 << (lvl + 1))) | (1 << lvl);
+
+ let s = if *self < p {
+ p.right_unchecked()
+ } else {
+ p.left_unchecked()
+ };
+
+ Some(ParentSibling::new(p, s))
+ }
+
+ fn is_leaf(&self) -> bool {
+ self & 1 == 0
+ }
+
+ fn is_in_tree(&self, root: &Self) -> bool {
+ *self <= 2 * root
+ }
+
+ #[cfg(any(feature = "secret_tree_access", feature = "private_message"))]
+ fn zero() -> Self {
+ 0
+ }
+ }
+ };
+}
+
+impl_tree_stdint!(u32);
+
+#[cfg(test)]
+impl_tree_stdint!(u64);
+
+pub fn leaf_lca_level(x: u32, y: u32) -> u32 {
+ let mut xn = x;
+ let mut yn = y;
+ let mut k = 0;
+
+ while xn != yn {
+ xn >>= 1;
+ yn >>= 1;
+ k += 1;
+ }
+
+ k
+}
+
+pub fn subtree(x: u32) -> (LeafIndex, LeafIndex) {
+ let breadth = 1 << x.trailing_ones();
+ (
+ LeafIndex((x + 1 - breadth) >> 1),
+ LeafIndex(((x + breadth) >> 1) + 1),
+ )
+}
+
+pub struct BfsIterTopDown {
+ level: usize,
+ mask: usize,
+ level_end: usize,
+ ctr: usize,
+}
+
+impl BfsIterTopDown {
+ pub fn new(num_leaves: usize) -> Self {
+ let depth = num_leaves.trailing_zeros() as usize;
+ Self {
+ level: depth + 1,
+ mask: (1 << depth) - 1,
+ level_end: 1,
+ ctr: 0,
+ }
+ }
+}
+
+impl Iterator for BfsIterTopDown {
+ type Item = usize;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ if self.ctr == self.level_end {
+ if self.level == 1 {
+ return None;
+ }
+ self.level_end = (((self.level_end - 1) << 1) | 1) + 1;
+ self.level -= 1;
+ self.ctr = 0;
+ self.mask >>= 1;
+ }
+ let res = Some((self.ctr << self.level) | self.mask);
+ self.ctr += 1;
+ res
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use itertools::Itertools;
+ use serde::{Deserialize, Serialize};
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(Serialize, Deserialize)]
+ struct TestCase {
+ n_leaves: u32,
+ n_nodes: u32,
+ root: u32,
+ left: Vec<Option<u32>>,
+ right: Vec<Option<u32>>,
+ parent: Vec<Option<u32>>,
+ sibling: Vec<Option<u32>>,
+ }
+
+ pub fn node_width(n: u32) -> u32 {
+ if n == 0 {
+ 0
+ } else {
+ 2 * (n - 1) + 1
+ }
+ }
+
+ #[test]
+ fn test_bfs_iterator() {
+ let expected = [7, 3, 11, 1, 5, 9, 13, 0, 2, 4, 6, 8, 10, 12, 14];
+ let bfs = BfsIterTopDown::new(8);
+ assert_eq!(bfs.collect::<Vec<_>>(), expected);
+ }
+
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate_tree_math_test_cases() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for log_n_leaves in 0..8 {
+ let n_leaves = 1 << log_n_leaves;
+ let n_nodes = node_width(n_leaves);
+ let left = (0..n_nodes).map(|x| x.left()).collect::<Vec<_>>();
+ let right = (0..n_nodes).map(|x| x.right()).collect::<Vec<_>>();
+
+ let (parent, sibling) = (0..n_nodes)
+ .map(|x| {
+ x.parent_sibling(&n_leaves)
+ .map(|ps| (ps.parent, ps.sibling))
+ .unzip()
+ })
+ .unzip();
+
+ test_cases.push(TestCase {
+ n_leaves,
+ n_nodes,
+ root: n_leaves.root(),
+ left,
+ right,
+ parent,
+ sibling,
+ })
+ }
+
+ test_cases
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_math, generate_tree_math_test_cases())
+ }
+
+ #[test]
+ fn test_tree_math() {
+ let test_cases = load_test_cases();
+
+ for case in test_cases {
+ assert_eq!(node_width(case.n_leaves), case.n_nodes);
+ assert_eq!(case.n_leaves.root(), case.root);
+
+ for x in 0..case.n_nodes {
+ assert_eq!(x.left(), case.left[x as usize]);
+ assert_eq!(x.right(), case.right[x as usize]);
+
+ let (p, s) = x
+ .parent_sibling(&case.n_leaves)
+ .map(|ps| (ps.parent, ps.sibling))
+ .unzip();
+
+ assert_eq!(p, case.parent[x as usize]);
+ assert_eq!(s, case.sibling[x as usize]);
+ }
+ }
+ }
+
+ #[test]
+ fn test_direct_path() {
+ let expected: Vec<Vec<u32>> = [
+ [0x01, 0x03, 0x07, 0x0f].to_vec(),
+ [0x03, 0x07, 0x0f].to_vec(),
+ [0x01, 0x03, 0x07, 0x0f].to_vec(),
+ [0x07, 0x0f].to_vec(),
+ [0x05, 0x03, 0x07, 0x0f].to_vec(),
+ [0x03, 0x07, 0x0f].to_vec(),
+ [0x05, 0x03, 0x07, 0x0f].to_vec(),
+ [0x0f].to_vec(),
+ [0x09, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x0b, 0x07, 0x0f].to_vec(),
+ [0x09, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x07, 0x0f].to_vec(),
+ [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
+ [0x0b, 0x07, 0x0f].to_vec(),
+ [0x0d, 0x0b, 0x07, 0x0f].to_vec(),
+ [].to_vec(),
+ [0x11, 0x13, 0x17, 0x0f].to_vec(),
+ [0x13, 0x17, 0x0f].to_vec(),
+ [0x11, 0x13, 0x17, 0x0f].to_vec(),
+ [0x17, 0x0f].to_vec(),
+ [0x15, 0x13, 0x17, 0x0f].to_vec(),
+ [0x13, 0x17, 0x0f].to_vec(),
+ [0x15, 0x13, 0x17, 0x0f].to_vec(),
+ [0x0f].to_vec(),
+ [0x19, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x1b, 0x17, 0x0f].to_vec(),
+ [0x19, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x17, 0x0f].to_vec(),
+ [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
+ [0x1b, 0x17, 0x0f].to_vec(),
+ [0x1d, 0x1b, 0x17, 0x0f].to_vec(),
+ ]
+ .to_vec();
+
+ for (i, item) in expected.iter().enumerate() {
+ let path = (i as u32)
+ .direct_copath(&16)
+ .into_iter()
+ .map(|cp| cp.path)
+ .collect_vec();
+
+ assert_eq!(item, &path)
+ }
+ }
+
+ #[test]
+ fn test_copath_path() {
+ let expected: Vec<Vec<u32>> = [
+ [0x02, 0x05, 0x0b, 0x17].to_vec(),
+ [0x05, 0x0b, 0x17].to_vec(),
+ [0x00, 0x05, 0x0b, 0x17].to_vec(),
+ [0x0b, 0x17].to_vec(),
+ [0x06, 0x01, 0x0b, 0x17].to_vec(),
+ [0x01, 0x0b, 0x17].to_vec(),
+ [0x04, 0x01, 0x0b, 0x17].to_vec(),
+ [0x17].to_vec(),
+ [0x0a, 0x0d, 0x03, 0x17].to_vec(),
+ [0x0d, 0x03, 0x17].to_vec(),
+ [0x08, 0x0d, 0x03, 0x17].to_vec(),
+ [0x03, 0x17].to_vec(),
+ [0x0e, 0x09, 0x03, 0x17].to_vec(),
+ [0x09, 0x03, 0x17].to_vec(),
+ [0x0c, 0x09, 0x03, 0x17].to_vec(),
+ [].to_vec(),
+ [0x12, 0x15, 0x1b, 0x07].to_vec(),
+ [0x15, 0x1b, 0x07].to_vec(),
+ [0x10, 0x15, 0x1b, 0x07].to_vec(),
+ [0x1b, 0x07].to_vec(),
+ [0x16, 0x11, 0x1b, 0x07].to_vec(),
+ [0x11, 0x1b, 0x07].to_vec(),
+ [0x14, 0x11, 0x1b, 0x07].to_vec(),
+ [0x07].to_vec(),
+ [0x1a, 0x1d, 0x13, 0x07].to_vec(),
+ [0x1d, 0x13, 0x07].to_vec(),
+ [0x18, 0x1d, 0x13, 0x07].to_vec(),
+ [0x13, 0x07].to_vec(),
+ [0x1e, 0x19, 0x13, 0x07].to_vec(),
+ [0x19, 0x13, 0x07].to_vec(),
+ [0x1c, 0x19, 0x13, 0x07].to_vec(),
+ ]
+ .to_vec();
+
+ for (i, item) in expected.iter().enumerate() {
+ let copath = (i as u32)
+ .direct_copath(&16)
+ .into_iter()
+ .map(|cp| cp.copath)
+ .collect_vec();
+
+ assert_eq!(item, &copath)
+ }
+ }
+}
diff --git a/src/tree_kem/mod.rs b/src/tree_kem/mod.rs
new file mode 100644
index 0000000..430ee16
--- /dev/null
+++ b/src/tree_kem/mod.rs
@@ -0,0 +1,1490 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::vec;
+use alloc::vec::Vec;
+#[cfg(feature = "std")]
+use core::fmt::Display;
+use itertools::Itertools;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::extension::ExtensionList;
+
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::identity::SigningIdentity;
+
+use math as tree_math;
+use node::{LeafIndex, NodeIndex, NodeVec};
+
+use self::leaf_node::LeafNode;
+
+use crate::client::MlsError;
+use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey};
+
+#[cfg(feature = "by_ref_proposal")]
+use crate::group::proposal::{AddProposal, UpdateProposal};
+
+#[cfg(any(test, feature = "by_ref_proposal"))]
+use crate::group::proposal::RemoveProposal;
+
+use crate::group::proposal_filter::ProposalBundle;
+use crate::tree_kem::tree_hash::TreeHashes;
+
+mod capabilities;
+pub(crate) mod hpke_encryption;
+mod lifetime;
+pub(crate) mod math;
+pub mod node;
+pub mod parent_hash;
+pub mod path_secret;
+mod private;
+mod tree_hash;
+pub mod tree_validator;
+pub mod update_path;
+
+pub use capabilities::*;
+pub use lifetime::*;
+pub(crate) use private::*;
+pub use update_path::*;
+
+use tree_index::*;
+
+pub mod kem;
+pub mod leaf_node;
+pub mod leaf_node_validator;
+mod tree_index;
+
+#[cfg(feature = "std")]
+pub(crate) mod tree_utils;
+
+#[cfg(test)]
+mod interop_test_vectors;
+
+#[cfg(feature = "custom_proposal")]
+use crate::group::proposal::ProposalType;
+
+#[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct TreeKemPublic {
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(feature = "serde", serde(skip))]
+ index: TreeIndex,
+ pub(crate) nodes: NodeVec,
+ tree_hashes: TreeHashes,
+}
+
+impl PartialEq for TreeKemPublic {
+ fn eq(&self, other: &Self) -> bool {
+ self.nodes == other.nodes
+ }
+}
+
+impl TreeKemPublic {
+ pub fn new() -> TreeKemPublic {
+ Default::default()
+ }
+
+ #[cfg_attr(not(feature = "tree_index"), allow(unused))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn import_node_data<IP>(
+ nodes: NodeVec,
+ identity_provider: &IP,
+ extensions: &ExtensionList,
+ ) -> Result<TreeKemPublic, MlsError>
+ where
+ IP: IdentityProvider,
+ {
+ let mut tree = TreeKemPublic {
+ nodes,
+ ..Default::default()
+ };
+
+ #[cfg(feature = "tree_index")]
+ tree.initialize_index_if_necessary(identity_provider, extensions)
+ .await?;
+
+ Ok(tree)
+ }
+
+ #[cfg(feature = "tree_index")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn initialize_index_if_necessary<IP: IdentityProvider>(
+ &mut self,
+ identity_provider: &IP,
+ extensions: &ExtensionList,
+ ) -> Result<(), MlsError> {
+ if !self.index.is_initialized() {
+ self.index = TreeIndex::new();
+
+ for (leaf_index, leaf) in self.nodes.non_empty_leaves() {
+ index_insert(
+ &mut self.index,
+ leaf,
+ leaf_index,
+ identity_provider,
+ extensions,
+ )
+ .await?;
+ }
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "tree_index")]
+ pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
+ self.index.get_leaf_index_with_identity(identity)
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_leaf_node_with_identity<I: IdentityProvider>(
+ &self,
+ identity: &[u8],
+ id_provider: &I,
+ extensions: &ExtensionList,
+ ) -> Result<Option<LeafIndex>, MlsError> {
+ for (i, leaf) in self.nodes.non_empty_leaves() {
+ let leaf_id = id_provider
+ .identity(&leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ if leaf_id == identity {
+ return Ok(Some(i));
+ }
+ }
+
+ Ok(None)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn derive<I: IdentityProvider>(
+ leaf_node: LeafNode,
+ secret_key: HpkeSecretKey,
+ identity_provider: &I,
+ extensions: &ExtensionList,
+ ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> {
+ let mut public_tree = TreeKemPublic::new();
+
+ public_tree
+ .add_leaf(leaf_node, identity_provider, extensions, None)
+ .await?;
+
+ let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key);
+
+ Ok((public_tree, private_tree))
+ }
+
+ pub fn total_leaf_count(&self) -> u32 {
+ self.nodes.total_leaf_count()
+ }
+
+ #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
+ pub fn occupied_leaf_count(&self) -> u32 {
+ self.nodes.occupied_leaf_count()
+ }
+
+ pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
+ self.nodes.borrow_as_leaf(index)
+ }
+
+ pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option<LeafIndex> {
+ self.nodes.non_empty_leaves().find_map(
+ |(index, node)| {
+ if node == leaf_node {
+ Some(index)
+ } else {
+ None
+ }
+ },
+ )
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool {
+ #[cfg(feature = "tree_index")]
+ return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count();
+
+ #[cfg(not(feature = "tree_index"))]
+ self.nodes
+ .non_empty_leaves()
+ .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type))
+ }
+
+ #[cfg(test)]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn add_leaves<I: IdentityProvider, CP: CipherSuiteProvider>(
+ &mut self,
+ leaf_nodes: Vec<LeafNode>,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<LeafIndex>, MlsError> {
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+
+ for leaf in leaf_nodes.into_iter() {
+ start = self
+ .add_leaf(leaf, id_provider, &Default::default(), Some(start))
+ .await?;
+ added.push(start);
+ }
+
+ self.update_hashes(&added, cipher_suite_provider).await?;
+
+ Ok(added)
+ }
+
+ pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
+ self.nodes.non_empty_leaves()
+ }
+
+ #[cfg(feature = "prior_epoch")]
+ pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
+ self.nodes.leaves()
+ }
+
+ pub(crate) fn update_node(
+ &mut self,
+ pub_key: crypto::HpkePublicKey,
+ index: NodeIndex,
+ ) -> Result<(), MlsError> {
+ self.nodes
+ .borrow_or_fill_node_as_parent(index, &pub_key)
+ .map(|p| {
+ p.public_key = pub_key;
+ p.unmerged_leaves = vec![];
+ })
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn apply_update_path<IP, CP>(
+ &mut self,
+ sender: LeafIndex,
+ update_path: &ValidatedUpdatePath,
+ extensions: &ExtensionList,
+ identity_provider: IP,
+ cipher_suite_provider: &CP,
+ ) -> Result<(), MlsError>
+ where
+ IP: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Install the new leaf node
+ let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?;
+
+ #[cfg(feature = "tree_index")]
+ let original_leaf_node = existing_leaf.clone();
+
+ #[cfg(feature = "tree_index")]
+ let original_identity = identity_provider
+ .identity(&original_leaf_node.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ *existing_leaf = update_path.leaf_node.clone();
+
+ // Update the rest of the nodes on the direct path
+ let path = self.nodes.direct_copath(sender);
+
+ for (node, pn) in update_path.nodes.iter().zip(path) {
+ node.as_ref()
+ .map(|n| self.update_node(n.public_key.clone(), pn.path))
+ .transpose()?;
+ }
+
+ #[cfg(feature = "tree_index")]
+ self.index.remove(&original_leaf_node, &original_identity);
+
+ index_insert(
+ #[cfg(feature = "tree_index")]
+ &mut self.index,
+ #[cfg(not(feature = "tree_index"))]
+ &self.nodes,
+ &update_path.leaf_node,
+ sender,
+ &identity_provider,
+ extensions,
+ )
+ .await?;
+
+ // Verify the parent hash of the new sender leaf node and update the parent hash values
+ // in the local tree
+ self.update_parent_hashes(sender, true, cipher_suite_provider)
+ .await?;
+
+ Ok(())
+ }
+
+ fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> {
+ // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf
+ self.nodes.direct_copath(index).into_iter().for_each(|i| {
+ if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) {
+ p.unmerged_leaves.push(index)
+ }
+ });
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn batch_edit<I, CP>(
+ &mut self,
+ proposal_bundle: &mut ProposalBundle,
+ extensions: &ExtensionList,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ filter: bool,
+ ) -> Result<Vec<LeafIndex>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Apply removes (they commute with updates because they don't touch the same leaves)
+ for i in (0..proposal_bundle.remove_proposals().len()).rev() {
+ let index = proposal_bundle.remove_proposals()[i].proposal.to_remove;
+ let res = self.nodes.blank_leaf_node(index);
+
+ if res.is_ok() {
+ // This shouldn't fail if `blank_leaf_node` succedded.
+ self.nodes.blank_direct_path(index)?;
+ }
+
+ #[cfg(feature = "tree_index")]
+ if let Ok(old_leaf) = &res {
+ // If this fails, it's not because the proposal is bad.
+ let identity =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ self.index.remove(old_leaf, &identity);
+ }
+
+ if proposal_bundle.remove_proposals()[i].is_by_value() || !filter {
+ res?;
+ } else if res.is_err() {
+ proposal_bundle.remove::<RemoveProposal>(i);
+ }
+ }
+
+ // Remove from the tree old leaves from updates
+ let mut partial_updates = vec![];
+ let senders = proposal_bundle.update_senders.iter().copied();
+
+ for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() {
+ let new_leaf = p.proposal.leaf_node.clone();
+
+ match self.nodes.blank_leaf_node(index) {
+ Ok(old_leaf) => {
+ #[cfg(feature = "tree_index")]
+ let old_id =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ #[cfg(feature = "tree_index")]
+ self.index.remove(&old_leaf, &old_id);
+
+ partial_updates.push((index, old_leaf, new_leaf, i));
+ }
+ _ => {
+ if !filter || !p.is_by_reference() {
+ return Err(MlsError::UpdatingNonExistingMember);
+ }
+ }
+ }
+ }
+
+ #[cfg(feature = "tree_index")]
+ let index_clone = self.index.clone();
+
+ let mut removed_leaves = vec![];
+ let mut updated_indices = vec![];
+ let mut bad_indices = vec![];
+
+ // Apply updates one by one. If there's an update which we can't apply or revert, we revert
+ // all updates.
+ for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() {
+ #[cfg(feature = "tree_index")]
+ let res =
+ index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await;
+
+ #[cfg(not(feature = "tree_index"))]
+ let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await;
+
+ let err = res.is_err();
+
+ if !filter {
+ res?;
+ }
+
+ if !err {
+ self.nodes.insert_leaf(index, new_leaf);
+ removed_leaves.push(old_leaf);
+ updated_indices.push(index);
+ } else {
+ #[cfg(feature = "tree_index")]
+ let res =
+ index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await;
+
+ #[cfg(not(feature = "tree_index"))]
+ let res =
+ index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await;
+
+ if res.is_ok() {
+ self.nodes.insert_leaf(index, old_leaf);
+ bad_indices.push(i);
+ } else {
+ // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error.
+ #[cfg(feature = "tree_index")]
+ {
+ self.index = index_clone;
+ }
+
+ removed_leaves
+ .into_iter()
+ .zip(updated_indices.iter())
+ .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf));
+
+ updated_indices = vec![];
+ break;
+ }
+ }
+ }
+
+ // If we managed to update something, blank direct paths
+ updated_indices
+ .iter()
+ .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?;
+
+ // Remove rejected updates from applied proposals
+ if updated_indices.is_empty() {
+ // This takes care of the "revert all" scenario
+ proposal_bundle.updates = vec![];
+ } else {
+ for i in bad_indices.into_iter().rev() {
+ proposal_bundle.remove::<UpdateProposal>(i);
+ proposal_bundle.update_senders.remove(i);
+ }
+ }
+
+ // Apply adds
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+ let mut bad_indexes = vec![];
+
+ for i in 0..proposal_bundle.additions.len() {
+ let leaf = proposal_bundle.additions[i]
+ .proposal
+ .key_package
+ .leaf_node
+ .clone();
+
+ let res = self
+ .add_leaf(leaf, id_provider, extensions, Some(start))
+ .await;
+
+ if let Ok(index) = res {
+ start = index;
+ added.push(start);
+ } else if proposal_bundle.additions[i].is_by_value() || !filter {
+ res?;
+ } else {
+ bad_indexes.push(i);
+ }
+ }
+
+ for i in bad_indexes.into_iter().rev() {
+ proposal_bundle.remove::<AddProposal>(i);
+ }
+
+ self.nodes.trim();
+
+ let updated_leaves = proposal_bundle
+ .remove_proposals()
+ .iter()
+ .map(|p| p.proposal.to_remove)
+ .chain(updated_indices)
+ .chain(added.iter().copied())
+ .collect_vec();
+
+ self.update_hashes(&updated_leaves, cipher_suite_provider)
+ .await?;
+
+ Ok(added)
+ }
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn batch_edit_lite<I, CP>(
+ &mut self,
+ proposal_bundle: &ProposalBundle,
+ extensions: &ExtensionList,
+ id_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<LeafIndex>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ // Apply removes
+ for p in &proposal_bundle.removals {
+ let index = p.proposal.to_remove;
+
+ #[cfg(feature = "tree_index")]
+ {
+ // If this fails, it's not because the proposal is bad.
+ let old_leaf = self.nodes.blank_leaf_node(index)?;
+
+ let identity =
+ identity(&old_leaf.signing_identity, id_provider, extensions).await?;
+
+ self.index.remove(&old_leaf, &identity);
+ }
+
+ #[cfg(not(feature = "tree_index"))]
+ self.nodes.blank_leaf_node(index)?;
+
+ self.nodes.blank_direct_path(index)?;
+ }
+
+ // Apply adds
+ let mut start = LeafIndex(0);
+ let mut added = vec![];
+
+ for p in &proposal_bundle.additions {
+ let leaf = p.proposal.key_package.leaf_node.clone();
+ start = self
+ .add_leaf(leaf, id_provider, extensions, Some(start))
+ .await?;
+ added.push(start);
+ }
+
+ self.nodes.trim();
+
+ let updated_leaves = proposal_bundle
+ .remove_proposals()
+ .iter()
+ .map(|p| p.proposal.to_remove)
+ .chain(added.iter().copied())
+ .collect_vec();
+
+ self.update_hashes(&updated_leaves, cipher_suite_provider)
+ .await?;
+
+ Ok(added)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn add_leaf<I: IdentityProvider>(
+ &mut self,
+ leaf: LeafNode,
+ id_provider: &I,
+ extensions: &ExtensionList,
+ start: Option<LeafIndex>,
+ ) -> Result<LeafIndex, MlsError> {
+ let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0)));
+
+ #[cfg(feature = "tree_index")]
+ index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?;
+
+ #[cfg(not(feature = "tree_index"))]
+ index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?;
+
+ self.nodes.insert_leaf(index, leaf);
+ self.update_unmerged(index)?;
+
+ Ok(index)
+ }
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn identity<I: IdentityProvider>(
+ signing_id: &SigningIdentity,
+ provider: &I,
+ extensions: &ExtensionList,
+) -> Result<Vec<u8>, MlsError> {
+ provider
+ .identity(signing_id, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))
+}
+
+#[cfg(feature = "std")]
+impl Display for TreeKemPublic {
+ fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
+ write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes))
+ }
+}
+
+#[cfg(test)]
+use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender};
+
+#[cfg(test)]
+impl TreeKemPublic {
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_leaf<I, CP>(
+ &mut self,
+ leaf_index: u32,
+ leaf_node: LeafNode,
+ identity_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<(), MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ let p = Proposal::Update(UpdateProposal { leaf_node });
+
+ let mut bundle = ProposalBundle::default();
+ bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue);
+ bundle.update_senders = vec![LeafIndex(leaf_index)];
+
+ self.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ true,
+ )
+ .await?;
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn remove_leaves<I, CP>(
+ &mut self,
+ indexes: Vec<LeafIndex>,
+ identity_provider: &I,
+ cipher_suite_provider: &CP,
+ ) -> Result<Vec<(LeafIndex, LeafNode)>, MlsError>
+ where
+ I: IdentityProvider,
+ CP: CipherSuiteProvider,
+ {
+ let old_tree = self.clone();
+
+ let proposals = indexes
+ .iter()
+ .copied()
+ .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove }));
+
+ let mut bundle = ProposalBundle::default();
+
+ for p in proposals {
+ bundle.add(p, Sender::Member(0), ProposalSource::ByValue);
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ self.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ true,
+ )
+ .await?;
+
+ #[cfg(not(feature = "by_ref_proposal"))]
+ self.batch_edit_lite(
+ &bundle,
+ &Default::default(),
+ identity_provider,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ bundle
+ .removals
+ .iter()
+ .map(|p| {
+ let index = p.proposal.to_remove;
+ let leaf = old_tree.get_leaf_node(index)?.clone();
+ Ok((index, leaf))
+ })
+ .collect()
+ }
+
+ pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> {
+ self.nodes.non_empty_leaves().map(|(_, l)| l).collect()
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use crate::crypto::test_utils::TestCryptoProvider;
+ use crate::signer::Signable;
+ use alloc::vec::Vec;
+ use alloc::{format, vec};
+ use mls_rs_core::crypto::CipherSuiteProvider;
+ use mls_rs_core::group::Capabilities;
+ use mls_rs_core::identity::BasicCredential;
+
+ use crate::identity::test_utils::get_test_signing_identity;
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::{HpkeSecretKey, SignatureSecretKey},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key,
+ };
+
+ use super::leaf_node::{ConfigProperties, LeafNodeSigningContext};
+ use super::node::LeafIndex;
+ use super::Lifetime;
+ use super::{
+ leaf_node::{test_utils::get_basic_test_node, LeafNode},
+ TreeKemPrivate, TreeKemPublic,
+ };
+
+ #[derive(Debug)]
+ pub(crate) struct TestTree {
+ pub public: TreeKemPublic,
+ pub private: TreeKemPrivate,
+ pub creator_leaf: LeafNode,
+ pub creator_signing_key: SignatureSecretKey,
+ pub creator_hpke_secret: HpkeSecretKey,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree {
+ let (creator_leaf, creator_hpke_secret, creator_signing_key) =
+ get_basic_test_node_sig_key(cipher_suite, "creator").await;
+
+ let (test_public, test_private) = TreeKemPublic::derive(
+ creator_leaf.clone(),
+ creator_hpke_secret.clone(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ TestTree {
+ public: test_public,
+ private: test_private,
+ creator_leaf,
+ creator_signing_key,
+ creator_hpke_secret,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec<LeafNode> {
+ [
+ get_basic_test_node(cipher_suite, "A").await,
+ get_basic_test_node(cipher_suite, "B").await,
+ get_basic_test_node(cipher_suite, "C").await,
+ ]
+ .to_vec()
+ }
+
+ impl TreeKemPublic {
+ #[cfg(feature = "tree_index")]
+ pub fn equal_internals(&self, other: &TreeKemPublic) -> bool {
+ self.tree_hashes == other.tree_hashes && self.index == other.index
+ }
+ }
+
+ #[derive(Debug, Clone)]
+ pub struct TreeWithSigners {
+ pub tree: TreeKemPublic,
+ pub signers: Vec<Option<SignatureSecretKey>>,
+ pub group_id: Vec<u8>,
+ }
+
+ impl TreeWithSigners {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn make_full_tree<P: CipherSuiteProvider>(
+ n_leaves: u32,
+ cs: &P,
+ ) -> TreeWithSigners {
+ let mut tree = TreeWithSigners {
+ tree: TreeKemPublic::new(),
+ signers: vec![],
+ group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(),
+ };
+
+ tree.add_member("Alice", cs).await;
+
+ // A adds B, B adds C, C adds D etc.
+ for i in 1..n_leaves {
+ tree.add_member(&format!("Alice{i}"), cs).await;
+ tree.update_committer_path(i - 1, cs).await;
+ }
+
+ tree
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn add_member<P: CipherSuiteProvider>(&mut self, name: &str, cs: &P) {
+ let (leaf, signer) = make_leaf(name, cs).await;
+ let index = self.tree.nodes.next_empty_leaf(LeafIndex(0));
+ self.tree.nodes.insert_leaf(index, leaf);
+ self.tree.update_unmerged(index).unwrap();
+ let index = *index as usize;
+
+ match self.signers.len() {
+ l if l == index => self.signers.push(Some(signer)),
+ l if l > index => self.signers[index] = Some(signer),
+ _ => panic!("signer tree size mismatch"),
+ }
+ }
+
+ #[cfg(feature = "rfc_compliant")]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ pub fn remove_member(&mut self, member: u32) {
+ self.tree
+ .nodes
+ .blank_direct_path(LeafIndex(member))
+ .unwrap();
+
+ self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap();
+
+ *self
+ .signers
+ .get_mut(member as usize)
+ .expect("signer tree size mismatch") = None;
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_committer_path<P: CipherSuiteProvider>(
+ &mut self,
+ committer: u32,
+ cs: &P,
+ ) {
+ let committer = LeafIndex(committer);
+
+ let path = self.tree.nodes.direct_copath(committer);
+ let filtered = self.tree.nodes.filtered(committer).unwrap();
+
+ for (n, f) in path.into_iter().zip(filtered) {
+ if !f {
+ self.tree
+ .update_node(cs.kem_generate().await.unwrap().1, n.path)
+ .unwrap();
+ }
+ }
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+
+ self.tree
+ .update_parent_hashes(committer, false, cs)
+ .await
+ .unwrap();
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+
+ let context = LeafNodeSigningContext {
+ group_id: Some(&self.group_id),
+ leaf_index: Some(*committer),
+ };
+
+ let signer = self.signers[*committer as usize].as_ref().unwrap();
+
+ self.tree
+ .nodes
+ .borrow_as_leaf_mut(committer)
+ .unwrap()
+ .sign(cs, signer, &context)
+ .await
+ .unwrap();
+
+ self.tree.tree_hashes.current = vec![];
+ self.tree.tree_hash(cs).await.unwrap();
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn make_leaf<P: CipherSuiteProvider>(
+ name: &str,
+ cs: &P,
+ ) -> (LeafNode, SignatureSecretKey) {
+ let (signing_identity, signature_key) =
+ get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await;
+
+ let capabilities = Capabilities {
+ credentials: vec![BasicCredential::credential_type()],
+ cipher_suites: TestCryptoProvider::all_supported_cipher_suites(),
+ ..Default::default()
+ };
+
+ let properties = ConfigProperties {
+ capabilities,
+ extensions: Default::default(),
+ };
+
+ let (leaf, _) = LeafNode::generate(
+ cs,
+ properties,
+ signing_identity,
+ &signature_key,
+ Lifetime::years(1).unwrap(),
+ )
+ .await
+ .unwrap();
+
+ (leaf, signature_key)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider};
+
+ #[cfg(feature = "custom_proposal")]
+ use crate::group::proposal::ProposalType;
+
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::tree_kem::leaf_node::LeafNode;
+ use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent};
+ use crate::tree_kem::parent_hash::ParentHash;
+ use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
+ use crate::tree_kem::{MlsError, TreeKemPublic};
+ use alloc::borrow::ToOwned;
+ use alloc::vec;
+ use alloc::vec::Vec;
+ use assert_matches::assert_matches;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use alloc::boxed::Box;
+
+ #[cfg(feature = "by_ref_proposal")]
+ use crate::{
+ client::test_utils::TEST_PROTOCOL_VERSION,
+ group::{
+ proposal::{Proposal, RemoveProposal, UpdateProposal},
+ proposal_filter::{ProposalBundle, ProposalSource},
+ proposal_ref::ProposalRef,
+ Sender,
+ },
+ key_package::test_utils::test_key_package,
+ };
+
+ #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))]
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_derive() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let test_tree = get_test_tree(cipher_suite).await;
+
+ assert_eq!(
+ test_tree.public.nodes[0],
+ Some(Node::Leaf(test_tree.creator_leaf.clone()))
+ );
+
+ assert_eq!(test_tree.private.self_index, LeafIndex(0));
+
+ assert_eq!(
+ test_tree.private.secret_keys[0],
+ Some(test_tree.creator_hpke_secret)
+ );
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_import_export() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await;
+
+ let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ test_tree
+ .public
+ .add_leaves(
+ additional_key_packages,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let imported = TreeKemPublic::import_node_data(
+ test_tree.public.nodes.clone(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(test_tree.public.nodes, imported.nodes);
+
+ #[cfg(feature = "tree_index")]
+ assert_eq!(test_tree.public.index, imported.index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let res = tree
+ .add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The leaf count should be equal to the number of packages we added
+ assert_eq!(res.len(), leaf_nodes.len());
+ assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32);
+
+ // Each added package should be at the proper index and searchable in the tree
+ res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| {
+ assert_eq!(tree.get_leaf_node(r).unwrap(), &kp);
+ });
+
+ // Verify the underlying state
+ #[cfg(feature = "tree_index")]
+ assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize);
+
+ assert_eq!(tree.nodes.len(), 5);
+ assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into());
+ assert_eq!(tree.nodes[1], None);
+ assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into());
+ assert_eq!(tree.nodes[3], None);
+ assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_key_packages() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let key_packages = tree.get_leaf_nodes();
+ assert_eq!(key_packages, key_packages.to_owned());
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_duplicate() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ key_packages.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let res = tree
+ .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await;
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(_)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_empty_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[0] = None; // Set the original first node to none
+ //
+ tree.add_leaves(
+ [key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(tree.nodes[0], key_packages[1].clone().into());
+ assert_eq!(tree.nodes[1], None);
+ assert_eq!(tree.nodes[2], key_packages[0].clone().into());
+ assert_eq!(tree.nodes.len(), 3)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_add_leaf_unmerged() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[3] = Parent {
+ public_key: vec![].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ tree.add_leaves(
+ [key_packages[2].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(
+ tree.nodes[3].as_parent().unwrap().unmerged_leaves,
+ vec![LeafIndex(3)]
+ )
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // Add in parent nodes so we can detect them clearing after update
+ tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
+ tree.nodes
+ .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into())
+ .unwrap();
+ });
+
+ let original_size = tree.occupied_leaf_count();
+ let original_leaf_index = LeafIndex(1);
+
+ let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await;
+
+ tree.update_leaf(
+ *original_leaf_index,
+ updated_leaf.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The tree should not have grown due to an update
+ assert_eq!(tree.occupied_leaf_count(), original_size);
+
+ // The cache of tree package indexes should not have grown
+ #[cfg(feature = "tree_index")]
+ assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count());
+
+ // The key package should be updated in the tree
+ assert_eq!(
+ tree.get_leaf_node(original_leaf_index).unwrap(),
+ &updated_leaf
+ );
+
+ // Verify that the direct path has been cleared
+ tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| {
+ assert!(tree.nodes[n.path as usize].is_none());
+ });
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf_not_found() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await;
+
+ let res = tree
+ .update_leaf(
+ 128,
+ new_key_package,
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::UpdatingNonExistingMember));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let indexes = tree
+ .add_leaves(
+ key_packages.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ // Remove two leaves from the tree
+ let expected_result: Vec<(LeafIndex, LeafNode)> =
+ indexes.clone().into_iter().zip(key_packages).collect();
+
+ let res = tree
+ .remove_leaves(
+ indexes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // The order may change
+ assert!(res.iter().all(|x| expected_result.contains(x)));
+ assert!(expected_result.iter().all(|x| res.contains(x)));
+
+ // The leaves should be removed from the tree
+ assert_eq!(
+ tree.occupied_leaf_count(),
+ original_leaf_count - indexes.len() as u32
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf_middle() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let to_remove = tree
+ .add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap()[0];
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ let res = tree
+ .remove_leaves(
+ vec![to_remove],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]);
+
+ // The leaf count should have been reduced by 1
+ assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
+
+ // There should be a blank in the tree
+ assert_eq!(
+ tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(),
+ &None
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_blanks() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let original_leaf_count = tree.occupied_leaf_count();
+
+ let to_remove = vec![LeafIndex(2)];
+
+ // Remove the leaf from the tree
+ tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ // The occupied leaf count should have been reduced by 1
+ assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1);
+
+ // The total leaf count should remain unchanged
+ assert_eq!(tree.total_leaf_count(), original_leaf_count);
+
+ // The location of key_packages[1] should now be blank
+ let removed_location = tree
+ .nodes
+ .get(NodeIndex::from(LeafIndex(2)) as usize)
+ .unwrap();
+
+ assert_eq!(removed_location, &None);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove_leaf_failure() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let res = tree
+ .remove_leaves(
+ vec![LeafIndex(128)],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::InvalidNodeIndex(256)));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_find_leaf_node() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ leaf_nodes.clone(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // Find each node
+ for (i, leaf_node) in leaf_nodes.iter().enumerate() {
+ let expected_index = LeafIndex(i as u32 + 1);
+ assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index));
+ }
+ }
+
+ // TODO add test for the lite version
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn batch_edit_works() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let mut bundle = ProposalBundle::default();
+
+ let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await;
+ let add = Proposal::Add(Box::new(kp.into()));
+
+ bundle.add(add, Sender::Member(0), ProposalSource::ByValue);
+
+ let update = UpdateProposal {
+ leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await,
+ };
+
+ let update = Proposal::Update(update);
+ let pref = ProposalRef::new_fake(vec![1, 2, 3]);
+
+ bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref));
+
+ bundle.update_senders = vec![LeafIndex(1)];
+
+ let remove = RemoveProposal {
+ to_remove: LeafIndex(2),
+ };
+
+ let remove = Proposal::Remove(remove);
+
+ bundle.add(remove, Sender::Member(0), ProposalSource::ByValue);
+
+ tree.batch_edit(
+ &mut bundle,
+ &Default::default(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ true,
+ )
+ .await
+ .unwrap();
+
+ assert_eq!(bundle.add_proposals().len(), 1);
+ assert_eq!(bundle.remove_proposals().len(), 1);
+ assert_eq!(bundle.update_proposals().len(), 1);
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposal_support() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut tree = TreeKemPublic::new();
+
+ let test_proposal_type = ProposalType::from(42);
+
+ let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ leaf_nodes
+ .iter_mut()
+ .for_each(|n| n.capabilities.proposals.push(test_proposal_type));
+
+ tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ assert!(tree.can_support_proposal(test_proposal_type));
+ assert!(!tree.can_support_proposal(ProposalType::from(43)));
+
+ let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await;
+
+ tree.add_leaves(
+ vec![test_node],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ assert!(!tree.can_support_proposal(test_proposal_type));
+ }
+}
diff --git a/src/tree_kem/node.rs b/src/tree_kem/node.rs
new file mode 100644
index 0000000..8b7372f
--- /dev/null
+++ b/src/tree_kem/node.rs
@@ -0,0 +1,577 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::LeafNode;
+use crate::client::MlsError;
+use crate::crypto::HpkePublicKey;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::parent_hash::ParentHash;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::hash::Hash;
+use core::ops::{Deref, DerefMut};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use tree_math::{CopathNode, TreeIndex};
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct Parent {
+ pub public_key: HpkePublicKey,
+ pub parent_hash: ParentHash,
+ pub unmerged_leaves: Vec<LeafIndex>,
+}
+
+#[derive(
+ Clone, Copy, Debug, Ord, PartialEq, PartialOrd, Hash, Eq, MlsSize, MlsEncode, MlsDecode,
+)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct LeafIndex(pub(crate) u32);
+
+impl LeafIndex {
+ pub fn new(i: u32) -> Self {
+ Self(i)
+ }
+}
+
+impl Deref for LeafIndex {
+ type Target = u32;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<&LeafIndex> for NodeIndex {
+ fn from(leaf_index: &LeafIndex) -> Self {
+ leaf_index.0 * 2
+ }
+}
+
+impl From<LeafIndex> for NodeIndex {
+ fn from(leaf_index: LeafIndex) -> Self {
+ leaf_index.0 * 2
+ }
+}
+
+pub(crate) type NodeIndex = u32;
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[allow(clippy::large_enum_variant)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[repr(u8)]
+//TODO: Research if this should actually be a Box<Leaf> for memory / performance reasons
+pub(crate) enum Node {
+ Leaf(LeafNode) = 1u8,
+ Parent(Parent) = 2u8,
+}
+
+impl Node {
+ pub fn public_key(&self) -> &HpkePublicKey {
+ match self {
+ Node::Parent(p) => &p.public_key,
+ Node::Leaf(l) => &l.public_key,
+ }
+ }
+}
+
+impl From<Parent> for Option<Node> {
+ fn from(p: Parent) -> Self {
+ Node::from(p).into()
+ }
+}
+
+impl From<LeafNode> for Option<Node> {
+ fn from(l: LeafNode) -> Self {
+ Node::from(l).into()
+ }
+}
+
+impl From<Parent> for Node {
+ fn from(p: Parent) -> Self {
+ Node::Parent(p)
+ }
+}
+
+impl From<LeafNode> for Node {
+ fn from(l: LeafNode) -> Self {
+ Node::Leaf(l)
+ }
+}
+
+pub(crate) trait NodeTypeResolver {
+ fn as_parent(&self) -> Result<&Parent, MlsError>;
+ fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError>;
+ fn as_leaf(&self) -> Result<&LeafNode, MlsError>;
+ fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError>;
+ fn as_non_empty(&self) -> Result<&Node, MlsError>;
+}
+
+impl NodeTypeResolver for Option<Node> {
+ fn as_parent(&self) -> Result<&Parent, MlsError> {
+ self.as_ref()
+ .and_then(|n| match n {
+ Node::Parent(p) => Some(p),
+ Node::Leaf(_) => None,
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_parent_mut(&mut self) -> Result<&mut Parent, MlsError> {
+ self.as_mut()
+ .and_then(|n| match n {
+ Node::Parent(p) => Some(p),
+ Node::Leaf(_) => None,
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_leaf(&self) -> Result<&LeafNode, MlsError> {
+ self.as_ref()
+ .and_then(|n| match n {
+ Node::Parent(_) => None,
+ Node::Leaf(l) => Some(l),
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_leaf_mut(&mut self) -> Result<&mut LeafNode, MlsError> {
+ self.as_mut()
+ .and_then(|n| match n {
+ Node::Parent(_) => None,
+ Node::Leaf(l) => Some(l),
+ })
+ .ok_or(MlsError::ExpectedNode)
+ }
+
+ fn as_non_empty(&self) -> Result<&Node, MlsError> {
+ self.as_ref().ok_or(MlsError::UnexpectedEmptyNode)
+ }
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode, Default)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct NodeVec(Vec<Option<Node>>);
+
+impl From<Vec<Option<Node>>> for NodeVec {
+ fn from(x: Vec<Option<Node>>) -> Self {
+ NodeVec(x)
+ }
+}
+
+impl Deref for NodeVec {
+ type Target = Vec<Option<Node>>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl DerefMut for NodeVec {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut self.0
+ }
+}
+
+impl NodeVec {
+ #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))]
+ pub fn occupied_leaf_count(&self) -> u32 {
+ self.non_empty_leaves().count() as u32
+ }
+
+ pub fn total_leaf_count(&self) -> u32 {
+ (self.len() as u32 / 2 + 1).next_power_of_two()
+ }
+
+ #[inline]
+ pub fn borrow_node(&self, index: NodeIndex) -> Result<&Option<Node>, MlsError> {
+ Ok(self.get(self.validate_index(index)?).unwrap_or(&None))
+ }
+
+ fn validate_index(&self, index: NodeIndex) -> Result<usize, MlsError> {
+ if (index as usize) >= self.len().next_power_of_two() {
+ Err(MlsError::InvalidNodeIndex(index))
+ } else {
+ Ok(index as usize)
+ }
+ }
+
+ #[cfg(test)]
+ fn empty_leaves(&mut self) -> impl Iterator<Item = (LeafIndex, &mut Option<Node>)> {
+ self.iter_mut()
+ .step_by(2)
+ .enumerate()
+ .filter(|(_, n)| n.is_none())
+ .map(|(i, n)| (LeafIndex(i as u32), n))
+ }
+
+ pub fn non_empty_leaves(&self) -> impl Iterator<Item = (LeafIndex, &LeafNode)> + '_ {
+ self.leaves()
+ .enumerate()
+ .filter_map(|(i, l)| l.map(|l| (LeafIndex(i as u32), l)))
+ }
+
+ pub fn non_empty_parents(&self) -> impl Iterator<Item = (NodeIndex, &Parent)> + '_ {
+ self.iter()
+ .enumerate()
+ .skip(1)
+ .step_by(2)
+ .map(|(i, n)| (i as NodeIndex, n))
+ .filter_map(|(i, n)| n.as_parent().ok().map(|p| (i, p)))
+ }
+
+ pub fn leaves(&self) -> impl Iterator<Item = Option<&LeafNode>> + '_ {
+ self.iter().step_by(2).map(|n| n.as_leaf().ok())
+ }
+
+ pub fn direct_copath(&self, index: LeafIndex) -> Vec<CopathNode<NodeIndex>> {
+ NodeIndex::from(index).direct_copath(&self.total_leaf_count())
+ }
+
+ // Section 8.4
+ // The filtered direct path of a node is obtained from the node's direct path by removing
+ // all nodes whose child on the nodes's copath has an empty resolution
+ pub fn filtered(&self, index: LeafIndex) -> Result<Vec<bool>, MlsError> {
+ Ok(NodeIndex::from(index)
+ .direct_copath(&self.total_leaf_count())
+ .into_iter()
+ .map(|cp| self.is_resolution_empty(cp.copath))
+ .collect())
+ }
+
+ #[inline]
+ pub fn is_blank(&self, index: NodeIndex) -> Result<bool, MlsError> {
+ self.borrow_node(index).map(|n| n.is_none())
+ }
+
+ #[inline]
+ pub fn is_leaf(&self, index: NodeIndex) -> bool {
+ index % 2 == 0
+ }
+
+ // Blank a previously filled leaf node, and return the existing leaf
+ pub fn blank_leaf_node(&mut self, leaf_index: LeafIndex) -> Result<LeafNode, MlsError> {
+ let node_index = self.validate_index(leaf_index.into())?;
+
+ match self.get_mut(node_index).and_then(Option::take) {
+ Some(Node::Leaf(l)) => Ok(l),
+ _ => Err(MlsError::RemovingNonExistingMember),
+ }
+ }
+
+ pub fn blank_direct_path(&mut self, leaf: LeafIndex) -> Result<(), MlsError> {
+ for i in self.direct_copath(leaf) {
+ if let Some(n) = self.get_mut(i.path as usize) {
+ *n = None
+ }
+ }
+
+ Ok(())
+ }
+
+ // Remove elements until the last node is non-blank
+ pub fn trim(&mut self) {
+ while self.last() == Some(&None) {
+ self.pop();
+ }
+ }
+
+ pub fn borrow_as_parent(&self, node_index: NodeIndex) -> Result<&Parent, MlsError> {
+ self.borrow_node(node_index).and_then(|n| n.as_parent())
+ }
+
+ pub fn borrow_as_parent_mut(&mut self, node_index: NodeIndex) -> Result<&mut Parent, MlsError> {
+ let index = self.validate_index(node_index)?;
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))?
+ .as_parent_mut()
+ }
+
+ pub fn borrow_as_leaf_mut(&mut self, index: LeafIndex) -> Result<&mut LeafNode, MlsError> {
+ let node_index = NodeIndex::from(index);
+ let index = self.validate_index(node_index)?;
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))?
+ .as_leaf_mut()
+ }
+
+ pub fn borrow_as_leaf(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> {
+ let node_index = NodeIndex::from(index);
+ self.borrow_node(node_index).and_then(|n| n.as_leaf())
+ }
+
+ pub fn borrow_or_fill_node_as_parent(
+ &mut self,
+ node_index: NodeIndex,
+ public_key: &HpkePublicKey,
+ ) -> Result<&mut Parent, MlsError> {
+ let index = self.validate_index(node_index)?;
+
+ while self.len() <= index {
+ self.push(None);
+ }
+
+ self.get_mut(index)
+ .ok_or(MlsError::InvalidNodeIndex(node_index))
+ .and_then(|n| {
+ if n.is_none() {
+ *n = Parent {
+ public_key: public_key.clone(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+ }
+ n.as_parent_mut()
+ })
+ }
+
+ pub fn get_resolution_index(&self, index: NodeIndex) -> Result<Vec<NodeIndex>, MlsError> {
+ let mut indexes = vec![index];
+ let mut resolution = vec![];
+
+ while let Some(index) = indexes.pop() {
+ if let Some(Some(node)) = self.get(index as usize) {
+ resolution.push(index);
+
+ if let Node::Parent(p) = node {
+ resolution.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
+ }
+ } else if !index.is_leaf() {
+ indexes.push(index.right_unchecked());
+ indexes.push(index.left_unchecked());
+ }
+ }
+
+ Ok(resolution)
+ }
+
+ pub fn find_in_resolution(
+ &self,
+ index: NodeIndex,
+ to_find: Option<NodeIndex>,
+ ) -> Option<usize> {
+ let mut indexes = vec![index];
+ let mut resolution_len = 0;
+
+ while let Some(index) = indexes.pop() {
+ if let Some(Some(node)) = self.get(index as usize) {
+ if Some(index) == to_find || to_find.is_none() {
+ return Some(resolution_len);
+ }
+
+ resolution_len += 1;
+
+ if let Node::Parent(p) = node {
+ indexes.extend(p.unmerged_leaves.iter().map(NodeIndex::from));
+ }
+ } else if !index.is_leaf() {
+ indexes.push(index.right_unchecked());
+ indexes.push(index.left_unchecked());
+ }
+ }
+
+ None
+ }
+
+ pub fn is_resolution_empty(&self, index: NodeIndex) -> bool {
+ self.find_in_resolution(index, None).is_none()
+ }
+
+ pub(crate) fn next_empty_leaf(&self, start: LeafIndex) -> LeafIndex {
+ let mut n = NodeIndex::from(start) as usize;
+
+ while n < self.len() {
+ if self.0[n].is_none() {
+ return LeafIndex((n as u32) >> 1);
+ }
+
+ n += 2;
+ }
+
+ LeafIndex((self.len() as u32 + 1) >> 1)
+ }
+
+ /// If `index` fits in the current tree, inserts `leaf` at `index`. Else, inserts `leaf` as the
+ /// last leaf
+ pub fn insert_leaf(&mut self, index: LeafIndex, leaf: LeafNode) {
+ let node_index = (*index as usize) << 1;
+
+ if node_index > self.len() {
+ self.push(None);
+ self.push(None);
+ } else if self.is_empty() {
+ self.push(None);
+ }
+
+ self.0[node_index] = Some(leaf.into());
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE, tree_kem::leaf_node::test_utils::get_basic_test_node,
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_node_vec() -> NodeVec {
+ let mut nodes = vec![None; 7];
+
+ nodes[0] = get_basic_test_node(TEST_CIPHER_SUITE, "A").await.into();
+ nodes[4] = get_basic_test_node(TEST_CIPHER_SUITE, "C").await.into();
+
+ nodes[5] = Parent {
+ public_key: b"CD".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![LeafIndex(2)],
+ }
+ .into();
+
+ nodes[6] = get_basic_test_node(TEST_CIPHER_SUITE, "D").await.into();
+
+ NodeVec::from(nodes)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ tree_kem::{
+ leaf_node::test_utils::get_basic_test_node, node::test_utils::get_test_node_vec,
+ },
+ };
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn node_key_getters() {
+ let test_node_parent: Node = Parent {
+ public_key: b"pub".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ let test_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "B").await;
+ let test_node_leaf: Node = test_leaf.clone().into();
+
+ assert_eq!(test_node_parent.public_key().as_ref(), b"pub");
+ assert_eq!(test_node_leaf.public_key(), &test_leaf.public_key);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_empty_leaves() {
+ let mut test_vec = get_test_node_vec().await;
+ let mut test_vec_clone = get_test_node_vec().await;
+ let empty_leaves: Vec<(LeafIndex, &mut Option<Node>)> = test_vec.empty_leaves().collect();
+ assert_eq!(
+ [(LeafIndex(1), &mut test_vec_clone[2])].as_ref(),
+ empty_leaves.as_slice()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_direct_path() {
+ let test_vec = get_test_node_vec().await;
+ // Tree math is already tested in that module, just ensure equality
+ let expected = 0.direct_copath(&4);
+ let actual = test_vec.direct_copath(LeafIndex(0));
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_filtered_direct_path_co_path() {
+ let test_vec = get_test_node_vec().await;
+ let expected = [true, false];
+ let actual = test_vec.filtered(LeafIndex(0)).unwrap();
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_parent_node() {
+ let mut test_vec = get_test_node_vec().await;
+
+ // If the node is a leaf it should fail
+ assert!(test_vec.borrow_as_parent_mut(0).is_err());
+
+ // If the node index is out of range it should fail
+ assert!(test_vec
+ .borrow_as_parent_mut(test_vec.len() as u32)
+ .is_err());
+
+ // Otherwise it should succeed
+ let mut expected = Parent {
+ public_key: b"CD".to_vec().into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![LeafIndex(2)],
+ };
+
+ assert_eq!(test_vec.borrow_as_parent_mut(5).unwrap(), &mut expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_resolution() {
+ let test_vec = get_test_node_vec().await;
+
+ let resolution_node_5 = test_vec.get_resolution_index(5).unwrap();
+ let resolution_node_2 = test_vec.get_resolution_index(2).unwrap();
+ let resolution_node_3 = test_vec.get_resolution_index(3).unwrap();
+
+ assert_eq!(&resolution_node_5, &[5, 4]);
+ assert!(resolution_node_2.is_empty());
+ assert_eq!(&resolution_node_3, &[0, 5, 4]);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_or_fill_existing() {
+ let mut test_vec = get_test_node_vec().await;
+ let mut test_vec2 = test_vec.clone();
+
+ let expected = test_vec[5].as_parent_mut().unwrap();
+ let actual = test_vec2
+ .borrow_or_fill_node_as_parent(5, &Vec::new().into())
+ .unwrap();
+
+ assert_eq!(actual, expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_get_or_fill_empty() {
+ let mut test_vec = get_test_node_vec().await;
+
+ let mut expected = Parent {
+ public_key: vec![0u8; 4].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ };
+
+ let actual = test_vec
+ .borrow_or_fill_node_as_parent(1, &vec![0u8; 4].into())
+ .unwrap();
+
+ assert_eq!(actual, &mut expected);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_leaf_count() {
+ let test_vec = get_test_node_vec().await;
+ assert_eq!(test_vec.len(), 7);
+ assert_eq!(test_vec.occupied_leaf_count(), 3);
+ assert_eq!(
+ test_vec.non_empty_leaves().count(),
+ test_vec.occupied_leaf_count() as usize
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_total_leaf_count() {
+ let test_vec = get_test_node_vec().await;
+ assert_eq!(test_vec.occupied_leaf_count(), 3);
+ assert_eq!(test_vec.total_leaf_count(), 4);
+ }
+}
diff --git a/src/tree_kem/parent_hash.rs b/src/tree_kem/parent_hash.rs
new file mode 100644
index 0000000..f04157a
--- /dev/null
+++ b/src/tree_kem/parent_hash.rs
@@ -0,0 +1,431 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey};
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::node::{LeafIndex, Node, NodeIndex};
+use crate::tree_kem::TreeKemPublic;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use tree_math::TreeIndex;
+
+use super::leaf_node::LeafNodeSource;
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+#[cfg(not(feature = "std"))]
+use alloc::collections::BTreeSet;
+
+#[derive(Clone, Debug, MlsSize, MlsEncode)]
+struct ParentHashInput<'a> {
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ public_key: &'a HpkePublicKey,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ parent_hash: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ original_sibling_tree_hash: &'a [u8],
+}
+
+#[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct ParentHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for ParentHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("ParentHash")
+ .fmt(f)
+ }
+}
+
+impl From<Vec<u8>> for ParentHash {
+ fn from(v: Vec<u8>) -> Self {
+ Self(v)
+ }
+}
+
+impl Deref for ParentHash {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl ParentHash {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn new<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ public_key: &HpkePublicKey,
+ parent_hash: &ParentHash,
+ original_sibling_tree_hash: &[u8],
+ ) -> Result<Self, MlsError> {
+ let input = ParentHashInput {
+ public_key,
+ parent_hash,
+ original_sibling_tree_hash,
+ };
+
+ let input_bytes = input.mls_encode_to_vec()?;
+
+ let hash = cipher_suite_provider
+ .hash(&input_bytes)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
+
+ Ok(Self(hash))
+ }
+
+ pub fn empty() -> Self {
+ ParentHash(Vec::new())
+ }
+
+ pub fn matches(&self, hash: &ParentHash) -> bool {
+ //TODO: Constant time equals
+ hash == self
+ }
+}
+
+impl Node {
+ fn get_parent_hash(&self) -> Option<ParentHash> {
+ match self {
+ Node::Parent(p) => Some(p.parent_hash.clone()),
+ Node::Leaf(l) => match &l.leaf_node_source {
+ LeafNodeSource::Commit(parent_hash) => Some(parent_hash.clone()),
+ _ => None,
+ },
+ }
+ }
+}
+
+impl TreeKemPublic {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn parent_hash_for_leaf<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ index: LeafIndex,
+ ) -> Result<ParentHash, MlsError> {
+ let mut hash = ParentHash::empty();
+
+ for node in self.nodes.direct_copath(index).into_iter().rev() {
+ if self.nodes.is_resolution_empty(node.copath) {
+ continue;
+ }
+
+ let parent = self.nodes.borrow_as_parent_mut(node.path)?;
+
+ let calculated = ParentHash::new(
+ cipher_suite_provider,
+ &parent.public_key,
+ &hash,
+ &self.tree_hashes.current[node.copath as usize],
+ )
+ .await?;
+
+ (parent.parent_hash, hash) = (hash, calculated);
+ }
+
+ Ok(hash)
+ }
+
+ // Updates all of the required parent hash values, and returns the calculated parent hash value for the leaf node
+ // If an update path is provided, additionally verify that the calculated parent hash matches
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn update_parent_hashes<P: CipherSuiteProvider>(
+ &mut self,
+ index: LeafIndex,
+ verify_leaf_hash: bool,
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ // First update the relevant original hashes used for parent hash computation.
+ self.update_hashes(&[index], cipher_suite_provider).await?;
+
+ let leaf_hash = self
+ .parent_hash_for_leaf(cipher_suite_provider, index)
+ .await?;
+
+ let leaf = self.nodes.borrow_as_leaf_mut(index)?;
+
+ if verify_leaf_hash {
+ // Verify the parent hash of the new sender leaf node and update the parent hash values
+ // in the local tree
+ if let LeafNodeSource::Commit(parent_hash) = &leaf.leaf_node_source {
+ if !leaf_hash.matches(parent_hash) {
+ return Err(MlsError::ParentHashMismatch);
+ }
+ } else {
+ return Err(MlsError::InvalidLeafNodeSource);
+ }
+ } else {
+ leaf.leaf_node_source = LeafNodeSource::Commit(leaf_hash);
+ }
+
+ // Update hashes after changes to the tree.
+ self.update_hashes(&[index], cipher_suite_provider).await
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(super) async fn validate_parent_hashes<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ let original_hashes = self.compute_original_hashes(cipher_suite_provider).await?;
+
+ let nodes_to_validate = self
+ .nodes
+ .non_empty_parents()
+ .map(|(node_index, _)| node_index);
+
+ #[cfg(feature = "std")]
+ let mut nodes_to_validate = nodes_to_validate.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let mut nodes_to_validate = nodes_to_validate.collect::<BTreeSet<_>>();
+
+ let num_leaves = self.total_leaf_count();
+
+ // For each leaf l, validate all non-blank nodes on the chain from l up the tree.
+ for (leaf_index, _) in self.nodes.non_empty_leaves() {
+ let mut n = NodeIndex::from(leaf_index);
+
+ while let Some(mut ps) = n.parent_sibling(&num_leaves) {
+ // Find the first non-blank ancestor p of n and p's co-path child s.
+ while self.nodes.is_blank(ps.parent)? {
+ // If we reached the root, we're done with this chain.
+ let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
+ return Ok(());
+ };
+
+ ps = ps_parent;
+ }
+
+ // Check is n's parent_hash field matches the parent hash of p with co-path child s.
+ let p_parent = self.nodes.borrow_as_parent(ps.parent)?;
+
+ let n_node = self
+ .nodes
+ .borrow_node(n)?
+ .as_ref()
+ .ok_or(MlsError::ExpectedNode)?;
+
+ let calculated = ParentHash::new(
+ cipher_suite_provider,
+ &p_parent.public_key,
+ &p_parent.parent_hash,
+ &original_hashes[ps.sibling as usize],
+ )
+ .await?;
+
+ if n_node.get_parent_hash() == Some(calculated) {
+ // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
+ // under c is equal to the resolution of c with n removed".
+ let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
+ return Err(MlsError::ParentHashMismatch);
+ };
+
+ let c = cp.sibling;
+ let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();
+
+ #[cfg(feature = "std")]
+ let mut c_resolution = c_resolution.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();
+
+ let p_unmerged_in_c_subtree = self
+ .unmerged_in_subtree(ps.parent, c)?
+ .iter()
+ .copied()
+ .map(|x| *x * 2);
+
+ #[cfg(feature = "std")]
+ let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>();
+ #[cfg(not(feature = "std"))]
+ let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>();
+
+ if c_resolution.remove(&n)
+ && c_resolution == p_unmerged_in_c_subtree
+ && nodes_to_validate.remove(&ps.parent)
+ {
+ // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
+ n = ps.parent;
+ } else {
+ // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
+ return Err(MlsError::ParentHashMismatch);
+ }
+ } else {
+ // If n's parent_hash field doesn't match, we're done with this chain.
+ break;
+ }
+ }
+ }
+
+ // The check passes iff all non-blank nodes are validated.
+ if nodes_to_validate.is_empty() {
+ Ok(())
+ } else {
+ Err(MlsError::ParentHashMismatch)
+ }
+ }
+}
+
+#[cfg(test)]
+pub(crate) mod test_utils {
+
+ use super::*;
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::test_cipher_suite_provider,
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{leaf_node::test_utils::get_basic_test_node, node::Parent},
+ };
+
+ use alloc::vec;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_parent(
+ cipher_suite: CipherSuite,
+ unmerged_leaves: Vec<LeafIndex>,
+ ) -> Parent {
+ let (_, public_key) = test_cipher_suite_provider(cipher_suite)
+ .kem_generate()
+ .await
+ .unwrap();
+
+ Parent {
+ public_key,
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn test_parent_node(
+ cipher_suite: CipherSuite,
+ unmerged_leaves: Vec<LeafIndex>,
+ ) -> Node {
+ Node::Parent(test_parent(cipher_suite, unmerged_leaves).await)
+ }
+
+ // Create figure 12 from MLS RFC
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut tree = TreeKemPublic::new();
+
+ let mut leaves = Vec::new();
+
+ for l in ["A", "B", "C", "D", "E", "F", "G"] {
+ leaves.push(get_basic_test_node(cipher_suite, l).await);
+ }
+
+ tree.add_leaves(leaves, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree.nodes[1] = Some(test_parent_node(cipher_suite, vec![]).await);
+ tree.nodes[3] = Some(test_parent_node(cipher_suite, vec![LeafIndex(3)]).await);
+
+ tree.nodes[7] =
+ Some(test_parent_node(cipher_suite, vec![LeafIndex(3), LeafIndex(6)]).await);
+
+ tree.nodes[9] = Some(test_parent_node(cipher_suite, vec![LeafIndex(5)]).await);
+
+ tree.nodes[11] =
+ Some(test_parent_node(cipher_suite, vec![LeafIndex(5), LeafIndex(6)]).await);
+
+ tree.update_parent_hashes(LeafIndex(0), false, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree.update_parent_hashes(LeafIndex(4), false, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ tree
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
+ use crate::tree_kem::leaf_node::LeafNodeSource;
+ use crate::tree_kem::test_utils::TreeWithSigners;
+ use crate::tree_kem::MlsError;
+ use assert_matches::assert_matches;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_missing_parent_hash() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ *test_tree.nodes.borrow_as_leaf_mut(LeafIndex(0)).unwrap() =
+ get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+
+ let missing_parent_hash_res = test_tree
+ .update_parent_hashes(
+ LeafIndex(0),
+ true,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await;
+
+ assert_matches!(
+ missing_parent_hash_res,
+ Err(MlsError::InvalidLeafNodeSource)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_mismatch() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ let unexpected_parent_hash = ParentHash::from(hex!("f00d"));
+
+ test_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .leaf_node_source = LeafNodeSource::Commit(unexpected_parent_hash);
+
+ let invalid_parent_hash_res = test_tree
+ .update_parent_hashes(
+ LeafIndex(0),
+ true,
+ &test_cipher_suite_provider(TEST_CIPHER_SUITE),
+ )
+ .await;
+
+ assert_matches!(invalid_parent_hash_res, Err(MlsError::ParentHashMismatch));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_invalid() {
+ let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
+
+ test_tree.nodes[2] = None;
+
+ let res = test_tree
+ .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
+ .await;
+
+ assert_matches!(res, Err(MlsError::ParentHashMismatch));
+ }
+}
diff --git a/src/tree_kem/path_secret.rs b/src/tree_kem/path_secret.rs
new file mode 100644
index 0000000..c9fce76
--- /dev/null
+++ b/src/tree_kem/path_secret.rs
@@ -0,0 +1,265 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use crate::client::MlsError;
+use crate::crypto::{CipherSuiteProvider, HpkePublicKey, HpkeSecretKey};
+use crate::group::key_schedule::kdf_derive_secret;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::{
+ fmt::{self, Debug},
+ ops::Deref,
+};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use zeroize::Zeroizing;
+
+use super::hpke_encryption::HpkeEncryptable;
+
+#[derive(Clone, Eq, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct PathSecret(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::zeroizing_serde"))]
+ Zeroizing<Vec<u8>>,
+);
+
+impl Debug for PathSecret {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("PathSecret")
+ .fmt(f)
+ }
+}
+
+impl Deref for PathSecret {
+ type Target = Vec<u8>;
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+impl From<Vec<u8>> for PathSecret {
+ fn from(data: Vec<u8>) -> Self {
+ PathSecret(Zeroizing::new(data))
+ }
+}
+
+impl From<Zeroizing<Vec<u8>>> for PathSecret {
+ fn from(data: Zeroizing<Vec<u8>>) -> Self {
+ PathSecret(data)
+ }
+}
+
+impl PathSecret {
+ pub fn random<P: CipherSuiteProvider>(
+ cipher_suite_provider: &P,
+ ) -> Result<PathSecret, MlsError> {
+ cipher_suite_provider
+ .random_bytes_vec(cipher_suite_provider.kdf_extract_size())
+ .map(Into::into)
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+
+ pub fn empty<P: CipherSuiteProvider>(cipher_suite_provider: &P) -> Self {
+ // Define commit_secret as the all-zero vector of the same length as a path_secret
+ PathSecret::from(vec![0u8; cipher_suite_provider.kdf_extract_size()])
+ }
+}
+
+impl HpkeEncryptable for PathSecret {
+ const ENCRYPT_LABEL: &'static str = "UpdatePathNode";
+
+ fn from_bytes(bytes: Vec<u8>) -> Result<Self, MlsError> {
+ Ok(Self(Zeroizing::new(bytes)))
+ }
+
+ fn get_bytes(&self) -> Result<Vec<u8>, MlsError> {
+ Ok(self.to_vec())
+ }
+}
+
+impl PathSecret {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn to_hpke_key_pair<P: CipherSuiteProvider>(
+ &self,
+ cs: &P,
+ ) -> Result<(HpkeSecretKey, HpkePublicKey), MlsError> {
+ let node_secret = Zeroizing::new(kdf_derive_secret(cs, self, b"node").await?);
+
+ cs.kem_derive(&node_secret)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+ }
+}
+
+#[derive(Clone, Debug)]
+pub struct PathSecretGenerator<'a, P> {
+ cipher_suite_provider: &'a P,
+ last: Option<PathSecret>,
+ starting_with: Option<PathSecret>,
+}
+
+impl<'a, P: CipherSuiteProvider> PathSecretGenerator<'a, P> {
+ pub fn new(cipher_suite_provider: &'a P) -> Self {
+ Self {
+ cipher_suite_provider,
+ last: None,
+ starting_with: None,
+ }
+ }
+
+ pub fn starting_with(cipher_suite_provider: &'a P, secret: PathSecret) -> Self {
+ Self {
+ starting_with: Some(secret),
+ ..Self::new(cipher_suite_provider)
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn next_secret(&mut self) -> Result<PathSecret, MlsError> {
+ let secret = if let Some(starting_with) = self.starting_with.take() {
+ Ok(starting_with)
+ } else if let Some(last) = self.last.take() {
+ kdf_derive_secret(self.cipher_suite_provider, &last, b"path")
+ .await
+ .map(PathSecret::from)
+ } else {
+ PathSecret::random(self.cipher_suite_provider)
+ }?;
+
+ self.last = Some(secret.clone());
+
+ Ok(secret)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::{
+ test_cipher_suite_provider, try_test_cipher_suite_provider, TestCryptoProvider,
+ },
+ };
+
+ use super::*;
+
+ use alloc::string::String;
+
+ #[cfg(target_arch = "wasm32")]
+ use wasm_bindgen_test::wasm_bindgen_test as test;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ generations: Vec<String>,
+ }
+
+ impl TestCase {
+ #[cfg(not(mls_build_async))]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ fn generate() -> Vec<TestCase> {
+ CipherSuite::all()
+ .map(
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ |cipher_suite| {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let mut generator = PathSecretGenerator::new(&cs_provider);
+
+ let generations = (0..10)
+ .map(|_| hex::encode(&*generator.next_secret().unwrap()))
+ .collect();
+
+ TestCase {
+ cipher_suite: cipher_suite.into(),
+ generations,
+ }
+ },
+ )
+ .collect()
+ }
+
+ #[cfg(mls_build_async)]
+ fn generate() -> Vec<TestCase> {
+ panic!("Tests cannot be generated in async mode");
+ }
+ }
+
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(path_secret, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_path_secret_generation() {
+ let cases = load_test_cases();
+
+ for test_case in cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(test_case.cipher_suite) else {
+ continue;
+ };
+
+ let first_secret = PathSecret::from(hex::decode(&test_case.generations[0]).unwrap());
+ let mut generator = PathSecretGenerator::starting_with(&cs_provider, first_secret);
+
+ for expected in &test_case.generations {
+ let generated = hex::encode(&*generator.next_secret().await.unwrap());
+ assert_eq!(expected, &generated);
+ }
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_first_path_is_random() {
+ let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ let mut generator = PathSecretGenerator::new(&cs_provider);
+ let first_secret = generator.next_secret().await.unwrap();
+
+ for _ in 0..100 {
+ let mut next_generator = PathSecretGenerator::new(&cs_provider);
+ let next_secret = next_generator.next_secret().await.unwrap();
+ assert_ne!(first_secret, next_secret);
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_starting_with() {
+ let cs_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let secret = PathSecret::random(&cs_provider).unwrap();
+
+ let mut generator = PathSecretGenerator::starting_with(&cs_provider, secret.clone());
+
+ let first_secret = generator.next_secret().await.unwrap();
+ let second_secret = generator.next_secret().await.unwrap();
+
+ assert_eq!(secret, first_secret);
+ assert_ne!(first_secret, second_secret);
+ }
+
+ #[test]
+ fn test_empty_path_secret() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cs_provider = test_cipher_suite_provider(cipher_suite);
+ let empty = PathSecret::empty(&cs_provider);
+ assert_eq!(
+ empty,
+ PathSecret::from(vec![0u8; cs_provider.kdf_extract_size()])
+ )
+ }
+ }
+
+ #[test]
+ fn test_random_path_secret() {
+ let cs_provider = test_cipher_suite_provider(CipherSuite::P256_AES128);
+ let initial = PathSecret::random(&cs_provider).unwrap();
+
+ for _ in 0..100 {
+ let next = PathSecret::random(&cs_provider).unwrap();
+ assert_ne!(next, initial);
+ }
+ }
+}
diff --git a/src/tree_kem/private.rs b/src/tree_kem/private.rs
new file mode 100644
index 0000000..1cc72ee
--- /dev/null
+++ b/src/tree_kem/private.rs
@@ -0,0 +1,310 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+use alloc::{vec, vec::Vec};
+
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::crypto::HpkeSecretKey;
+
+use crate::{client::MlsError, crypto::CipherSuiteProvider};
+
+use super::{
+ math::leaf_lca_level,
+ node::LeafIndex,
+ path_secret::{PathSecret, PathSecretGenerator},
+ TreeKemPublic,
+};
+
+#[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Eq, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+#[non_exhaustive]
+pub struct TreeKemPrivate {
+ pub self_index: LeafIndex,
+ pub secret_keys: Vec<Option<HpkeSecretKey>>,
+}
+
+impl TreeKemPrivate {
+ pub fn new_self_leaf(self_index: LeafIndex, leaf_secret: HpkeSecretKey) -> Self {
+ TreeKemPrivate {
+ self_index,
+ secret_keys: vec![Some(leaf_secret)],
+ }
+ }
+
+ pub fn new_for_external() -> Self {
+ TreeKemPrivate {
+ self_index: LeafIndex(0),
+ secret_keys: Default::default(),
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_secrets<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ signer_index: LeafIndex,
+ path_secret: PathSecret,
+ public_tree: &TreeKemPublic,
+ ) -> Result<(), MlsError> {
+ // Identify the lowest common
+ // ancestor of the leaves at index and at GroupInfo.signer_index. Set the private key
+ // for this node to the private key derived from the path_secret.
+ let lca_index = leaf_lca_level(self.self_index.into(), signer_index.into()) as usize - 2;
+
+ // For each parent of the common ancestor, up to the root of the tree, derive a new
+ // path secret and set the private key for the node to the private key derived from the
+ // path secret. The private key MUST be the private key that corresponds to the public
+ // key in the node.
+
+ let mut node_secret_gen =
+ PathSecretGenerator::starting_with(cipher_suite_provider, path_secret);
+
+ let path = public_tree.nodes.direct_copath(self.self_index);
+ let filtered = &public_tree.nodes.filtered(self.self_index)?;
+ self.secret_keys.resize(path.len() + 1, None);
+
+ for (i, (n, f)) in path.iter().zip(filtered).enumerate().skip(lca_index) {
+ if *f {
+ continue;
+ }
+
+ let secret = node_secret_gen.next_secret().await?;
+
+ let expected_pub_key = public_tree
+ .nodes
+ .borrow_node(n.path)?
+ .as_ref()
+ .map(|n| n.public_key())
+ .ok_or(MlsError::PubKeyMismatch)?;
+
+ let (secret_key, public_key) = secret.to_hpke_key_pair(cipher_suite_provider).await?;
+
+ if expected_pub_key != &public_key {
+ return Err(MlsError::PubKeyMismatch);
+ }
+
+ // It's ok to use index directly because of the resize above
+ self.secret_keys[i + 1] = Some(secret_key);
+ }
+
+ Ok(())
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ pub fn update_leaf(&mut self, new_leaf: HpkeSecretKey) {
+ self.secret_keys = vec![None; self.secret_keys.len()];
+ self.secret_keys[0] = Some(new_leaf);
+ }
+}
+
+#[cfg(test)]
+impl TreeKemPrivate {
+ pub fn new(self_index: LeafIndex) -> Self {
+ TreeKemPrivate {
+ self_index,
+ secret_keys: Default::default(),
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use assert_matches::assert_matches;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ kem::TreeKem,
+ leaf_node::test_utils::{
+ default_properties, get_basic_test_node, get_basic_test_node_sig_key,
+ },
+ math::TreeIndex,
+ node::LeafIndex,
+ },
+ };
+
+ use super::*;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn random_hpke_secret_key() -> HpkeSecretKey {
+ let (secret, _) = test_cipher_suite_provider(TEST_CIPHER_SUITE)
+ .kem_derive(&random_bytes(32))
+ .await
+ .unwrap();
+
+ secret
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_create_self_leaf() {
+ let secret = random_hpke_secret_key().await;
+
+ let self_index = LeafIndex(42);
+
+ let private_key = TreeKemPrivate::new_self_leaf(self_index, secret.clone());
+
+ assert_eq!(private_key.self_index, self_index);
+ assert_eq!(private_key.secret_keys.len(), 1);
+ assert_eq!(private_key.secret_keys[0].as_ref().unwrap(), &secret)
+ }
+
+ // Create a ratchet tree for Alice, Bob and Charlie. Alice generates an update path for
+ // Charlie. Return (Public Tree, Charlie's private key, update path, path secret)
+ // The ratchet tree returned has leaf indexes as [alice, bob, charlie]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn update_secrets_setup(
+ cipher_suite: CipherSuite,
+ ) -> (TreeKemPublic, TreeKemPrivate, TreeKemPrivate, PathSecret) {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let (alice_leaf, alice_hpke_secret, alice_signing) =
+ get_basic_test_node_sig_key(cipher_suite, "alice").await;
+
+ let bob_leaf = get_basic_test_node(cipher_suite, "bob").await;
+
+ let (charlie_leaf, charlie_hpke_secret, _charlie_signing) =
+ get_basic_test_node_sig_key(cipher_suite, "charlie").await;
+
+ // Create a new public tree with Alice
+ let (mut public_tree, mut alice_private) = TreeKemPublic::derive(
+ alice_leaf,
+ alice_hpke_secret,
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Add bob and charlie to the tree
+ public_tree
+ .add_leaves(
+ vec![bob_leaf, charlie_leaf],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ // Alice's secret key is longer now
+ alice_private.secret_keys.resize(3, None);
+
+ // Generate an update path for Alice
+ let encap_gen = TreeKem::new(&mut public_tree, &mut alice_private)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[],
+ &alice_signing,
+ default_properties(),
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ // Get a path secret from Alice for Charlie
+ let path_secret = encap_gen.path_secrets[1].clone().unwrap();
+
+ // Private key for Charlie
+ let charlie_private = TreeKemPrivate::new_self_leaf(LeafIndex(2), charlie_hpke_secret);
+
+ (public_tree, charlie_private, alice_private, path_secret)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_secrets() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (public_tree, mut charlie_private, alice_private, path_secret) =
+ update_secrets_setup(cipher_suite).await;
+
+ let existing_private = charlie_private.secret_keys.first().cloned().unwrap();
+
+ // Add the secrets for Charlie to his private key
+ charlie_private
+ .update_secrets(
+ &test_cipher_suite_provider(cipher_suite),
+ LeafIndex(0),
+ path_secret,
+ &public_tree,
+ )
+ .await
+ .unwrap();
+
+ // Make sure that Charlie's private key didn't lose keys
+ assert_eq!(charlie_private.secret_keys.len(), 3);
+
+ // Check that the intersection of the secret keys of Alice and Charlie matches.
+ // The intersection contains only the root.
+ assert_eq!(alice_private.secret_keys[2], charlie_private.secret_keys[2]);
+
+ assert_eq!(
+ charlie_private.secret_keys[0].as_ref(),
+ existing_private.as_ref()
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_secrets_key_mismatch() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+
+ let (mut public_tree, mut charlie_private, _, path_secret) =
+ update_secrets_setup(cipher_suite).await;
+
+ // Sabotage the public tree
+ public_tree
+ .nodes
+ .borrow_as_parent_mut(public_tree.total_leaf_count().root())
+ .unwrap()
+ .public_key = random_bytes(32).into();
+
+ // Add the secrets for Charlie to his private key
+ let res = charlie_private
+ .update_secrets(
+ &test_cipher_suite_provider(cipher_suite),
+ LeafIndex(0),
+ path_secret,
+ &public_tree,
+ )
+ .await;
+
+ assert_matches!(res, Err(MlsError::PubKeyMismatch));
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn setup_direct_path(self_index: LeafIndex, leaf_count: u32) -> TreeKemPrivate {
+ let secret = random_hpke_secret_key().await;
+
+ let mut private_key = TreeKemPrivate::new_self_leaf(self_index, secret.clone());
+
+ private_key.secret_keys = (0..0.direct_copath(&leaf_count).len() + 1)
+ .map(|_| Some(secret.clone()))
+ .collect();
+
+ private_key
+ }
+
+ #[cfg(feature = "by_ref_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_update_leaf() {
+ let self_leaf = LeafIndex(42);
+ let mut private_key = setup_direct_path(self_leaf, 128).await;
+
+ let new_secret = random_hpke_secret_key().await;
+
+ private_key.update_leaf(new_secret.clone());
+
+ // The update operation should have removed all the other keys in our direct path we
+ // previously added
+ assert!(private_key.secret_keys.iter().skip(1).all(|n| n.is_none()));
+
+ // The secret key for our leaf should have been updated accordingly
+ assert_eq!(private_key.secret_keys.first().unwrap(), &Some(new_secret));
+ }
+}
diff --git a/src/tree_kem/tree_hash.rs b/src/tree_kem/tree_hash.rs
new file mode 100644
index 0000000..d9115e3
--- /dev/null
+++ b/src/tree_kem/tree_hash.rs
@@ -0,0 +1,432 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::leaf_node::LeafNode;
+use super::node::{LeafIndex, NodeVec};
+use super::tree_math::BfsIterTopDown;
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::node::Parent;
+use crate::tree_kem::TreeKemPublic;
+use alloc::collections::VecDeque;
+use alloc::vec;
+use alloc::vec::Vec;
+use core::fmt::{self, Debug};
+use itertools::Itertools;
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::error::IntoAnyError;
+use tree_math::TreeIndex;
+
+use core::ops::Deref;
+
+#[derive(Clone, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct TreeHash(
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
+ Vec<u8>,
+);
+
+impl Debug for TreeHash {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("TreeHash")
+ .fmt(f)
+ }
+}
+
+impl Deref for TreeHash {
+ type Target = [u8];
+
+ fn deref(&self) -> &Self::Target {
+ &self.0
+ }
+}
+
+#[derive(Clone, Debug, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub(crate) struct TreeHashes {
+ pub current: Vec<TreeHash>,
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+struct LeafNodeHashInput<'a> {
+ leaf_index: LeafIndex,
+ leaf_node: Option<&'a LeafNode>,
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+struct ParentNodeTreeHashInput<'a> {
+ parent_node: Option<&'a Parent>,
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ left_hash: &'a [u8],
+ #[mls_codec(with = "mls_rs_codec::byte_vec")]
+ right_hash: &'a [u8],
+}
+
+#[derive(Debug, MlsSize, MlsEncode)]
+#[repr(u8)]
+enum TreeHashInput<'a> {
+ Leaf(LeafNodeHashInput<'a>) = 1u8,
+ Parent(ParentNodeTreeHashInput<'a>) = 2u8,
+}
+
+impl TreeKemPublic {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[inline(never)]
+ pub async fn tree_hash<P: CipherSuiteProvider>(
+ &mut self,
+ cipher_suite_provider: &P,
+ ) -> Result<Vec<u8>, MlsError> {
+ self.initialize_hashes(cipher_suite_provider).await?;
+ let root = self.total_leaf_count().root();
+ Ok(self.tree_hashes.current[root as usize].to_vec())
+ }
+
+ // Update hashes after `committer` makes changes to the tree. `path_blank` is the
+ // list of leaves whose paths were blanked, i.e. updates and removes.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn update_hashes<P: CipherSuiteProvider>(
+ &mut self,
+ updated_leaves: &[LeafIndex],
+ cipher_suite_provider: &P,
+ ) -> Result<(), MlsError> {
+ let num_leaves = self.total_leaf_count();
+
+ let trailing_blanks = (0..num_leaves)
+ .rev()
+ .map_while(|l| {
+ self.tree_hashes
+ .current
+ .get(2 * l as usize)
+ .is_none()
+ .then_some(LeafIndex(l))
+ })
+ .collect::<Vec<_>>();
+
+ // Update the current hashes for direct paths of all modified leaves.
+ tree_hash(
+ &mut self.tree_hashes.current,
+ &self.nodes,
+ Some([updated_leaves, &trailing_blanks].concat()),
+ &[],
+ num_leaves,
+ cipher_suite_provider,
+ )
+ .await?;
+
+ Ok(())
+ }
+
+ // Initialize all hashes after creating / importing a tree.
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError>
+ where
+ P: CipherSuiteProvider,
+ {
+ if self.tree_hashes.current.is_empty() {
+ let num_leaves = self.total_leaf_count();
+
+ tree_hash(
+ &mut self.tree_hashes.current,
+ &self.nodes,
+ None,
+ &[],
+ num_leaves,
+ cipher_suite_provider,
+ )
+ .await?;
+ }
+
+ Ok(())
+ }
+
+ pub(crate) fn unmerged_in_subtree(
+ &self,
+ node_unmerged: u32,
+ subtree_root: u32,
+ ) -> Result<&[LeafIndex], MlsError> {
+ let unmerged = &self.nodes.borrow_as_parent(node_unmerged)?.unmerged_leaves;
+ let (left, right) = tree_math::subtree(subtree_root);
+ let mut start = 0;
+ while start < unmerged.len() && unmerged[start] < left {
+ start += 1;
+ }
+ let mut end = start;
+ while end < unmerged.len() && unmerged[end] < right {
+ end += 1;
+ }
+ Ok(&unmerged[start..end])
+ }
+
+ fn different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError> {
+ Ok(!self.nodes.is_blank(ancestor)?
+ && !self.nodes.is_blank(descendant)?
+ && self.unmerged_in_subtree(ancestor, descendant)?
+ != self.nodes.borrow_as_parent(descendant)?.unmerged_leaves)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub(crate) async fn compute_original_hashes<P: CipherSuiteProvider>(
+ &self,
+ cipher_suite: &P,
+ ) -> Result<Vec<TreeHash>, MlsError> {
+ let num_leaves = self.nodes.total_leaf_count() as usize;
+ let root = (num_leaves as u32).root();
+
+ // The value `filtered_sets[n]` is a list of all ancestors `a` of `n` s.t. we have to compute
+ // the tree hash of `n` with the unmerged leaves of `a` filtered out.
+ let mut filtered_sets = vec![vec![]; num_leaves * 2 - 1];
+ filtered_sets[root as usize].push(root);
+ let mut tree_hashes = vec![vec![]; num_leaves * 2 - 1];
+
+ let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1);
+
+ for n in bfs_iter {
+ let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
+ break;
+ };
+
+ let p = ps.parent;
+ filtered_sets[n] = filtered_sets[p as usize].clone();
+
+ if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? {
+ filtered_sets[n].push(p);
+
+ // Compute tree hash of `n` without unmerged leaves of `p`. This also computes the tree hash
+ // for any descendants of `n` added to `filtered_sets` later via `clone`.
+ let (start_leaf, end_leaf) = tree_math::subtree(n as u32);
+
+ tree_hash(
+ &mut tree_hashes[p as usize],
+ &self.nodes,
+ Some((*start_leaf..*end_leaf).map(LeafIndex).collect_vec()),
+ &self.nodes.borrow_as_parent(p)?.unmerged_leaves,
+ num_leaves as u32,
+ cipher_suite,
+ )
+ .await?;
+ }
+ }
+
+ // Set the `original_hashes` based on the computed `hashes`.
+ let mut original_hashes = vec![TreeHash::default(); num_leaves * 2 - 1];
+
+ // If root has unmerged leaves, we recompute it's original hash. Else, we can use the current hash.
+ let root_original = if !self.nodes.is_blank(root)? && !self.nodes.is_leaf(root) {
+ let root_unmerged = &self.nodes.borrow_as_parent(root)?.unmerged_leaves;
+
+ if !root_unmerged.is_empty() {
+ let mut hashes = vec![];
+
+ tree_hash(
+ &mut hashes,
+ &self.nodes,
+ None,
+ root_unmerged,
+ num_leaves as u32,
+ cipher_suite,
+ )
+ .await?;
+
+ Some(hashes)
+ } else {
+ None
+ }
+ } else {
+ None
+ };
+
+ for (i, hash) in original_hashes.iter_mut().enumerate() {
+ let a = filtered_sets[i].last().unwrap();
+ *hash = if self.nodes.is_blank(*a)? || a == &root {
+ if let Some(root_original) = &root_original {
+ root_original[i].clone()
+ } else {
+ self.tree_hashes.current[i].clone()
+ }
+ } else {
+ tree_hashes[*a as usize][i].clone()
+ }
+ }
+
+ Ok(original_hashes)
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn tree_hash<P: CipherSuiteProvider>(
+ hashes: &mut Vec<TreeHash>,
+ nodes: &NodeVec,
+ leaves_to_update: Option<Vec<LeafIndex>>,
+ filtered_leaves: &[LeafIndex],
+ num_leaves: u32,
+ cipher_suite_provider: &P,
+) -> Result<(), MlsError> {
+ let leaves_to_update =
+ leaves_to_update.unwrap_or_else(|| (0..num_leaves).map(LeafIndex).collect::<Vec<_>>());
+
+ // Resize the array in case the tree was extended or truncated
+ hashes.resize(num_leaves as usize * 2 - 1, TreeHash::default());
+
+ let mut node_queue = VecDeque::with_capacity(leaves_to_update.len());
+
+ for l in leaves_to_update.iter().filter(|l| ***l < num_leaves) {
+ let leaf = (!filtered_leaves.contains(l))
+ .then_some(nodes.borrow_as_leaf(*l).ok())
+ .flatten();
+
+ hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?);
+
+ if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) {
+ node_queue.push_back(ps.parent);
+ }
+ }
+
+ while let Some(n) = node_queue.pop_front() {
+ let hash = TreeHash(
+ hash_for_parent(
+ nodes.borrow_as_parent(n).ok(),
+ cipher_suite_provider,
+ filtered_leaves,
+ &hashes[n.left_unchecked() as usize],
+ &hashes[n.right_unchecked() as usize],
+ )
+ .await?,
+ );
+
+ hashes[n as usize] = hash;
+
+ if let Some(ps) = n.parent_sibling(&num_leaves) {
+ node_queue.push_back(ps.parent);
+ }
+ }
+
+ Ok(())
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn hash_for_leaf<P: CipherSuiteProvider>(
+ leaf_index: LeafIndex,
+ leaf_node: Option<&LeafNode>,
+ cipher_suite_provider: &P,
+) -> Result<Vec<u8>, MlsError> {
+ let input = TreeHashInput::Leaf(LeafNodeHashInput {
+ leaf_index,
+ leaf_node,
+ });
+
+ cipher_suite_provider
+ .hash(&input.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn hash_for_parent<P: CipherSuiteProvider>(
+ parent_node: Option<&Parent>,
+ cipher_suite_provider: &P,
+ filtered: &[LeafIndex],
+ left_hash: &[u8],
+ right_hash: &[u8],
+) -> Result<Vec<u8>, MlsError> {
+ let mut parent_node = parent_node.cloned();
+
+ if let Some(ref mut parent_node) = parent_node {
+ parent_node
+ .unmerged_leaves
+ .retain(|unmerged_index| !filtered.contains(unmerged_index));
+ }
+
+ let input = TreeHashInput::Parent(ParentNodeTreeHashInput {
+ parent_node: parent_node.as_ref(),
+ left_hash,
+ right_hash,
+ });
+
+ cipher_suite_provider
+ .hash(&input.mls_encode_to_vec()?)
+ .await
+ .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
+}
+
+#[cfg(test)]
+mod tests {
+ use mls_rs_codec::MlsDecode;
+
+ use crate::{
+ cipher_suite::CipherSuite,
+ crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{node::NodeVec, parent_hash::test_utils::get_test_tree_fig_12},
+ };
+
+ use super::*;
+
+ #[derive(serde::Deserialize, serde::Serialize)]
+ struct TestCase {
+ cipher_suite: u16,
+ #[serde(with = "hex::serde")]
+ tree_data: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ tree_hash: Vec<u8>,
+ }
+
+ impl TestCase {
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ #[cfg_attr(coverage_nightly, coverage(off))]
+ async fn generate() -> Vec<TestCase> {
+ let mut test_cases = Vec::new();
+
+ for cipher_suite in CipherSuite::all() {
+ let mut tree = get_test_tree_fig_12(cipher_suite).await;
+
+ test_cases.push(TestCase {
+ cipher_suite: cipher_suite.into(),
+ tree_data: tree.nodes.mls_encode_to_vec().unwrap(),
+ tree_hash: tree
+ .tree_hash(&test_cipher_suite_provider(cipher_suite))
+ .await
+ .unwrap(),
+ })
+ }
+
+ test_cases
+ }
+ }
+
+ #[cfg(mls_build_async)]
+ async fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_hash, TestCase::generate().await)
+ }
+
+ #[cfg(not(mls_build_async))]
+ fn load_test_cases() -> Vec<TestCase> {
+ load_test_case_json!(tree_hash, TestCase::generate())
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_tree_hash() {
+ let cases = load_test_cases().await;
+
+ for one_case in cases {
+ let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
+ continue;
+ };
+
+ let mut tree = TreeKemPublic::import_node_data(
+ NodeVec::mls_decode(&mut &*one_case.tree_data).unwrap(),
+ &BasicIdentityProvider,
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ let calculated_hash = tree.tree_hash(&cs_provider).await.unwrap();
+
+ assert_eq!(calculated_hash, one_case.tree_hash);
+ }
+ }
+}
diff --git a/src/tree_kem/tree_index.rs b/src/tree_kem/tree_index.rs
new file mode 100644
index 0000000..4e6731a
--- /dev/null
+++ b/src/tree_kem/tree_index.rs
@@ -0,0 +1,505 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use super::*;
+#[cfg(feature = "tree_index")]
+use core::fmt::{self, Debug};
+
+#[cfg(all(feature = "tree_index", feature = "custom_proposal"))]
+use crate::group::proposal::ProposalType;
+
+#[cfg(feature = "tree_index")]
+use crate::identity::CredentialType;
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::crypto::SignaturePublicKey;
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+use itertools::Itertools;
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+use alloc::collections::{btree_map::Entry, BTreeMap};
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+use std::collections::{hash_map::Entry, HashMap};
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+use alloc::collections::BTreeSet;
+
+#[cfg(feature = "tree_index")]
+use mls_rs_core::crypto::HpkePublicKey;
+
+#[cfg(feature = "tree_index")]
+#[derive(Clone, Default, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode, Hash, PartialOrd, Ord)]
+pub struct Identifier(#[mls_codec(with = "mls_rs_codec::byte_vec")] Vec<u8>);
+
+#[cfg(feature = "tree_index")]
+impl Debug for Identifier {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ mls_rs_core::debug::pretty_bytes(&self.0)
+ .named("Identifier")
+ .fmt(f)
+ }
+}
+
+#[cfg(all(feature = "tree_index", feature = "std"))]
+#[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct TreeIndex {
+ credential_signature_key: HashMap<SignaturePublicKey, LeafIndex>,
+ hpke_key: HashMap<HpkePublicKey, LeafIndex>,
+ identities: HashMap<Identifier, LeafIndex>,
+ credential_type_counters: HashMap<CredentialType, TypeCounter>,
+ #[cfg(feature = "custom_proposal")]
+ proposal_type_counter: HashMap<ProposalType, u32>,
+}
+
+#[cfg(all(feature = "tree_index", not(feature = "std")))]
+#[derive(Clone, Debug, Default, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+pub struct TreeIndex {
+ credential_signature_key: BTreeMap<SignaturePublicKey, LeafIndex>,
+ hpke_key: BTreeMap<HpkePublicKey, LeafIndex>,
+ identities: BTreeMap<Identifier, LeafIndex>,
+ credential_type_counters: BTreeMap<CredentialType, TypeCounter>,
+ #[cfg(feature = "custom_proposal")]
+ proposal_type_counter: BTreeMap<ProposalType, u32>,
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn index_insert<I: IdentityProvider>(
+ tree_index: &mut TreeIndex,
+ new_leaf: &LeafNode,
+ new_leaf_idx: LeafIndex,
+ id_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError> {
+ let new_id = id_provider
+ .identity(&new_leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ tree_index.insert(new_leaf_idx, new_leaf, new_id)
+}
+
+#[cfg(not(feature = "tree_index"))]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(super) async fn index_insert<I: IdentityProvider>(
+ nodes: &NodeVec,
+ new_leaf: &LeafNode,
+ new_leaf_idx: LeafIndex,
+ id_provider: &I,
+ extensions: &ExtensionList,
+) -> Result<(), MlsError> {
+ let new_id = id_provider
+ .identity(&new_leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ for (i, leaf) in nodes.non_empty_leaves().filter(|(i, _)| i != &new_leaf_idx) {
+ (new_leaf.public_key != leaf.public_key)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ (new_leaf.signing_identity.signature_key != leaf.signing_identity.signature_key)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ let id = id_provider
+ .identity(&leaf.signing_identity, extensions)
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?;
+
+ (new_id != id)
+ .then_some(())
+ .ok_or(MlsError::DuplicateLeafData(*i))?;
+
+ let cred_type = leaf.signing_identity.credential.credential_type();
+
+ new_leaf
+ .capabilities
+ .credentials
+ .contains(&cred_type)
+ .then_some(())
+ .ok_or(MlsError::InUseCredentialTypeUnsupportedByNewLeaf)?;
+
+ let new_cred_type = new_leaf.signing_identity.credential.credential_type();
+
+ leaf.capabilities
+ .credentials
+ .contains(&new_cred_type)
+ .then_some(())
+ .ok_or(MlsError::CredentialTypeOfNewLeafIsUnsupported)?;
+ }
+
+ Ok(())
+}
+
+#[cfg(feature = "tree_index")]
+impl TreeIndex {
+ pub fn new() -> Self {
+ Default::default()
+ }
+
+ pub fn is_initialized(&self) -> bool {
+ !self.identities.is_empty()
+ }
+
+ fn insert(
+ &mut self,
+ index: LeafIndex,
+ leaf_node: &LeafNode,
+ identity: Vec<u8>,
+ ) -> Result<(), MlsError> {
+ let old_leaf_count = self.credential_signature_key.len();
+
+ let pub_key = leaf_node.signing_identity.signature_key.clone();
+ let credential_entry = self.credential_signature_key.entry(pub_key);
+
+ if let Entry::Occupied(entry) = credential_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let hpke_entry = self.hpke_key.entry(leaf_node.public_key.clone());
+
+ if let Entry::Occupied(entry) = hpke_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let identity_entry = self.identities.entry(Identifier(identity));
+ if let Entry::Occupied(entry) = identity_entry {
+ return Err(MlsError::DuplicateLeafData(**entry.get()));
+ }
+
+ let in_use_cred_type_unsupported_by_new_leaf = self
+ .credential_type_counters
+ .iter()
+ .filter_map(|(cred_type, counters)| Some(*cred_type).filter(|_| counters.used > 0))
+ .find(|cred_type| !leaf_node.capabilities.credentials.contains(cred_type));
+
+ if in_use_cred_type_unsupported_by_new_leaf.is_some() {
+ return Err(MlsError::InUseCredentialTypeUnsupportedByNewLeaf);
+ }
+
+ let new_leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
+
+ let cred_type_counters = self
+ .credential_type_counters
+ .entry(new_leaf_cred_type)
+ .or_default();
+
+ if cred_type_counters.supported != old_leaf_count as u32 {
+ return Err(MlsError::CredentialTypeOfNewLeafIsUnsupported);
+ }
+
+ cred_type_counters.used += 1;
+
+ let credential_type_iter = leaf_node.capabilities.credentials.iter().copied();
+
+ #[cfg(feature = "std")]
+ let credential_type_iter = credential_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Credential type counter updates
+ credential_type_iter.for_each(|cred_type| {
+ self.credential_type_counters
+ .entry(cred_type)
+ .or_default()
+ .supported += 1;
+ });
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let proposal_type_iter = leaf_node.capabilities.proposals.iter().copied();
+
+ #[cfg(feature = "std")]
+ let proposal_type_iter = proposal_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Proposal type counter update
+ proposal_type_iter.for_each(|proposal_type| {
+ *self.proposal_type_counter.entry(proposal_type).or_default() += 1;
+ });
+ }
+
+ identity_entry.or_insert(index);
+ credential_entry.or_insert(index);
+ hpke_entry.or_insert(index);
+
+ Ok(())
+ }
+
+ pub(crate) fn get_leaf_index_with_identity(&self, identity: &[u8]) -> Option<LeafIndex> {
+ self.identities.get(&Identifier(identity.to_vec())).copied()
+ }
+
+ pub fn remove(&mut self, leaf_node: &LeafNode, identity: &[u8]) {
+ let existed = self
+ .identities
+ .remove(&Identifier(identity.to_vec()))
+ .is_some();
+
+ self.credential_signature_key
+ .remove(&leaf_node.signing_identity.signature_key);
+
+ self.hpke_key.remove(&leaf_node.public_key);
+
+ if !existed {
+ return;
+ }
+
+ // Decrement credential type counters
+ let leaf_cred_type = leaf_node.signing_identity.credential.credential_type();
+
+ if let Some(counters) = self.credential_type_counters.get_mut(&leaf_cred_type) {
+ counters.used -= 1;
+ }
+
+ let credential_type_iter = leaf_node.capabilities.credentials.iter();
+
+ #[cfg(feature = "std")]
+ let credential_type_iter = credential_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let credential_type_iter = credential_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ credential_type_iter.for_each(|cred_type| {
+ if let Some(counters) = self.credential_type_counters.get_mut(cred_type) {
+ counters.supported -= 1;
+ }
+ });
+
+ #[cfg(feature = "custom_proposal")]
+ {
+ let proposal_type_iter = leaf_node.capabilities.proposals.iter();
+
+ #[cfg(feature = "std")]
+ let proposal_type_iter = proposal_type_iter.unique();
+
+ #[cfg(not(feature = "std"))]
+ let proposal_type_iter = proposal_type_iter.collect::<BTreeSet<_>>().into_iter();
+
+ // Decrement proposal type counters
+ proposal_type_iter.for_each(|proposal_type| {
+ if let Some(supported) = self.proposal_type_counter.get_mut(proposal_type) {
+ *supported -= 1;
+ }
+ })
+ }
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ pub fn count_supporting_proposal(&self, proposal_type: ProposalType) -> u32 {
+ self.proposal_type_counter
+ .get(&proposal_type)
+ .copied()
+ .unwrap_or_default()
+ }
+
+ #[cfg(test)]
+ pub fn len(&self) -> usize {
+ self.credential_signature_key.len()
+ }
+}
+
+#[cfg(feature = "tree_index")]
+#[derive(Clone, Debug, Default, PartialEq, MlsEncode, MlsDecode, MlsSize)]
+struct TypeCounter {
+ supported: u32,
+ used: u32,
+}
+
+#[cfg(feature = "tree_index")]
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ tree_kem::leaf_node::test_utils::{get_basic_test_node, get_test_client_identity},
+ };
+ use alloc::format;
+ use assert_matches::assert_matches;
+
+ #[derive(Clone, Debug)]
+ struct TestData {
+ pub leaf_node: LeafNode,
+ pub index: LeafIndex,
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_test_data(index: LeafIndex) -> TestData {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let leaf_node = get_basic_test_node(cipher_suite, &format!("foo{}", index.0)).await;
+
+ TestData { leaf_node, index }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_setup() -> (Vec<TestData>, TreeIndex) {
+ let mut test_data = Vec::new();
+
+ for i in 0..10 {
+ test_data.push(get_test_data(LeafIndex(i)).await);
+ }
+
+ let mut test_index = TreeIndex::new();
+
+ test_data.clone().into_iter().for_each(|d| {
+ test_index
+ .insert(
+ d.index,
+ &d.leaf_node,
+ get_test_client_identity(&d.leaf_node),
+ )
+ .unwrap()
+ });
+
+ (test_data, test_index)
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert() {
+ let (test_data, test_index) = test_setup().await;
+
+ assert_eq!(test_index.credential_signature_key.len(), test_data.len());
+ assert_eq!(test_index.hpke_key.len(), test_data.len());
+
+ test_data.into_iter().enumerate().for_each(|(i, d)| {
+ let pub_key = d.leaf_node.signing_identity.signature_key;
+
+ assert_eq!(
+ test_index.credential_signature_key.get(&pub_key),
+ Some(&LeafIndex(i as u32))
+ );
+
+ assert_eq!(
+ test_index.hpke_key.get(&d.leaf_node.public_key),
+ Some(&LeafIndex(i as u32))
+ );
+ })
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_duplicate_credential_key() {
+ let (test_data, mut test_index) = test_setup().await;
+
+ let before_error = test_index.clone();
+
+ let mut new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
+ new_key_package.signing_identity = test_data[1].leaf_node.signing_identity.clone();
+
+ let res = test_index.insert(
+ test_data[1].index,
+ &new_key_package,
+ get_test_client_identity(&new_key_package),
+ );
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
+ if index == *test_data[1].index);
+
+ assert_eq!(before_error, test_index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_insert_duplicate_hpke_key() {
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let (test_data, mut test_index) = test_setup().await;
+ let before_error = test_index.clone();
+
+ let mut new_leaf_node = get_basic_test_node(cipher_suite, "foo").await;
+ new_leaf_node.public_key = test_data[1].leaf_node.public_key.clone();
+
+ let res = test_index.insert(
+ test_data[1].index,
+ &new_leaf_node,
+ get_test_client_identity(&new_leaf_node),
+ );
+
+ assert_matches!(res, Err(MlsError::DuplicateLeafData(index))
+ if index == *test_data[1].index);
+
+ assert_eq!(before_error, test_index);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_remove() {
+ let (test_data, mut test_index) = test_setup().await;
+
+ test_index.remove(
+ &test_data[1].leaf_node,
+ &get_test_client_identity(&test_data[1].leaf_node),
+ );
+
+ assert_eq!(
+ test_index.credential_signature_key.len(),
+ test_data.len() - 1
+ );
+
+ assert_eq!(test_index.hpke_key.len(), test_data.len() - 1);
+
+ assert_eq!(
+ test_index
+ .credential_signature_key
+ .get(&test_data[1].leaf_node.signing_identity.signature_key),
+ None
+ );
+
+ assert_eq!(
+ test_index.hpke_key.get(&test_data[1].leaf_node.public_key),
+ None
+ );
+ }
+
+ #[cfg(feature = "custom_proposal")]
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn custom_proposals() {
+ let test_proposal_id = ProposalType::new(42);
+ let other_proposal_id = ProposalType::new(45);
+
+ let mut test_data_1 = get_test_data(LeafIndex(0)).await;
+
+ test_data_1
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(test_proposal_id);
+
+ let mut test_data_2 = get_test_data(LeafIndex(1)).await;
+
+ test_data_2
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(test_proposal_id);
+
+ test_data_2
+ .leaf_node
+ .capabilities
+ .proposals
+ .push(other_proposal_id);
+
+ let mut test_index = TreeIndex::new();
+
+ test_index
+ .insert(test_data_1.index, &test_data_1.leaf_node, vec![0])
+ .unwrap();
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
+
+ test_index
+ .insert(test_data_2.index, &test_data_2.leaf_node, vec![1])
+ .unwrap();
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 2);
+ assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 1);
+
+ test_index.remove(&test_data_2.leaf_node, &[1]);
+
+ assert_eq!(test_index.count_supporting_proposal(test_proposal_id), 1);
+ assert_eq!(test_index.count_supporting_proposal(other_proposal_id), 0);
+ }
+}
diff --git a/src/tree_kem/tree_utils.rs b/src/tree_kem/tree_utils.rs
new file mode 100644
index 0000000..e7cdeb1
--- /dev/null
+++ b/src/tree_kem/tree_utils.rs
@@ -0,0 +1,191 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::string::String;
+use alloc::{format, vec};
+use core::borrow::BorrowMut;
+
+use debug_tree::TreeBuilder;
+
+use super::node::{NodeIndex, NodeVec};
+use crate::{client::MlsError, tree_kem::math::TreeIndex};
+
+pub(crate) fn build_tree(
+ tree: &mut TreeBuilder,
+ nodes: &NodeVec,
+ idx: NodeIndex,
+) -> Result<(), MlsError> {
+ let blank_tag = if nodes.is_blank(idx)? { "Blank " } else { "" };
+
+ // Leaf Node
+ if nodes.is_leaf(idx) {
+ let leaf_tag = format!("{blank_tag}Leaf ({idx})");
+ tree.add_leaf(&leaf_tag);
+ return Ok(());
+ }
+
+ // Parent Leaf
+ let mut parent_tag = format!("{blank_tag}Parent ({idx})");
+
+ if nodes.total_leaf_count().root() == idx {
+ parent_tag = format!("{blank_tag}Root ({idx})");
+ }
+
+ // Add unmerged leaves indexes
+ let unmerged_leaves_idxs = match nodes.borrow_as_parent(idx) {
+ Ok(parent) => parent
+ .unmerged_leaves
+ .iter()
+ .map(|leaf_idx| format!("{}", leaf_idx.0))
+ .collect(),
+ Err(_) => {
+ // Empty parent nodes throw `NotParent` error when borrow as Parent
+ vec![]
+ }
+ };
+
+ if !unmerged_leaves_idxs.is_empty() {
+ let unmerged_leaves_tag =
+ format!(" unmerged leaves idxs: {}", unmerged_leaves_idxs.join(","));
+ parent_tag.push_str(&unmerged_leaves_tag);
+ }
+
+ let mut branch = tree.add_branch(&parent_tag);
+
+ //This cannot panic, as we already checked that idx is not a leaf
+ build_tree(tree, nodes, idx.left_unchecked())?;
+ build_tree(tree, nodes, idx.right_unchecked())?;
+
+ branch.release();
+
+ Ok(())
+}
+
+pub(crate) fn build_ascii_tree(nodes: &NodeVec) -> String {
+ let leaves_count: u32 = nodes.total_leaf_count();
+ let mut tree = TreeBuilder::new();
+ build_tree(tree.borrow_mut(), nodes, leaves_count.root()).unwrap();
+ tree.string()
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+
+ use crate::{
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ node::Parent,
+ parent_hash::ParentHash,
+ test_utils::{get_test_leaf_nodes, get_test_tree},
+ },
+ };
+
+ use super::build_ascii_tree;
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_fully_populated_tree() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Blank Root (3)\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_tree_blank_leaves() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ let to_remove = tree
+ .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider)
+ .await
+ .unwrap()[0];
+
+ tree.remove_leaves(
+ vec![to_remove],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Blank Root (3)\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Blank Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn print_tree_unmerged_leaves_on_parent() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+
+ // Create a tree
+ let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public;
+ let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await;
+
+ tree.add_leaves(
+ [key_packages[0].clone(), key_packages[1].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ tree.nodes[3] = Parent {
+ public_key: vec![].into(),
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ .into();
+
+ tree.add_leaves(
+ [key_packages[2].clone()].to_vec(),
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ let tree_str = concat!(
+ "Root (3) unmerged leaves idxs: 3\n",
+ "├╼ Blank Parent (1)\n",
+ "│ ├╼ Leaf (0)\n",
+ "│ └╼ Leaf (2)\n",
+ "└╼ Blank Parent (5)\n",
+ " ├╼ Leaf (4)\n",
+ " └╼ Leaf (6)",
+ );
+
+ assert_eq!(tree_str, build_ascii_tree(&tree.nodes));
+ }
+}
diff --git a/src/tree_kem/tree_validator.rs b/src/tree_kem/tree_validator.rs
new file mode 100644
index 0000000..26d4baf
--- /dev/null
+++ b/src/tree_kem/tree_validator.rs
@@ -0,0 +1,356 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+#[cfg(feature = "std")]
+use std::collections::HashSet;
+
+#[cfg(not(feature = "std"))]
+use alloc::{vec, vec::Vec};
+use tree_math::TreeIndex;
+
+use super::node::{Node, NodeIndex};
+use crate::client::MlsError;
+use crate::crypto::CipherSuiteProvider;
+use crate::group::GroupContext;
+use crate::iter::wrap_impl_iter;
+use crate::tree_kem::math as tree_math;
+use crate::tree_kem::{leaf_node_validator::LeafNodeValidator, TreeKemPublic};
+use mls_rs_core::identity::IdentityProvider;
+
+#[cfg(all(not(mls_build_async), feature = "rayon"))]
+use rayon::prelude::*;
+
+#[cfg(mls_build_async)]
+use futures::{StreamExt, TryStreamExt};
+
+pub(crate) struct TreeValidator<'a, C, CSP>
+where
+ C: IdentityProvider,
+ CSP: CipherSuiteProvider,
+{
+ expected_tree_hash: &'a [u8],
+ leaf_node_validator: LeafNodeValidator<'a, C, CSP>,
+ group_id: &'a [u8],
+ cipher_suite_provider: &'a CSP,
+}
+
+impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP> {
+ pub fn new(
+ cipher_suite_provider: &'a CSP,
+ context: &'a GroupContext,
+ identity_provider: &'a C,
+ ) -> Self {
+ TreeValidator {
+ expected_tree_hash: &context.tree_hash,
+ leaf_node_validator: LeafNodeValidator::new(
+ cipher_suite_provider,
+ identity_provider,
+ Some(&context.extensions),
+ ),
+ group_id: &context.group_id,
+ cipher_suite_provider,
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ pub async fn validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
+ self.validate_tree_hash(tree).await?;
+
+ tree.validate_parent_hashes(self.cipher_suite_provider)
+ .await?;
+
+ self.validate_no_trailing_blanks(tree)?;
+ self.validate_leaves(tree).await?;
+ validate_unmerged(tree)
+ }
+
+ fn validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
+ tree.nodes
+ .last()
+ .ok_or(MlsError::UnexpectedEmptyTree)?
+ .is_some()
+ .then_some(())
+ .ok_or(MlsError::UnexpectedTrailingBlanks)
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
+ //Verify that the tree hash of the ratchet tree matches the tree_hash field in the GroupInfo.
+ let tree_hash = tree.tree_hash(self.cipher_suite_provider).await?;
+
+ if tree_hash != self.expected_tree_hash {
+ return Err(MlsError::TreeHashMismatch);
+ }
+
+ Ok(())
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
+ let leaves = wrap_impl_iter(tree.nodes.non_empty_leaves());
+
+ #[cfg(mls_build_async)]
+ let leaves = leaves.map(Ok);
+
+ { leaves }
+ .try_for_each(|(index, leaf_node)| async move {
+ self.leaf_node_validator
+ .revalidate(leaf_node, self.group_id, *index)
+ .await
+ })
+ .await
+ }
+}
+
+fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
+ let unmerged_sets = tree.nodes.iter().map(|n| {
+ #[cfg(feature = "std")]
+ if let Some(Node::Parent(p)) = n {
+ HashSet::from_iter(p.unmerged_leaves.iter().cloned())
+ } else {
+ HashSet::new()
+ }
+
+ #[cfg(not(feature = "std"))]
+ if let Some(Node::Parent(p)) = n {
+ p.unmerged_leaves.clone()
+ } else {
+ vec![]
+ }
+ });
+
+ let mut unmerged_sets = unmerged_sets.collect::<Vec<_>>();
+
+ // For each leaf L, we search for the longest prefix P[1], P[2], ..., P[k] of the direct path of L
+ // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will
+ // then check that L is unmerged at each P[1], ..., P[k] and no other node.
+ let leaf_count = tree.total_leaf_count();
+
+ for (index, _) in tree.nodes.non_empty_leaves() {
+ let mut n = NodeIndex::from(index);
+
+ while let Some(ps) = n.parent_sibling(&leaf_count) {
+ if tree.nodes.is_blank(ps.parent)? {
+ n = ps.parent;
+ continue;
+ }
+
+ let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;
+
+ if parent_node.unmerged_leaves.contains(&index) {
+ unmerged_sets[ps.parent as usize].retain(|i| i != &index);
+
+ n = ps.parent;
+ } else {
+ break;
+ }
+ }
+ }
+
+ let unmerged_sets = unmerged_sets.iter().all(|set| set.is_empty());
+
+ unmerged_sets
+ .then_some(())
+ .ok_or(MlsError::UnmergedLeavesMismatch)
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ use super::*;
+ use crate::{
+ cipher_suite::CipherSuite,
+ client::test_utils::TEST_CIPHER_SUITE,
+ crypto::test_utils::test_cipher_suite_provider,
+ crypto::test_utils::TestCryptoProvider,
+ group::test_utils::{get_test_group_context, random_bytes},
+ identity::basic::BasicIdentityProvider,
+ tree_kem::{
+ kem::TreeKem,
+ leaf_node::test_utils::{default_properties, get_basic_test_node},
+ node::{LeafIndex, Node, Parent},
+ parent_hash::{test_utils::get_test_tree_fig_12, ParentHash},
+ test_utils::get_test_tree,
+ },
+ };
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_parent_node(cipher_suite: CipherSuite) -> Parent {
+ let (_, public_key) = test_cipher_suite_provider(cipher_suite)
+ .kem_generate()
+ .await
+ .unwrap();
+
+ Parent {
+ public_key,
+ parent_hash: ParentHash::empty(),
+ unmerged_leaves: vec![],
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut test_tree = get_test_tree(cipher_suite).await;
+
+ let leaf1 = get_basic_test_node(cipher_suite, "leaf1").await;
+ let leaf2 = get_basic_test_node(cipher_suite, "leaf2").await;
+
+ test_tree
+ .public
+ .add_leaves(
+ vec![leaf1, leaf2],
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ )
+ .await
+ .unwrap();
+
+ test_tree.public.nodes[1] = Some(Node::Parent(test_parent_node(cipher_suite).await));
+ test_tree.public.nodes[3] = Some(Node::Parent(test_parent_node(cipher_suite).await));
+
+ TreeKem::new(&mut test_tree.public, &mut test_tree.private)
+ .encap(
+ &mut get_test_group_context(42, cipher_suite).await,
+ &[LeafIndex(1), LeafIndex(2)],
+ &test_tree.creator_signing_key,
+ default_properties(),
+ None,
+ &cipher_suite_provider,
+ #[cfg(test)]
+ &Default::default(),
+ )
+ .await
+ .unwrap();
+
+ test_tree.public
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_valid_tree() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ validator.validate(&mut test_tree).await.unwrap();
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_tree_hash_mismatch() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let context = get_test_group_context(1, cipher_suite).await;
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::TreeHashMismatch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_parent_hash_mismatch() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ let parent_node = test_tree.nodes.borrow_as_parent_mut(1).unwrap();
+ parent_node.parent_hash = ParentHash::from(random_bytes(32));
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::ParentHashMismatch));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_key_package_validation_failure() {
+ for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
+ let mut test_tree = get_valid_tree(cipher_suite).await;
+
+ test_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .signature = random_bytes(32);
+
+ let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
+ let mut context = get_test_group_context(1, cipher_suite).await;
+ context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
+
+ let validator =
+ TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
+
+ let res = validator.validate(&mut test_tree).await;
+
+ assert_matches!(res, Err(MlsError::InvalidSignature));
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_correct_tree() {
+ let tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+ validate_unmerged(&tree).unwrap();
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_blank_leaf() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Blank leaf D unmerged at nodes 3, 7
+ tree.nodes[6] = None;
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_broken_path() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Make D with direct path [3, 7] unmerged at 7 but not 3
+ tree.nodes.borrow_as_parent_mut(3).unwrap().unmerged_leaves = vec![];
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn verify_unmerged_with_leaf_outside_tree() {
+ let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
+
+ // Add leaf E from the right subtree of the root to unmerged leaves of node 1 on the left
+ tree.nodes.borrow_as_parent_mut(1).unwrap().unmerged_leaves = vec![LeafIndex(4)];
+
+ assert_matches!(
+ validate_unmerged(&tree),
+ Err(MlsError::UnmergedLeavesMismatch)
+ );
+ }
+}
diff --git a/src/tree_kem/update_path.rs b/src/tree_kem/update_path.rs
new file mode 100644
index 0000000..654c21f
--- /dev/null
+++ b/src/tree_kem/update_path.rs
@@ -0,0 +1,274 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use alloc::{vec, vec::Vec};
+use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
+use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider};
+
+use super::{
+ leaf_node::LeafNode,
+ leaf_node_validator::{LeafNodeValidator, ValidationContext},
+ node::LeafIndex,
+};
+use crate::{
+ client::MlsError,
+ crypto::{CipherSuiteProvider, HpkeCiphertext, HpkePublicKey},
+};
+use crate::{group::message_processor::ProvisionalState, time::MlsTime};
+
+#[derive(Clone, Debug, PartialEq, Eq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct UpdatePathNode {
+ pub public_key: HpkePublicKey,
+ pub encrypted_path_secret: Vec<HpkeCiphertext>,
+}
+
+#[derive(Clone, Debug, PartialEq, MlsSize, MlsEncode, MlsDecode)]
+#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
+#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
+pub struct UpdatePath {
+ pub leaf_node: LeafNode,
+ pub nodes: Vec<UpdatePathNode>,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub struct ValidatedUpdatePath {
+ pub leaf_node: LeafNode,
+ pub nodes: Vec<Option<UpdatePathNode>>,
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub(crate) async fn validate_update_path<C: IdentityProvider, CSP: CipherSuiteProvider>(
+ identity_provider: &C,
+ cipher_suite_provider: &CSP,
+ path: UpdatePath,
+ state: &ProvisionalState,
+ sender: LeafIndex,
+ commit_time: Option<MlsTime>,
+) -> Result<ValidatedUpdatePath, MlsError> {
+ let group_context_extensions = &state.group_context.extensions;
+
+ let leaf_validator = LeafNodeValidator::new(
+ cipher_suite_provider,
+ identity_provider,
+ Some(group_context_extensions),
+ );
+
+ leaf_validator
+ .check_if_valid(
+ &path.leaf_node,
+ ValidationContext::Commit((&state.group_context.group_id, *sender, commit_time)),
+ )
+ .await?;
+
+ let check_identity_eq = state.applied_proposals.external_initializations.is_empty();
+
+ if check_identity_eq {
+ let existing_leaf = state.public_tree.nodes.borrow_as_leaf(sender)?;
+ let original_leaf_node = existing_leaf.clone();
+
+ identity_provider
+ .valid_successor(
+ &original_leaf_node.signing_identity,
+ &path.leaf_node.signing_identity,
+ group_context_extensions,
+ )
+ .await
+ .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?
+ .then_some(())
+ .ok_or(MlsError::InvalidSuccessor)?;
+
+ (existing_leaf.public_key != path.leaf_node.public_key)
+ .then_some(())
+ .ok_or(MlsError::SameHpkeKey(*sender))?;
+ }
+
+ // Unfilter the update path
+ let filtered = state.public_tree.nodes.filtered(sender)?;
+ let mut unfiltered_nodes = vec![];
+ let mut i = 0;
+
+ for n in path.nodes {
+ while *filtered.get(i).ok_or(MlsError::WrongPathLen)? {
+ unfiltered_nodes.push(None);
+ i += 1;
+ }
+
+ unfiltered_nodes.push(Some(n));
+ i += 1;
+ }
+
+ Ok(ValidatedUpdatePath {
+ leaf_node: path.leaf_node,
+ nodes: unfiltered_nodes,
+ })
+}
+
+#[cfg(test)]
+mod tests {
+ use alloc::vec;
+ use assert_matches::assert_matches;
+
+ use crate::client::test_utils::TEST_CIPHER_SUITE;
+ use crate::crypto::test_utils::test_cipher_suite_provider;
+ use crate::crypto::HpkeCiphertext;
+ use crate::group::message_processor::ProvisionalState;
+ use crate::group::test_utils::{get_test_group_context, random_bytes, TEST_GROUP};
+ use crate::identity::basic::BasicIdentityProvider;
+ use crate::tree_kem::leaf_node::test_utils::default_properties;
+ use crate::tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key;
+ use crate::tree_kem::leaf_node::LeafNodeSource;
+ use crate::tree_kem::node::LeafIndex;
+ use crate::tree_kem::parent_hash::ParentHash;
+ use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree};
+ use crate::tree_kem::validate_update_path;
+
+ use super::{UpdatePath, UpdatePathNode};
+ use crate::{cipher_suite::CipherSuite, tree_kem::MlsError};
+
+ use alloc::vec::Vec;
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_update_path(cipher_suite: CipherSuite, cred: &str) -> UpdatePath {
+ let (mut leaf_node, _, signer) = get_basic_test_node_sig_key(cipher_suite, cred).await;
+
+ leaf_node.leaf_node_source = LeafNodeSource::Commit(ParentHash::from(hex!("beef")));
+
+ leaf_node
+ .commit(
+ &test_cipher_suite_provider(cipher_suite),
+ TEST_GROUP,
+ 0,
+ default_properties(),
+ None,
+ &signer,
+ )
+ .await
+ .unwrap();
+
+ let node = UpdatePathNode {
+ public_key: random_bytes(32).into(),
+ encrypted_path_secret: vec![HpkeCiphertext {
+ kem_output: random_bytes(32),
+ ciphertext: random_bytes(32),
+ }],
+ };
+
+ UpdatePath {
+ leaf_node,
+ nodes: vec![node.clone(), node],
+ }
+ }
+
+ #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+ async fn test_provisional_state(cipher_suite: CipherSuite) -> ProvisionalState {
+ let mut tree = get_test_tree(cipher_suite).await.public;
+ let leaf_nodes = get_test_leaf_nodes(cipher_suite).await;
+
+ tree.add_leaves(
+ leaf_nodes,
+ &BasicIdentityProvider,
+ &test_cipher_suite_provider(cipher_suite),
+ )
+ .await
+ .unwrap();
+
+ ProvisionalState {
+ public_tree: tree,
+ applied_proposals: Default::default(),
+ group_context: get_test_group_context(1, cipher_suite).await,
+ indexes_of_added_kpkgs: vec![],
+ external_init_index: None,
+ #[cfg(feature = "state_update")]
+ unused_proposals: vec![],
+ }
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_valid_leaf_node() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path.clone(),
+ &test_provisional_state(TEST_CIPHER_SUITE).await,
+ LeafIndex(0),
+ None,
+ )
+ .await
+ .unwrap();
+
+ let expected = update_path.nodes.into_iter().map(Some).collect::<Vec<_>>();
+
+ assert_eq!(validated.nodes, expected);
+ assert_eq!(validated.leaf_node, update_path.leaf_node);
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn test_invalid_key_package() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let mut update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+ update_path.leaf_node.signature = random_bytes(32);
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &test_provisional_state(TEST_CIPHER_SUITE).await,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::InvalidSignature));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn validating_path_fails_with_different_identity() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let cipher_suite = TEST_CIPHER_SUITE;
+ let update_path = test_update_path(cipher_suite, "foobar").await;
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &test_provisional_state(cipher_suite).await,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::InvalidSuccessor));
+ }
+
+ #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
+ async fn validating_path_fails_with_same_hpke_key() {
+ let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE);
+ let update_path = test_update_path(TEST_CIPHER_SUITE, "creator").await;
+ let mut state = test_provisional_state(TEST_CIPHER_SUITE).await;
+
+ state
+ .public_tree
+ .nodes
+ .borrow_as_leaf_mut(LeafIndex(0))
+ .unwrap()
+ .public_key = update_path.leaf_node.public_key.clone();
+
+ let validated = validate_update_path(
+ &BasicIdentityProvider,
+ &cipher_suite_provider,
+ update_path,
+ &state,
+ LeafIndex(0),
+ None,
+ )
+ .await;
+
+ assert_matches!(validated, Err(MlsError::SameHpkeKey(_)));
+ }
+}
diff --git a/test_utils/src/scenario_utils.rs b/test_utils/src/scenario_utils.rs
new file mode 100644
index 0000000..153bc8f
--- /dev/null
+++ b/test_utils/src/scenario_utils.rs
@@ -0,0 +1,338 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use mls_rs::client_builder::Preferences;
+use mls_rs::group::{ReceivedMessage, StateUpdate};
+use mls_rs::{CipherSuite, ExtensionList, Group, MlsMessage, ProtocolVersion};
+
+use crate::test_client::{generate_client, TestClientConfig};
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestCase {
+ pub cipher_suite: u16,
+
+ pub external_psks: Vec<TestExternalPsk>,
+ #[serde(with = "hex::serde")]
+ pub key_package: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub signature_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub encryption_priv: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub init_priv: Vec<u8>,
+
+ #[serde(with = "hex::serde")]
+ pub welcome: Vec<u8>,
+ pub ratchet_tree: Option<TestRatchetTree>,
+ #[serde(with = "hex::serde")]
+ pub initial_epoch_authenticator: Vec<u8>,
+
+ pub epochs: Vec<TestEpoch>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestExternalPsk {
+ #[serde(with = "hex::serde")]
+ pub psk_id: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub psk: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestEpoch {
+ pub proposals: Vec<TestMlsMessage>,
+ #[serde(with = "hex::serde")]
+ pub commit: Vec<u8>,
+ #[serde(with = "hex::serde")]
+ pub epoch_authenticator: Vec<u8>,
+}
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestMlsMessage(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+#[derive(serde::Serialize, serde::Deserialize, Debug, Default, Clone)]
+pub struct TestRatchetTree(#[serde(with = "hex::serde")] pub Vec<u8>);
+
+impl TestEpoch {
+ pub fn new(
+ proposals: Vec<MlsMessage>,
+ commit: &MlsMessage,
+ epoch_authenticator: Vec<u8>,
+ ) -> Self {
+ let proposals = proposals
+ .into_iter()
+ .map(|p| TestMlsMessage(p.to_bytes().unwrap()))
+ .collect();
+
+ Self {
+ proposals,
+ commit: commit.to_bytes().unwrap(),
+ epoch_authenticator,
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn get_test_groups(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_participants: usize,
+ preferences: Preferences,
+) -> Vec<Group<TestClientConfig>> {
+ // Create the group with Alice as the group initiator
+ let creator = generate_client(cipher_suite, b"alice".to_vec(), preferences.clone());
+
+ let mut creator_group = creator
+ .client
+ .create_group_with_id(
+ protocol_version,
+ cipher_suite,
+ b"group".to_vec(),
+ creator.identity,
+ ExtensionList::default(),
+ )
+ .await
+ .unwrap();
+
+ // Generate random clients that will be members of the group
+ let receiver_clients = (0..num_participants - 1)
+ .map(|i| {
+ generate_client(
+ cipher_suite,
+ format!("bob{i}").into_bytes(),
+ preferences.clone(),
+ )
+ })
+ .collect::<Vec<_>>();
+
+ let mut receiver_keys = Vec::new();
+
+ for client in &receiver_clients {
+ let keys = client
+ .client
+ .generate_key_package_message(protocol_version, cipher_suite, client.identity.clone())
+ .await
+ .unwrap();
+
+ receiver_keys.push(keys);
+ }
+
+ // Add the generated clients to the group the creator made
+ let mut commit_builder = creator_group.commit_builder();
+
+ for key in &receiver_keys {
+ commit_builder = commit_builder.add_member(key.clone()).unwrap();
+ }
+
+ let welcome = commit_builder.build().await.unwrap().welcome_message;
+
+ // Creator can confirm the commit was processed by the server
+ #[cfg(feature = "state_update")]
+ {
+ let commit_description = creator_group.apply_pending_commit().await.unwrap();
+
+ assert!(commit_description.state_update.is_active());
+ assert_eq!(commit_description.state_update.new_epoch(), 1);
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ creator_group.apply_pending_commit().await.unwrap();
+
+ for client in &receiver_clients {
+ let res = creator_group
+ .member_with_identity(client.identity.credential.as_basic().unwrap().identifier())
+ .await;
+
+ assert!(res.is_ok());
+ }
+
+ #[cfg(feature = "state_update")]
+ assert!(commit_description
+ .state_update
+ .roster_update()
+ .removed()
+ .is_empty());
+
+ // Export the tree for receivers
+ let tree_data = creator_group.export_tree().unwrap();
+
+ // All the receivers will be able to join the group
+ let mut receiver_groups = Vec::new();
+
+ for client in &receiver_clients {
+ let test_client = client
+ .client
+ .join_group(Some(&tree_data), welcome.clone().unwrap())
+ .await
+ .unwrap()
+ .0;
+
+ receiver_groups.push(test_client);
+ }
+
+ for one_receiver in &receiver_groups {
+ assert!(Group::equal_group_state(&creator_group, one_receiver));
+ }
+
+ receiver_groups.insert(0, creator_group);
+
+ receiver_groups
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn all_process_commit_with_update(
+ groups: &mut [Group<TestClientConfig>],
+ commit: &MlsMessage,
+ sender: usize,
+) -> Vec<StateUpdate> {
+ let mut state_updates = Vec::new();
+
+ for g in groups {
+ let state_update = if sender != g.current_member_index() as usize {
+ let processed_msg = g.process_incoming_message(commit.clone()).await.unwrap();
+
+ match processed_msg {
+ ReceivedMessage::Commit(update) => update.state_update,
+ _ => panic!("Expected commit, got {processed_msg:?}"),
+ }
+ } else {
+ g.apply_pending_commit().await.unwrap().state_update
+ };
+
+ state_updates.push(state_update);
+ }
+
+ state_updates
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn all_process_message(
+ groups: &mut [Group<TestClientConfig>],
+ message: &MlsMessage,
+ sender: usize,
+ is_commit: bool,
+) {
+ for group in groups {
+ if sender != group.current_member_index() as usize {
+ group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+ } else if is_commit {
+ group.apply_pending_commit().await.unwrap();
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn add_random_members(
+ first_id: usize,
+ num_added: usize,
+ committer: usize,
+ groups: &mut Vec<Group<TestClientConfig>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let cipher_suite = groups[committer].cipher_suite();
+ let committer_index = groups[committer].current_member_index() as usize;
+
+ let mut key_packages = Vec::new();
+ let mut new_clients = Vec::new();
+
+ for i in 0..num_added {
+ let id = first_id + i;
+ let new_client = generate_client(
+ cipher_suite,
+ format!("dave-{id}").into(),
+ Preferences::default(),
+ );
+
+ let key_package = new_client
+ .client
+ .generate_key_package_message(
+ ProtocolVersion::MLS_10,
+ cipher_suite,
+ new_client.identity.clone(),
+ )
+ .await
+ .unwrap();
+
+ key_packages.push(key_package);
+ new_clients.push(new_client);
+ }
+
+ let committer_group = &mut groups[committer];
+ let mut commit = committer_group.commit_builder();
+
+ for key_package in key_packages {
+ commit = commit.add_member(key_package).unwrap();
+ }
+
+ let commit_output = commit.build().await.unwrap();
+
+ all_process_message(groups, &commit_output.commit_message, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit_output.commit_message, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let tree_data = groups[committer].export_tree().unwrap();
+
+ let mut new_groups = Vec::new();
+
+ for client in &new_clients {
+ let tree_data = tree_data.clone();
+ let commit = commit_output.welcome_message.clone().unwrap();
+
+ let client = client
+ .client
+ .join_group(Some(&tree_data.clone()), commit)
+ .await
+ .unwrap()
+ .0;
+
+ new_groups.push(client);
+ }
+
+ groups.append(&mut new_groups);
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn remove_members(
+ removed_members: Vec<usize>,
+ committer: usize,
+ groups: &mut Vec<Group<TestClientConfig>>,
+ test_case: Option<&mut TestCase>,
+) {
+ let remove_indexes = removed_members
+ .iter()
+ .map(|removed| groups[*removed].current_member_index())
+ .collect::<Vec<u32>>();
+
+ let mut commit_builder = groups[committer].commit_builder();
+
+ for index in remove_indexes {
+ commit_builder = commit_builder.remove_member(index).unwrap();
+ }
+
+ let commit = commit_builder.build().await.unwrap().commit_message;
+ let committer_index = groups[committer].current_member_index() as usize;
+ all_process_message(groups, &commit, committer_index, true).await;
+
+ let auth = groups[committer].epoch_authenticator().unwrap().to_vec();
+ let epoch = TestEpoch::new(vec![], &commit, auth);
+
+ if let Some(tc) = test_case {
+ tc.epochs.push(epoch)
+ };
+
+ let mut index = 0;
+
+ groups.retain(|_| {
+ index += 1;
+ !(removed_members.contains(&(index - 1)))
+ });
+}
diff --git a/tests/client_tests.rs b/tests/client_tests.rs
new file mode 100644
index 0000000..61e46ce
--- /dev/null
+++ b/tests/client_tests.rs
@@ -0,0 +1,847 @@
+// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
+// Copyright by contributors to this project.
+// SPDX-License-Identifier: (Apache-2.0 OR MIT)
+
+use assert_matches::assert_matches;
+use cfg_if::cfg_if;
+use mls_rs::client_builder::MlsConfig;
+use mls_rs::error::MlsError;
+use mls_rs::group::proposal::Proposal;
+use mls_rs::group::ReceivedMessage;
+use mls_rs::identity::SigningIdentity;
+use mls_rs::mls_rules::CommitOptions;
+use mls_rs::ExtensionList;
+use mls_rs::MlsMessage;
+use mls_rs::ProtocolVersion;
+use mls_rs::{CipherSuite, Group};
+use mls_rs::{Client, CryptoProvider};
+use mls_rs_core::crypto::CipherSuiteProvider;
+use rand::prelude::SliceRandom;
+use rand::RngCore;
+
+use mls_rs::test_utils::{all_process_message, get_test_basic_credential};
+
+#[cfg(mls_build_async)]
+use futures::Future;
+
+cfg_if! {
+ if #[cfg(target_arch = "wasm32")] {
+ use mls_rs_crypto_webcrypto::WebCryptoProvider as TestCryptoProvider;
+ } else {
+ use mls_rs_crypto_openssl::OpensslCryptoProvider as TestCryptoProvider;
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn generate_client(
+ cipher_suite: CipherSuite,
+ protocol_version: ProtocolVersion,
+ id: usize,
+ encrypt_controls: bool,
+) -> Client<impl MlsConfig> {
+ mls_rs::test_utils::generate_basic_client(
+ cipher_suite,
+ protocol_version,
+ id,
+ None,
+ encrypt_controls,
+ &TestCryptoProvider::default(),
+ None,
+ )
+ .await
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+pub async fn get_test_groups(
+ version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ num_participants: usize,
+ encrypt_controls: bool,
+) -> Vec<Group<impl MlsConfig>> {
+ mls_rs::test_utils::get_test_groups(
+ version,
+ cipher_suite,
+ num_participants,
+ None,
+ encrypt_controls,
+ &TestCryptoProvider::default(),
+ )
+ .await
+}
+
+use rand::seq::IteratorRandom;
+
+#[cfg(target_arch = "wasm32")]
+wasm_bindgen_test::wasm_bindgen_test_configure!(run_in_browser);
+
+#[cfg(target_arch = "wasm32")]
+use wasm_bindgen_test::wasm_bindgen_test as futures_test;
+
+#[cfg(all(mls_build_async, not(target_arch = "wasm32")))]
+use futures_test::test as futures_test;
+
+#[cfg(feature = "private_message")]
+#[cfg(mls_build_async)]
+async fn test_on_all_params<F, Fut>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
+ Fut: Future<Output = ()>,
+{
+ for version in ProtocolVersion::all() {
+ for cs in TestCryptoProvider::all_supported_cipher_suites() {
+ for encrypt_controls in [true, false] {
+ test(version, cs, 10, encrypt_controls).await;
+ }
+ }
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[cfg(not(mls_build_async))]
+fn test_on_all_params<F>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool),
+{
+ for version in ProtocolVersion::all() {
+ for cs in TestCryptoProvider::all_supported_cipher_suites() {
+ for encrypt_controls in [true, false] {
+ test(version, cs, 10, encrypt_controls);
+ }
+ }
+ }
+}
+
+#[cfg(not(feature = "private_message"))]
+#[cfg(mls_build_async)]
+async fn test_on_all_params<F, Fut>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
+ Fut: Future<Output = ()>,
+{
+ test_on_all_params_plaintext(test).await;
+}
+
+#[cfg(not(feature = "private_message"))]
+#[cfg(not(mls_build_async))]
+fn test_on_all_params<F>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool),
+{
+ test_on_all_params_plaintext(test);
+}
+
+#[cfg(mls_build_async)]
+async fn test_on_all_params_plaintext<F, Fut>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool) -> Fut,
+ Fut: Future<Output = ()>,
+{
+ for version in ProtocolVersion::all() {
+ for cs in TestCryptoProvider::all_supported_cipher_suites() {
+ test(version, cs, 10, false).await;
+ }
+ }
+}
+
+#[cfg(not(mls_build_async))]
+fn test_on_all_params_plaintext<F>(test: F)
+where
+ F: Fn(ProtocolVersion, CipherSuite, usize, bool),
+{
+ for version in ProtocolVersion::all() {
+ for cs in TestCryptoProvider::all_supported_cipher_suites() {
+ test(version, cs, 10, false);
+ }
+ }
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn test_create(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ _n_participants: usize,
+ encrypt_controls: bool,
+) {
+ let alice = generate_client(cipher_suite, protocol_version, 0, encrypt_controls).await;
+ let bob = generate_client(cipher_suite, protocol_version, 1, encrypt_controls).await;
+ let bob_key_pkg = bob.generate_key_package_message().await.unwrap();
+
+ // Alice creates a group and adds bob
+ let mut alice_group = alice
+ .create_group_with_id(b"group".to_vec(), ExtensionList::default())
+ .await
+ .unwrap();
+
+ let welcome = &alice_group
+ .commit_builder()
+ .add_member(bob_key_pkg)
+ .unwrap()
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages[0];
+
+ // Upon server confirmation, alice applies the commit to her own state
+ alice_group.apply_pending_commit().await.unwrap();
+
+ // Bob receives the welcome message and joins the group
+ let (bob_group, _) = bob.join_group(None, welcome).await.unwrap();
+
+ assert!(Group::equal_group_state(&alice_group, &bob_group));
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_create_group() {
+ test_on_all_params(test_create).await;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn test_empty_commits(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ participants: usize,
+ encrypt_controls: bool,
+) {
+ let mut groups = get_test_groups(
+ protocol_version,
+ cipher_suite,
+ participants,
+ encrypt_controls,
+ )
+ .await;
+
+ // Loop through each participant and send a path update
+
+ for i in 0..groups.len() {
+ // Create the commit
+ let commit_output = groups[i].commit(Vec::new()).await.unwrap();
+
+ assert!(commit_output.welcome_messages.is_empty());
+
+ let index = groups[i].current_member_index() as usize;
+ all_process_message(&mut groups, &commit_output.commit_message, index, true).await;
+
+ for other_group in groups.iter() {
+ assert!(Group::equal_group_state(other_group, &groups[i]));
+ }
+ }
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_group_path_updates() {
+ test_on_all_params(test_empty_commits).await;
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn test_update_proposals(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ participants: usize,
+ encrypt_controls: bool,
+) {
+ let mut groups = get_test_groups(
+ protocol_version,
+ cipher_suite,
+ participants,
+ encrypt_controls,
+ )
+ .await;
+
+ // Create an update from the ith member, have the ith + 1 member commit it
+ for i in 0..groups.len() - 1 {
+ let update_proposal_msg = groups[i].propose_update(Vec::new()).await.unwrap();
+
+ let sender = groups[i].current_member_index() as usize;
+ all_process_message(&mut groups, &update_proposal_msg, sender, false).await;
+
+ // Everyone receives the commit
+ let committer_index = i + 1;
+
+ let commit_output = groups[committer_index].commit(Vec::new()).await.unwrap();
+
+ assert!(commit_output.welcome_messages.is_empty());
+
+ let commit = commit_output.commit_message;
+
+ all_process_message(&mut groups, &commit, committer_index, true).await;
+
+ groups
+ .iter()
+ .for_each(|g| assert!(Group::equal_group_state(g, &groups[0])));
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_group_update_proposals() {
+ test_on_all_params(test_update_proposals).await;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn test_remove_proposals(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ participants: usize,
+ encrypt_controls: bool,
+) {
+ let mut groups = get_test_groups(
+ protocol_version,
+ cipher_suite,
+ participants,
+ encrypt_controls,
+ )
+ .await;
+
+ // Remove people from the group one at a time
+ while groups.len() > 1 {
+ let removed_and_committer = (0..groups.len()).choose_multiple(&mut rand::thread_rng(), 2);
+
+ let to_remove = removed_and_committer[0];
+ let committer = removed_and_committer[1];
+ let to_remove_index = groups[to_remove].current_member_index();
+
+ let epoch_before_remove = groups[committer].current_epoch();
+
+ let commit_output = groups[committer]
+ .commit_builder()
+ .remove_member(to_remove_index)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ assert!(commit_output.welcome_messages.is_empty());
+
+ let commit = commit_output.commit_message;
+ let committer_index = groups[committer].current_member_index() as usize;
+ all_process_message(&mut groups, &commit, committer_index, true).await;
+
+ // Check that remove was effective
+ for (i, group) in groups.iter().enumerate() {
+ if i == to_remove {
+ assert_eq!(group.current_epoch(), epoch_before_remove);
+ } else {
+ assert_eq!(group.current_epoch(), epoch_before_remove + 1);
+ assert!(group.roster().member_with_index(to_remove_index).is_err());
+ }
+ }
+
+ groups.retain(|group| group.current_member_index() != to_remove_index);
+
+ for one_group in groups.iter() {
+ assert!(Group::equal_group_state(one_group, &groups[0]))
+ }
+ }
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_group_remove_proposals() {
+ test_on_all_params(test_remove_proposals).await;
+}
+
+#[cfg(feature = "private_message")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn test_application_messages(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ participants: usize,
+ encrypt_controls: bool,
+) {
+ let message_count = 20;
+
+ let mut groups = get_test_groups(
+ protocol_version,
+ cipher_suite,
+ participants,
+ encrypt_controls,
+ )
+ .await;
+
+ // Loop through each participant and send application messages
+ for i in 0..groups.len() {
+ let mut test_message = vec![0; 1024];
+ rand::thread_rng().fill_bytes(&mut test_message);
+
+ for _ in 0..message_count {
+ // Encrypt the application message
+ let ciphertext = groups[i]
+ .encrypt_application_message(&test_message, Vec::new())
+ .await
+ .unwrap();
+
+ let sender_index = groups[i].current_member_index();
+
+ for g in groups.iter_mut() {
+ if g.current_member_index() != sender_index {
+ let decrypted = g
+ .process_incoming_message(ciphertext.clone())
+ .await
+ .unwrap();
+
+ assert_matches!(decrypted, ReceivedMessage::ApplicationMessage(m) if m.data() == test_message);
+ }
+ }
+ }
+ }
+}
+
+#[cfg(all(feature = "private_message", feature = "out_of_order"))]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_out_of_order_application_messages() {
+ let mut groups =
+ get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 2, false).await;
+
+ let mut alice_group = groups[0].clone();
+ let bob_group = &mut groups[1];
+
+ let ciphertext = alice_group
+ .encrypt_application_message(&[0], Vec::new())
+ .await
+ .unwrap();
+
+ let mut ciphertexts = vec![ciphertext];
+
+ ciphertexts.push(
+ alice_group
+ .encrypt_application_message(&[1], Vec::new())
+ .await
+ .unwrap(),
+ );
+
+ let commit = alice_group.commit(Vec::new()).await.unwrap().commit_message;
+
+ alice_group.apply_pending_commit().await.unwrap();
+
+ bob_group.process_incoming_message(commit).await.unwrap();
+
+ ciphertexts.push(
+ alice_group
+ .encrypt_application_message(&[2], Vec::new())
+ .await
+ .unwrap(),
+ );
+
+ ciphertexts.push(
+ alice_group
+ .encrypt_application_message(&[3], Vec::new())
+ .await
+ .unwrap(),
+ );
+
+ for i in [3, 2, 1, 0] {
+ let res = bob_group
+ .process_incoming_message(ciphertexts[i].clone())
+ .await
+ .unwrap();
+
+ assert_matches!(
+ res,
+ ReceivedMessage::ApplicationMessage(m) if m.data() == [i as u8]
+ );
+ }
+}
+
+#[cfg(feature = "private_message")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_group_application_messages() {
+ test_on_all_params(test_application_messages).await
+}
+
+#[cfg(feature = "private_message")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn processing_message_from_self_returns_error(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ _n_participants: usize,
+ encrypt_controls: bool,
+) {
+ let mut creator_group =
+ get_test_groups(protocol_version, cipher_suite, 1, encrypt_controls).await;
+ let creator_group = &mut creator_group[0];
+
+ let msg = creator_group
+ .encrypt_application_message(b"hello self", vec![])
+ .await
+ .unwrap();
+
+ let error = creator_group
+ .process_incoming_message(msg)
+ .await
+ .unwrap_err();
+
+ assert_matches!(error, MlsError::CantProcessMessageFromSelf);
+}
+
+#[cfg(feature = "private_message")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_processing_message_from_self_returns_error() {
+ test_on_all_params(processing_message_from_self_returns_error).await;
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn external_commits_work(
+ protocol_version: ProtocolVersion,
+ cipher_suite: CipherSuite,
+ _n_participants: usize,
+ _encrypt_controls: bool,
+) {
+ let creator = generate_client(cipher_suite, protocol_version, 0, false).await;
+
+ let creator_group = creator
+ .create_group_with_id(b"group".to_vec(), ExtensionList::default())
+ .await
+ .unwrap();
+
+ const PARTICIPANT_COUNT: usize = 10;
+
+ let mut others = Vec::new();
+
+ for i in 1..PARTICIPANT_COUNT {
+ others.push(generate_client(cipher_suite, protocol_version, i, Default::default()).await)
+ }
+
+ let mut groups = vec![creator_group];
+
+ for client in &others {
+ let existing_group = groups.choose_mut(&mut rand::thread_rng()).unwrap();
+
+ let group_info = existing_group
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ let (new_group, commit) = client
+ .external_commit_builder()
+ .unwrap()
+ .build(group_info)
+ .await
+ .unwrap();
+
+ for group in groups.iter_mut() {
+ group
+ .process_incoming_message(commit.clone())
+ .await
+ .unwrap();
+ }
+
+ groups.push(new_group);
+ }
+
+ assert!(groups
+ .iter()
+ .all(|group| group.roster().members_iter().count() == PARTICIPANT_COUNT));
+
+ for i in 0..groups.len() {
+ let message = groups[i].propose_remove(0, Vec::new()).await.unwrap();
+
+ for (_, group) in groups.iter_mut().enumerate().filter(|&(j, _)| i != j) {
+ let processed = group
+ .process_incoming_message(message.clone())
+ .await
+ .unwrap();
+
+ if let ReceivedMessage::Proposal(p) = &processed {
+ if let Proposal::Remove(r) = &p.proposal {
+ if r.to_remove() == 0 {
+ continue;
+ }
+ }
+ }
+
+ panic!("expected a proposal, got {processed:?}");
+ }
+ }
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_external_commits() {
+ test_on_all_params_plaintext(external_commits_work).await
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn test_remove_nonexisting_leaf() {
+ let mut groups =
+ get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 10, false).await;
+
+ groups[0]
+ .commit_builder()
+ .remove_member(5)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+ groups[0].apply_pending_commit().await.unwrap();
+
+ // Leaf index out of bounds
+ assert!(groups[0].commit_builder().remove_member(13).is_err());
+
+ // Removing blank leaf causes error
+ assert!(groups[0].commit_builder().remove_member(5).is_err());
+}
+
+#[cfg(feature = "psk")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn reinit_works() {
+ let suite1 = CipherSuite::P256_AES128;
+
+ let Some(suite2) = CipherSuite::all()
+ .find(|cs| cs != &suite1 && TestCryptoProvider::all_supported_cipher_suites().contains(cs))
+ else {
+ return;
+ };
+
+ let version = ProtocolVersion::MLS_10;
+
+ let alice1 = generate_client(suite1, version, 1, Default::default()).await;
+ let bob1 = generate_client(suite1, version, 2, Default::default()).await;
+
+ // Create a group with 2 parties
+ let mut alice_group = alice1.create_group(ExtensionList::new()).await.unwrap();
+ let kp = bob1.generate_key_package_message().await.unwrap();
+
+ let welcome = &alice_group
+ .commit_builder()
+ .add_member(kp)
+ .unwrap()
+ .build()
+ .await
+ .unwrap()
+ .welcome_messages[0];
+
+ alice_group.apply_pending_commit().await.unwrap();
+
+ let (mut bob_group, _) = bob1.join_group(None, welcome).await.unwrap();
+
+ // Alice proposes reinit
+ let reinit_proposal_message = alice_group
+ .propose_reinit(
+ None,
+ ProtocolVersion::MLS_10,
+ suite2,
+ ExtensionList::default(),
+ Vec::new(),
+ )
+ .await
+ .unwrap();
+
+ // Bob commits the reinit
+ bob_group
+ .process_incoming_message(reinit_proposal_message)
+ .await
+ .unwrap();
+
+ let commit = bob_group.commit(Vec::new()).await.unwrap().commit_message;
+
+ // Both process Bob's commit
+
+ #[cfg(feature = "state_update")]
+ {
+ let state_update = bob_group.apply_pending_commit().await.unwrap().state_update;
+ assert!(!state_update.is_active() && state_update.is_pending_reinit());
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ bob_group.apply_pending_commit().await.unwrap();
+
+ let message = alice_group.process_incoming_message(commit).await.unwrap();
+
+ #[cfg(feature = "state_update")]
+ if let ReceivedMessage::Commit(commit_description) = message {
+ assert!(
+ !commit_description.state_update.is_active()
+ && commit_description.state_update.is_pending_reinit()
+ );
+ }
+
+ #[cfg(not(feature = "state_update"))]
+ assert_matches!(message, ReceivedMessage::Commit(_));
+
+ // They can't create new epochs anymore
+ let res = alice_group.commit(Vec::new()).await;
+ assert!(res.is_err());
+
+ let res = bob_group.commit(Vec::new()).await;
+ assert!(res.is_err());
+
+ // Get reinit clients for alice and bob
+ let (secret_key, public_key) = TestCryptoProvider::new()
+ .cipher_suite_provider(suite2)
+ .unwrap()
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let identity = SigningIdentity::new(get_test_basic_credential(b"bob".to_vec()), public_key);
+
+ let bob2 = bob_group
+ .get_reinit_client(Some(secret_key), Some(identity))
+ .unwrap();
+
+ let (secret_key, public_key) = TestCryptoProvider::new()
+ .cipher_suite_provider(suite2)
+ .unwrap()
+ .signature_key_generate()
+ .await
+ .unwrap();
+
+ let identity = SigningIdentity::new(get_test_basic_credential(b"alice".to_vec()), public_key);
+
+ let alice2 = alice_group
+ .get_reinit_client(Some(secret_key), Some(identity))
+ .unwrap();
+
+ // Bob produces key package, alice commits, bob joins
+ let kp = bob2.generate_key_package().await.unwrap();
+ let (mut alice_group, welcome) = alice2.commit(vec![kp]).await.unwrap();
+ let (mut bob_group, _) = bob2.join(&welcome[0], None).await.unwrap();
+
+ assert!(bob_group.cipher_suite() == suite2);
+
+ // They can talk
+ let carol = generate_client(suite2, version, 3, Default::default()).await;
+
+ let kp = carol.generate_key_package_message().await.unwrap();
+
+ let commit_output = alice_group
+ .commit_builder()
+ .add_member(kp)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ alice_group.apply_pending_commit().await.unwrap();
+
+ bob_group
+ .process_incoming_message(commit_output.commit_message)
+ .await
+ .unwrap();
+
+ carol
+ .join_group(None, &commit_output.welcome_messages[0])
+ .await
+ .unwrap();
+}
+
+#[cfg(feature = "by_ref_proposal")]
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn external_joiner_can_process_siblings_update() {
+ let mut groups =
+ get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 3, false).await;
+
+ // Remove leaf 1 s.t. the external joiner joins in its place
+ let c = groups[0]
+ .commit_builder()
+ .remove_member(1)
+ .unwrap()
+ .build()
+ .await
+ .unwrap();
+
+ all_process_message(&mut groups, &c.commit_message, 0, true).await;
+
+ let info = groups[0]
+ .group_info_message_allowing_ext_commit(true)
+ .await
+ .unwrap();
+
+ // Create the external joiner and join
+ let new_client = generate_client(
+ CipherSuite::P256_AES128,
+ ProtocolVersion::MLS_10,
+ 0xabba,
+ false,
+ )
+ .await;
+
+ let (mut group, commit) = new_client.commit_external(info).await.unwrap();
+
+ all_process_message(&mut groups, &commit, 1, false).await;
+ groups.remove(1);
+
+ // New client's sibling proposes an update to blank their common parent
+ let p = groups[0].propose_update(Vec::new()).await.unwrap();
+ all_process_message(&mut groups, &p, 0, false).await;
+ group.process_incoming_message(p).await.unwrap();
+
+ // Some other member commits
+ let c = groups[1].commit(Vec::new()).await.unwrap().commit_message;
+ all_process_message(&mut groups, &c, 2, true).await;
+ group.process_incoming_message(c).await.unwrap();
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn weird_tree_scenario() {
+ let mut groups =
+ get_test_groups(ProtocolVersion::MLS_10, CipherSuite::P256_AES128, 17, false).await;
+
+ let to_remove = [0u32, 2, 5, 7, 8, 9, 15];
+
+ let mut builder = groups[14].commit_builder();
+
+ for idx in to_remove.iter() {
+ builder = builder.remove_member(*idx).unwrap();
+ }
+
+ let commit = builder.build().await.unwrap();
+
+ for idx in to_remove.into_iter().rev() {
+ groups.remove(idx as usize);
+ }
+
+ all_process_message(&mut groups, &commit.commit_message, 14, true).await;
+
+ let mut builder = groups.last_mut().unwrap().commit_builder();
+
+ for idx in 0..7 {
+ builder = builder
+ .add_member(fake_key_package(5555555 + idx).await)
+ .unwrap()
+ }
+
+ let commit = builder.remove_member(1).unwrap().build().await.unwrap();
+
+ let idx = groups.last().unwrap().current_member_index() as usize;
+
+ all_process_message(&mut groups, &commit.commit_message, idx, true).await;
+}
+
+#[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
+async fn fake_key_package(id: usize) -> MlsMessage {
+ generate_client(CipherSuite::P256_AES128, ProtocolVersion::MLS_10, id, false)
+ .await
+ .generate_key_package_message()
+ .await
+ .unwrap()
+}
+
+#[maybe_async::test(not(mls_build_async), async(mls_build_async, futures_test))]
+async fn external_info_from_commit_allows_to_join() {
+ let cs = CipherSuite::P256_AES128;
+ let version = ProtocolVersion::MLS_10;
+
+ let mut alice = mls_rs::test_utils::get_test_groups(
+ version,
+ cs,
+ 1,
+ Some(CommitOptions::new().with_allow_external_commit(true)),
+ false,
+ &TestCryptoProvider::default(),
+ )
+ .await
+ .remove(0);
+
+ let commit = alice.commit(vec![]).await.unwrap();
+ alice.apply_pending_commit().await.unwrap();
+ let bob = generate_client(cs, version, 0xdead, false).await;
+
+ let (_bob, commit) = bob
+ .commit_external(commit.external_commit_group_info.unwrap())
+ .await
+ .unwrap();
+
+ alice.process_incoming_message(commit).await.unwrap();
+}
diff --git a/webdriver.json b/webdriver.json
new file mode 100644
index 0000000..368df03
--- /dev/null
+++ b/webdriver.json
@@ -0,0 +1,9 @@
+{
+ "goog:chromeOptions": {
+ "args": [
+ "--disable-timeouts-for-profiling",
+ "--disable-new-content-rendering-timeout",
+ "--timeout 100000"
+ ]
+ }
+}