aboutsummaryrefslogtreecommitdiff
path: root/bumble/pandora/security.py
diff options
context:
space:
mode:
Diffstat (limited to 'bumble/pandora/security.py')
-rw-r--r--bumble/pandora/security.py120
1 files changed, 76 insertions, 44 deletions
diff --git a/bumble/pandora/security.py b/bumble/pandora/security.py
index fee1b7a..0f31512 100644
--- a/bumble/pandora/security.py
+++ b/bumble/pandora/security.py
@@ -13,6 +13,7 @@
# limitations under the License.
import asyncio
+import contextlib
import grpc
import logging
@@ -27,14 +28,11 @@ from bumble.core import (
)
from bumble.device import Connection as BumbleConnection, Device
from bumble.hci import HCI_Error
+from bumble.utils import EventWatcher
from bumble.pairing import PairingConfig, PairingDelegate as BasePairingDelegate
-from contextlib import suppress
-from google.protobuf import (
- any_pb2,
- empty_pb2,
- wrappers_pb2,
-) # pytype: disable=pyi-error
-from google.protobuf.wrappers_pb2 import BoolValue # pytype: disable=pyi-error
+from google.protobuf import any_pb2 # pytype: disable=pyi-error
+from google.protobuf import empty_pb2 # pytype: disable=pyi-error
+from google.protobuf import wrappers_pb2 # pytype: disable=pyi-error
from pandora.host_pb2 import Connection
from pandora.security_grpc_aio import SecurityServicer, SecurityStorageServicer
from pandora.security_pb2 import (
@@ -102,7 +100,7 @@ class PairingDelegate(BasePairingDelegate):
return ev
async def confirm(self, auto: bool = False) -> bool:
- self.log.info(
+ self.log.debug(
f"Pairing event: `just_works` (io_capability: {self.io_capability})"
)
@@ -117,7 +115,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def compare_numbers(self, number: int, digits: int = 6) -> bool:
- self.log.info(
+ self.log.debug(
f"Pairing event: `numeric_comparison` (io_capability: {self.io_capability})"
)
@@ -132,7 +130,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.confirm
async def get_number(self) -> Optional[int]:
- self.log.info(
+ self.log.debug(
f"Pairing event: `passkey_entry_request` (io_capability: {self.io_capability})"
)
@@ -149,7 +147,7 @@ class PairingDelegate(BasePairingDelegate):
return answer.passkey
async def get_string(self, max_length: int) -> Optional[str]:
- self.log.info(
+ self.log.debug(
f"Pairing event: `pin_code_request` (io_capability: {self.io_capability})"
)
@@ -180,7 +178,7 @@ class PairingDelegate(BasePairingDelegate):
):
return
- self.log.info(
+ self.log.debug(
f"Pairing event: `passkey_entry_notification` (io_capability: {self.io_capability})"
)
@@ -235,6 +233,11 @@ class SecurityService(SecurityServicer):
sc=config.pairing_sc_enable,
mitm=config.pairing_mitm_enable,
bonding=config.pairing_bonding_enable,
+ identity_address_type=(
+ PairingConfig.AddressType.PUBLIC
+ if connection.self_address.is_public
+ else config.identity_address_type
+ ),
delegate=PairingDelegate(
connection,
self,
@@ -250,7 +253,7 @@ class SecurityService(SecurityServicer):
async def OnPairing(
self, request: AsyncIterator[PairingEventAnswer], context: grpc.ServicerContext
) -> AsyncGenerator[PairingEvent, None]:
- self.log.info('OnPairing')
+ self.log.debug('OnPairing')
if self.event_queue is not None:
raise RuntimeError('already streaming pairing events')
@@ -276,7 +279,7 @@ class SecurityService(SecurityServicer):
self, request: SecureRequest, context: grpc.ServicerContext
) -> SecureResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- self.log.info(f"Secure: {connection_handle}")
+ self.log.debug(f"Secure: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -294,25 +297,37 @@ class SecurityService(SecurityServicer):
# trigger pairing if needed
if self.need_pairing(connection, level):
try:
- self.log.info('Pair...')
+ self.log.debug('Pair...')
+
+ security_result = asyncio.get_running_loop().create_future()
+
+ with contextlib.closing(EventWatcher()) as watcher:
- if (
- connection.transport == BT_LE_TRANSPORT
- and connection.role == BT_PERIPHERAL_ROLE
- ):
- wait_for_security: asyncio.Future[
- bool
- ] = asyncio.get_running_loop().create_future()
- connection.on("pairing", lambda *_: wait_for_security.set_result(True)) # type: ignore
- connection.on("pairing_failure", wait_for_security.set_exception)
+ @watcher.on(connection, 'pairing')
+ def on_pairing(*_: Any) -> None:
+ security_result.set_result('success')
- connection.request_pairing()
+ @watcher.on(connection, 'pairing_failure')
+ def on_pairing_failure(*_: Any) -> None:
+ security_result.set_result('pairing_failure')
- await wait_for_security
- else:
- await connection.pair()
+ @watcher.on(connection, 'disconnection')
+ def on_disconnection(*_: Any) -> None:
+ security_result.set_result('connection_died')
- self.log.info('Paired')
+ if (
+ connection.transport == BT_LE_TRANSPORT
+ and connection.role == BT_PERIPHERAL_ROLE
+ ):
+ connection.request_pairing()
+ else:
+ await connection.pair()
+
+ result = await security_result
+
+ self.log.debug(f'Pairing session complete, status={result}')
+ if result != 'success':
+ return SecureResponse(**{result: empty_pb2.Empty()})
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -323,9 +338,9 @@ class SecurityService(SecurityServicer):
# trigger authentication if needed
if self.need_authentication(connection, level):
try:
- self.log.info('Authenticate...')
+ self.log.debug('Authenticate...')
await connection.authenticate()
- self.log.info('Authenticated')
+ self.log.debug('Authenticated')
except asyncio.CancelledError:
self.log.warning("Connection died during authentication")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -336,9 +351,9 @@ class SecurityService(SecurityServicer):
# trigger encryption if needed
if self.need_encryption(connection, level):
try:
- self.log.info('Encrypt...')
+ self.log.debug('Encrypt...')
await connection.encrypt()
- self.log.info('Encrypted')
+ self.log.debug('Encrypted')
except asyncio.CancelledError:
self.log.warning("Connection died during encryption")
return SecureResponse(connection_died=empty_pb2.Empty())
@@ -356,7 +371,7 @@ class SecurityService(SecurityServicer):
self, request: WaitSecurityRequest, context: grpc.ServicerContext
) -> WaitSecurityResponse:
connection_handle = int.from_bytes(request.connection.cookie.value, 'big')
- self.log.info(f"WaitSecurity: {connection_handle}")
+ self.log.debug(f"WaitSecurity: {connection_handle}")
connection = self.device.lookup_connection(connection_handle)
assert connection
@@ -371,6 +386,7 @@ class SecurityService(SecurityServicer):
str
] = asyncio.get_running_loop().create_future()
authenticate_task: Optional[asyncio.Future[None]] = None
+ pair_task: Optional[asyncio.Future[None]] = None
async def authenticate() -> None:
assert connection
@@ -393,7 +409,7 @@ class SecurityService(SecurityServicer):
def set_failure(name: str) -> Callable[..., None]:
def wrapper(*args: Any) -> None:
- self.log.info(f'Wait for security: error `{name}`: {args}')
+ self.log.debug(f'Wait for security: error `{name}`: {args}')
wait_for_security.set_result(name)
return wrapper
@@ -401,13 +417,13 @@ class SecurityService(SecurityServicer):
def try_set_success(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
- self.log.info('Wait for security: done')
+ self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
def on_encryption_change(*_: Any) -> None:
assert connection
if self.reached_security_level(connection, level):
- self.log.info('Wait for security: done')
+ self.log.debug('Wait for security: done')
wait_for_security.set_result('success')
elif (
connection.transport == BT_BR_EDR_TRANSPORT
@@ -417,6 +433,10 @@ class SecurityService(SecurityServicer):
if authenticate_task is None:
authenticate_task = asyncio.create_task(authenticate())
+ def pair(*_: Any) -> None:
+ if self.need_pairing(connection, level):
+ pair_task = asyncio.create_task(connection.pair())
+
listeners: Dict[str, Callable[..., None]] = {
'disconnection': set_failure('connection_died'),
'pairing_failure': set_failure('pairing_failure'),
@@ -425,6 +445,9 @@ class SecurityService(SecurityServicer):
'pairing': try_set_success,
'connection_authentication': try_set_success,
'connection_encryption_change': on_encryption_change,
+ 'classic_pairing': try_set_success,
+ 'classic_pairing_failure': set_failure('pairing_failure'),
+ 'security_request': pair,
}
# register event handlers
@@ -435,7 +458,7 @@ class SecurityService(SecurityServicer):
if self.reached_security_level(connection, level):
return WaitSecurityResponse(success=empty_pb2.Empty())
- self.log.info('Wait for security...')
+ self.log.debug('Wait for security...')
kwargs = {}
kwargs[await wait_for_security] = empty_pb2.Empty()
@@ -445,12 +468,21 @@ class SecurityService(SecurityServicer):
# wait for `authenticate` to finish if any
if authenticate_task is not None:
- self.log.info('Wait for authentication...')
+ self.log.debug('Wait for authentication...')
try:
await authenticate_task # type: ignore
except:
pass
- self.log.info('Authenticated')
+ self.log.debug('Authenticated')
+
+ # wait for `pair` to finish if any
+ if pair_task is not None:
+ self.log.debug('Wait for authentication...')
+ try:
+ await pair_task # type: ignore
+ except:
+ pass
+ self.log.debug('paired')
return WaitSecurityResponse(**kwargs)
@@ -506,24 +538,24 @@ class SecurityStorageService(SecurityStorageServicer):
self, request: IsBondedRequest, context: grpc.ServicerContext
) -> wrappers_pb2.BoolValue:
address = utils.address_from_request(request, request.WhichOneof("address"))
- self.log.info(f"IsBonded: {address}")
+ self.log.debug(f"IsBonded: {address}")
if self.device.keystore is not None:
is_bonded = await self.device.keystore.get(str(address)) is not None
else:
is_bonded = False
- return BoolValue(value=is_bonded)
+ return wrappers_pb2.BoolValue(value=is_bonded)
@utils.rpc
async def DeleteBond(
self, request: DeleteBondRequest, context: grpc.ServicerContext
) -> empty_pb2.Empty:
address = utils.address_from_request(request, request.WhichOneof("address"))
- self.log.info(f"DeleteBond: {address}")
+ self.log.debug(f"DeleteBond: {address}")
if self.device.keystore is not None:
- with suppress(KeyError):
+ with contextlib.suppress(KeyError):
await self.device.keystore.delete(str(address))
return empty_pb2.Empty()