diff options
Diffstat (limited to 'private_join_and_compute/py/ciphers/ec_cipher_test.py')
-rw-r--r-- | private_join_and_compute/py/ciphers/ec_cipher_test.py | 78 |
1 files changed, 78 insertions, 0 deletions
diff --git a/private_join_and_compute/py/ciphers/ec_cipher_test.py b/private_join_and_compute/py/ciphers/ec_cipher_test.py new file mode 100644 index 0000000..5bcf082 --- /dev/null +++ b/private_join_and_compute/py/ciphers/ec_cipher_test.py @@ -0,0 +1,78 @@ +# Copyright 2019 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test class for EcCommutativeCipher.""" + +import unittest +from private_join_and_compute.py.ciphers import ec_cipher +from private_join_and_compute.py.crypto_util import supported_curves +from private_join_and_compute.py.crypto_util import supported_hashes + + +class EcCommutativeCipherTest(unittest.TestCase): + + def setUp(self): + super(EcCommutativeCipherTest, self).setUp() + self.client_cipher = ec_cipher.EcCipher(713) + self.server_cipher = ec_cipher.EcCipher(713) + + def ReEncryptionSameId(self, cipher1, cipher2): + user_id = b'3274646578436540569872403985702934875092834502' + enc_id1 = cipher1.Encrypt(user_id) + enc_id2 = cipher2.Encrypt(user_id) + result1 = cipher2.ReEncrypt(enc_id1) + result2 = cipher1.ReEncrypt(enc_id2) + self.assertEqual(result1, result2) + + def testReEncryptionSameId(self): + self.ReEncryptionSameId(self.client_cipher, self.server_cipher) + + def testReEncryptionDifferentId(self): + user_id1 = b'3274646578436540569872403985702934875092834502' + user_id2 = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id1) + enc_id2 = self.server_cipher.Encrypt(user_id2) + result1 = self.server_cipher.ReEncrypt(enc_id1) + result2 = self.client_cipher.ReEncrypt(enc_id2) + self.assertNotEqual(result1, result2) + + def testDecode(self): + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = self.client_cipher.Encrypt(user_id) + enc_id2 = self.server_cipher.Encrypt(user_id) + result1 = self.server_cipher.ReEncrypt(enc_id1) + actual_enc_id1 = self.client_cipher.DecryptReEncryptedId(result1) + actual_enc_id2 = self.server_cipher.DecryptReEncryptedId(result1) + self.assertEqual(enc_id1, actual_enc_id2) + self.assertEqual(enc_id2, actual_enc_id1) + + def testDifferentHashFunctions(self): + # freshly sampled key + sha256_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA256, + ) + sha512_cipher = ec_cipher.EcCipher( + curve_id=supported_curves.SupportedCurve.SECP256R1.id, + hash_type=supported_hashes.HashType.SHA512, + private_key_bytes=sha256_cipher.ec_key.priv_key_bytes, + ) + user_id = b'7402039857096829483572943875209348524958235824' + enc_id1 = sha256_cipher.Encrypt(user_id) + enc_id2 = sha512_cipher.Encrypt(user_id) + self.assertNotEqual(enc_id1, enc_id2) + + +if __name__ == '__main__': + unittest.main() |