diff options
Diffstat (limited to 'bumble/pandora/security.py')
-rw-r--r-- | bumble/pandora/security.py | 120 |
1 files changed, 76 insertions, 44 deletions
diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py index fee1b7a..0f31512 100644 --- a/bumble/pandora/security.py +++ b/bumble/pandora/security.py @@ -13,6 +13,7 @@ # limitations under the License. import asyncio +import contextlib import grpc import logging @@ -27,14 +28,11 @@ from bumble.core import ( ) from bumble.device import Connection as BumbleConnection, Device from bumble.hci import HCI_Error +from bumble.utils import EventWatcher from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate -from contextlib import suppress -from google.protobuf import ( - any_pb2, - empty_pb2, - wrappers_pb2, -) # pytype: disable=pyi-error -from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error +from google.protobuf import any_pb2 # pytype: disable=pyi-error +from google.protobuf import empty_pb2 # pytype: disable=pyi-error +from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error from pandora.host_pb2 import Connection from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer from pandora.security_pb2 import ( @@ -102,7 +100,7 @@ class PairingDelegate(BasePairingDelegate): return ev async def confirm(self, auto: bool = False) -> bool: - self.log.info( + self.log.debug( f"Pairing event: `just_works` (io_capability: {self.io_capability})" ) @@ -117,7 +115,7 @@ class PairingDelegate(BasePairingDelegate): return answer.confirm async def compare_numbers(self, number: int, digits: int = 6) -> bool: - self.log.info( + self.log.debug( f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})" ) @@ -132,7 +130,7 @@ class PairingDelegate(BasePairingDelegate): return answer.confirm async def get_number(self) -> Optional[int]: - self.log.info( + self.log.debug( f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})" ) @@ -149,7 +147,7 @@ class PairingDelegate(BasePairingDelegate): return answer.passkey async def get_string(self, max_length: int) -> Optional[str]: - self.log.info( + self.log.debug( f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})" ) @@ -180,7 +178,7 @@ class PairingDelegate(BasePairingDelegate): ): return - self.log.info( + self.log.debug( f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})" ) @@ -235,6 +233,11 @@ class SecurityService(SecurityServicer): sc=config.pairing_sc_enable, mitm=config.pairing_mitm_enable, bonding=config.pairing_bonding_enable, + identity_address_type=( + PairingConfig.AddressType.PUBLIC + if connection.self_address.is_public + else config.identity_address_type + ), delegate=PairingDelegate( connection, self, @@ -250,7 +253,7 @@ class SecurityService(SecurityServicer): async def OnPairing( self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext ) -> AsyncGenerator[PairingEvent, None]: - self.log.info('OnPairing') + self.log.debug('OnPairing') if self.event_queue is not None: raise RuntimeError('already streaming pairing events') @@ -276,7 +279,7 @@ class SecurityService(SecurityServicer): self, request: SecureRequest, context: grpc.ServicerContext ) -> SecureResponse: connection_handle = int.from_bytes(request.connection.cookie.value, 'big') - self.log.info(f"Secure: {connection_handle}") + self.log.debug(f"Secure: {connection_handle}") connection = self.device.lookup_connection(connection_handle) assert connection @@ -294,25 +297,37 @@ class SecurityService(SecurityServicer): # trigger pairing if needed if self.need_pairing(connection, level): try: - self.log.info('Pair...') + self.log.debug('Pair...') + + security_result = asyncio.get_running_loop().create_future() + + with contextlib.closing(EventWatcher()) as watcher: - if ( - connection.transport == BT_LE_TRANSPORT - and connection.role == BT_PERIPHERAL_ROLE - ): - wait_for_security: asyncio.Future[ - bool - ] = asyncio.get_running_loop().create_future() - connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore - connection.on("pairing_failure", wait_for_security.set_exception) + @watcher.on(connection, 'pairing') + def on_pairing(*_: Any) -> None: + security_result.set_result('success') - connection.request_pairing() + @watcher.on(connection, 'pairing_failure') + def on_pairing_failure(*_: Any) -> None: + security_result.set_result('pairing_failure') - await wait_for_security - else: - await connection.pair() + @watcher.on(connection, 'disconnection') + def on_disconnection(*_: Any) -> None: + security_result.set_result('connection_died') - self.log.info('Paired') + if ( + connection.transport == BT_LE_TRANSPORT + and connection.role == BT_PERIPHERAL_ROLE + ): + connection.request_pairing() + else: + await connection.pair() + + result = await security_result + + self.log.debug(f'Pairing session complete, status={result}') + if result != 'success': + return SecureResponse(**{result: empty_pb2.Empty()}) except asyncio.CancelledError: self.log.warning("Connection died during encryption") return SecureResponse(connection_died=empty_pb2.Empty()) @@ -323,9 +338,9 @@ class SecurityService(SecurityServicer): # trigger authentication if needed if self.need_authentication(connection, level): try: - self.log.info('Authenticate...') + self.log.debug('Authenticate...') await connection.authenticate() - self.log.info('Authenticated') + self.log.debug('Authenticated') except asyncio.CancelledError: self.log.warning("Connection died during authentication") return SecureResponse(connection_died=empty_pb2.Empty()) @@ -336,9 +351,9 @@ class SecurityService(SecurityServicer): # trigger encryption if needed if self.need_encryption(connection, level): try: - self.log.info('Encrypt...') + self.log.debug('Encrypt...') await connection.encrypt() - self.log.info('Encrypted') + self.log.debug('Encrypted') except asyncio.CancelledError: self.log.warning("Connection died during encryption") return SecureResponse(connection_died=empty_pb2.Empty()) @@ -356,7 +371,7 @@ class SecurityService(SecurityServicer): self, request: WaitSecurityRequest, context: grpc.ServicerContext ) -> WaitSecurityResponse: connection_handle = int.from_bytes(request.connection.cookie.value, 'big') - self.log.info(f"WaitSecurity: {connection_handle}") + self.log.debug(f"WaitSecurity: {connection_handle}") connection = self.device.lookup_connection(connection_handle) assert connection @@ -371,6 +386,7 @@ class SecurityService(SecurityServicer): str ] = asyncio.get_running_loop().create_future() authenticate_task: Optional[asyncio.Future[None]] = None + pair_task: Optional[asyncio.Future[None]] = None async def authenticate() -> None: assert connection @@ -393,7 +409,7 @@ class SecurityService(SecurityServicer): def set_failure(name: str) -> Callable[..., None]: def wrapper(*args: Any) -> None: - self.log.info(f'Wait for security: error `{name}`: {args}') + self.log.debug(f'Wait for security: error `{name}`: {args}') wait_for_security.set_result(name) return wrapper @@ -401,13 +417,13 @@ class SecurityService(SecurityServicer): def try_set_success(*_: Any) -> None: assert connection if self.reached_security_level(connection, level): - self.log.info('Wait for security: done') + self.log.debug('Wait for security: done') wait_for_security.set_result('success') def on_encryption_change(*_: Any) -> None: assert connection if self.reached_security_level(connection, level): - self.log.info('Wait for security: done') + self.log.debug('Wait for security: done') wait_for_security.set_result('success') elif ( connection.transport == BT_BR_EDR_TRANSPORT @@ -417,6 +433,10 @@ class SecurityService(SecurityServicer): if authenticate_task is None: authenticate_task = asyncio.create_task(authenticate()) + def pair(*_: Any) -> None: + if self.need_pairing(connection, level): + pair_task = asyncio.create_task(connection.pair()) + listeners: Dict[str, Callable[..., None]] = { 'disconnection': set_failure('connection_died'), 'pairing_failure': set_failure('pairing_failure'), @@ -425,6 +445,9 @@ class SecurityService(SecurityServicer): 'pairing': try_set_success, 'connection_authentication': try_set_success, 'connection_encryption_change': on_encryption_change, + 'classic_pairing': try_set_success, + 'classic_pairing_failure': set_failure('pairing_failure'), + 'security_request': pair, } # register event handlers @@ -435,7 +458,7 @@ class SecurityService(SecurityServicer): if self.reached_security_level(connection, level): return WaitSecurityResponse(success=empty_pb2.Empty()) - self.log.info('Wait for security...') + self.log.debug('Wait for security...') kwargs = {} kwargs[await wait_for_security] = empty_pb2.Empty() @@ -445,12 +468,21 @@ class SecurityService(SecurityServicer): # wait for `authenticate` to finish if any if authenticate_task is not None: - self.log.info('Wait for authentication...') + self.log.debug('Wait for authentication...') try: await authenticate_task # type: ignore except: pass - self.log.info('Authenticated') + self.log.debug('Authenticated') + + # wait for `pair` to finish if any + if pair_task is not None: + self.log.debug('Wait for authentication...') + try: + await pair_task # type: ignore + except: + pass + self.log.debug('paired') return WaitSecurityResponse(**kwargs) @@ -506,24 +538,24 @@ class SecurityStorageService(SecurityStorageServicer): self, request: IsBondedRequest, context: grpc.ServicerContext ) -> wrappers_pb2.BoolValue: address = utils.address_from_request(request, request.WhichOneof("address")) - self.log.info(f"IsBonded: {address}") + self.log.debug(f"IsBonded: {address}") if self.device.keystore is not None: is_bonded = await self.device.keystore.get(str(address)) is not None else: is_bonded = False - return BoolValue(value=is_bonded) + return wrappers_pb2.BoolValue(value=is_bonded) @utils.rpc async def DeleteBond( self, request: DeleteBondRequest, context: grpc.ServicerContext ) -> empty_pb2.Empty: address = utils.address_from_request(request, request.WhichOneof("address")) - self.log.info(f"DeleteBond: {address}") + self.log.debug(f"DeleteBond: {address}") if self.device.keystore is not None: - with suppress(KeyError): + with contextlib.suppress(KeyError): await self.device.keystore.delete(str(address)) return empty_pb2.Empty() |