aboutsummaryrefslogtreecommitdiff
path: root/client/cros/cellular/mbim_compliance/mbim_channel.py
blob: a487869c3f78b0b65cbdcfbc0317c436308720a9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
# Copyright 2015 The Chromium OS Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.

import logging
import multiprocessing
import struct
import time

import common
from six.moves.queue import Empty

from autotest_lib.client.bin import utils
from autotest_lib.client.cros.cellular.mbim_compliance import \
    mbim_channel_endpoint
from autotest_lib.client.cros.cellular.mbim_compliance import mbim_errors


class MBIMChannel(object):
    """
    Provide synchronous access to the modem with MBIM command level interaction.

    This object should simplify your interaction over the MBIM channel as
    follows:
    - Use |bidirectional_transaction| to send MBIM packets that are part of a
      transaction. This function will block until the transaction completes and
      return the MBIM packets received in response.
    - |bidirectional_transaction| will filter out packets that do not correspond
      to your transaction. This way, you don't have to worry about unsolicited
      notifications and/or stale packets when interacting with the modem.
    - All filtered out packets can be grabbed using the
      |get_outstanding_packets| function. Use this function to receive error
      notifications, status notifications, etc.
    - Use |unidirectional_transaction| to send MBIM packets for which you don't
      expect a response.
    - Use |flush| to clean out all pipes before starting a new transaction.

    Note that "MBIM packets" here really means MBIM fragments. This object does
    not (de)fragment packets for you. Out of necessity, it does check that
    received fragments are contiguous and in-order.

    So, this object houses the minimum information necessary about the MBIM
    fragments to provide you a comfortable synchronous packet level channel.

    """

    ENDPOINT_JOIN_TIMEOUT_S = 5
    FRAGMENT_TIMEOUT_S = 3
    # TODO(pprabhu) Consider allowing each transaction to specify its own
    # timeout.
    TRANSACTION_TIMEOUT_S = 5

    MESSAGE_HEADER_FORMAT = '<LLL'
    FRAGMENT_HEADER_FORMAT = '<LL'
    MBIM_FRAGMENTED_MESSAGES = [
            0x00000003,  # MBIM_COMMAND_MSG
            0x80000003,  # MBIM_COMMAND_DONE
            0x80000007]  # MBIM_INDICATE_STATUS

    def __init__(self,
                 device,
                 interface_number,
                 interrupt_endpoint_address,
                 in_buffer_size,
                 process_class=None):
        """
        @param device: Device handle returned by PyUSB for the modem to test.
        @param interface_number: |bInterfaceNumber| of the MBIM interface.
        @param interrupt_endpoint_address: |bEndpointAddress| for the usb
                INTERRUPT IN endpoint for notifications.
        @param in_buffer_size: The (fixed) buffer size to use for in control
                transfers.
        @param process_class: The class to instantiate to create a subprocess.
                This is used by tests only, to easily mock out the process
                ceation.

        """
        self._stop_request_event = multiprocessing.Event()
        self._request_queue = multiprocessing.Queue()
        self._response_queue = multiprocessing.Queue()
        self._outstanding_packets = []
        self._last_response = []
        self._stashed_first_fragment = None
        if process_class is None:
            process_class = multiprocessing.Process
        self._endpoint_process = process_class(
                target=mbim_channel_endpoint.MBIMChannelEndpoint,
                args=(device,
                      interface_number,
                      interrupt_endpoint_address,
                      in_buffer_size,
                      self._request_queue,
                      self._response_queue,
                      self._stop_request_event))
        self._endpoint_process.start()


    def __del__(self):
        """
        The destructor.

        Note that it is not guaranteed that |__del__| is called for objects that
        exist when the interpreter exits. It is recommended to call |close|
        explicitly.

        """
        self.close()


    def close(self):
        """
        Cleanly close the MBIMChannel.

        MBIMChannel forks a subprocess to communicate with the USB device. It is
        recommended that |close| be called explicitly.

        """
        if not self._endpoint_process:
            return

        if self._endpoint_process.is_alive():
            self._stop_request_event.set()
            self._endpoint_process.join(self.ENDPOINT_JOIN_TIMEOUT_S)
            if self._endpoint_process.is_alive():
                self._endpoint_process.terminate()

        self._endpoint_process = None


    def bidirectional_transaction(self, *args):
        """
        Execute a synchronous bidirectional transaction.

        @param *args: Fragments of a single MBIM transaction. An MBIM
                transaction may consist of multiple fragments - each fragment is
                the payload for a USB control message. It should be an
                |array.array| object.  It is your responsibility (and choice) to
                keep the fragments in-order, and to send all the fragments.
                For more details, see "Fragmentation of messages" in the MBIM
                spec.
        @returns: A list of fragments in the same order as received that
                correspond to the given transaction. If we receive less
                fragments than claimed, we will return what we get. If we
                receive non-contiguous / out-of-order fragments, we'll complain.
        @raises: MBIMComplianceChannelError if received fragments are
                out-of-order or non-contigouos.

        """
        self._verify_endpoint_open()
        if not args:
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'No data given to |bidirectional_transaction|.')

        transaction_id, _, _ = self._fragment_metadata(args[0])
        for fragment in args:
            self._request_queue.put_nowait(fragment)
        return self._get_response_fragments(transaction_id)


    def unidirectional_transaction(self, *args):
        """
        Execute a synchronous unidirectional transaction. No return value.

        @param *args: Fragments of a single MBIM transaction. An MBIM
                transaction may consist of multiple fragments - each fragment is
                the payload for a USB control message. It should be an
                |array.array| object.  It is your responsibility (and choice) to
                keep the fragments in-order, and to send all the fragments.
                For more details, see "Fragmentation of messages" in the MBIM
                spec.

        """
        self._verify_endpoint_open()
        if not args:
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'No data given to |unidirectional_transaction|.')

        for fragment in args:
            self._request_queue.put_nowait(fragment)


    def flush(self):
        """
        Clean out all queues.

        This waits till all outgoing packets have been sent, and then waits some
        more to give the channel time to settle down.

        @raises: MBIMComplianceChannelError if things don't settle down fast
                enough.
        """
        self._verify_endpoint_open()
        num_remaining_fragments = self._request_queue.qsize()
        try:
            timeout = self.FRAGMENT_TIMEOUT_S * num_remaining_fragments
            utils.poll_for_condition(lambda: self._request_queue.empty(),
                                     timeout=timeout)
        except utils.TimeoutError:
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'Could not flush request queue.')

        # Now wait for the response queue to settle down.
        # In the worst case, each request fragment that was remaining at the
        # time flush was called belonged to a different transaction, and each of
        # these transactions would serially timeout in |TRANSACTION_TIMEOUT_S|.
        # To avoid sleeping for long times, we cap this value arbitrarily to 5
        # transactions.
        num_remaining_transactions = min(5, num_remaining_fragments)
        time.sleep(num_remaining_fragments * self.TRANSACTION_TIMEOUT_S)
        extra_packets = self.get_outstanding_packets()
        for packet in extra_packets:
            logging.debug('flush: discarding packet: %s', packet)


    def get_outstanding_packets(self):
        """
        Get all received packets that were not part of an explicit transaction.

        @returns: A list of packets. Each packet is a list of fragments, so you
        perhaps want to do something like:
            for packet in channel.get_outstanding_packets():
                for fragment in packet:
                    # handle fragment.

        """
        self._verify_endpoint_open()
        # Try to get more packets from the response queue.
        # This can block forever if the modem keeps spewing trash at us.
        while True:
            packet = self._get_packet_fragments()
            if not packet:
                break
            self._outstanding_packets.append(packet)

        packets = self._outstanding_packets
        self._outstanding_packets = []
        return packets


    def _get_response_fragments(self, transaction_id):
        """
        Get response for the given |transaction_id|.

        @returns: A list of fragments.
        @raises: MBIMComplianceChannelError if response is not recieved.

        """
        def _poll_response():
            packet = self._get_packet_fragments()
            if not packet:
                return False
            first_fragment = packet[0]
            response_id, _, _ = self._fragment_metadata(first_fragment)
            if response_id == transaction_id:
                self._last_response = packet
                return True
            self._outstanding_packets.append(packet)
            return False

        try:
            utils.poll_for_condition(
                    _poll_response,
                    timeout=self.TRANSACTION_TIMEOUT_S)
        except utils.TimeoutError:
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'Did not receive timely reply to transaction %d' %
                    transaction_id)
        return self._last_response


    def _get_packet_fragments(self):
        """
        Get all fragements of the next packet from the modem.

        This function is responsible for putting together fragments of one
        packet, and checking that fragments are continguous and in-order.

        """
        fragments = []
        if self._stashed_first_fragment is not None:
            first_fragment = self._stashed_first_fragment
            self._stashed_first_fragment = None
        else:
            try:
                first_fragment = self._response_queue.get(
                        True, self.FRAGMENT_TIMEOUT_S)
            except Empty:
                # *Don't fail* Just return nothing.
                return fragments

        transaction_id, total_fragments, current_fragment = (
                self._fragment_metadata(first_fragment))
        if current_fragment != 0:
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'First fragment reports fragment number %d' %
                    current_fragment)

        fragments.append(first_fragment)

        last_fragment = 0
        while last_fragment < total_fragments - 1:
            try:
                fragment = self._response_queue.get(True,
                                                    self.FRAGMENT_TIMEOUT_S)
            except Empty:
                # *Don't fail* Just return the fragments we got so far.
                break

            fragment_id, fragment_total, fragment_current = (
                    self._fragment_metadata(fragment))
            if fragment_id != transaction_id:
                # *Don't fail* Treat a different transaction id as indicating
                # that the next packet has already arrived.
                logging.warning('Recieved only %d out of %d fragments for '
                                'transaction %d.',
                                last_fragment,
                                total_fragments,
                                transaction_id)
                self._stashed_first_fragment = fragment
                break

            if fragment_total != total_fragments:
                mbim_errors.log_and_raise(
                        mbim_errors.MBIMComplianceChannelError,
                        'Fragment number %d reports incorrect total (%d/%d)' %
                        (last_fragment + 1, fragment_total, total_fragments))

            if fragment_current != last_fragment + 1:
                mbim_errors.log_and_raise(
                        mbim_errors.MBIMComplianceChannelError,
                        'Received reordered fragments. Expected %d, got %d' %
                        (last_fragment + 1, fragment_current))

            last_fragment += 1
            fragments.append(fragment)

        return fragments


    def _fragment_metadata(self, fragment):
        """ This function houses all the MBIM packet knowledge. """
        # All packets have a message header.
        if len(fragment) < struct.calcsize(self.MESSAGE_HEADER_FORMAT):
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'Corrupted fragment |%s| does not have an MBIM header.' %
                    fragment)

        message_type, _, transaction_id = struct.unpack_from(
                self.MESSAGE_HEADER_FORMAT,
                fragment)

        if message_type in self.MBIM_FRAGMENTED_MESSAGES:
            fragment = fragment[struct.calcsize(self.MESSAGE_HEADER_FORMAT):]
            if len(fragment) < struct.calcsize(self.FRAGMENT_HEADER_FORMAT):
                mbim_errors.log_and_raise(
                        mbim_errors.MBIMComplianceChannelError,
                        'Corrupted fragment |%s| does not have a fragment '
                        'header. ' %
                        fragment)

            total_fragments, current_fragment = struct.unpack_from(
                    self.FRAGMENT_HEADER_FORMAT,
                    fragment)
        else:
            # For other types, there is only one 'fragment'.
            total_fragments = 1
            current_fragment = 0

        return transaction_id, total_fragments, current_fragment


    def _verify_endpoint_open(self):
        if not self._endpoint_process.is_alive():
            mbim_errors.log_and_raise(
                    mbim_errors.MBIMComplianceChannelError,
                    'MBIMChannelEndpoint died unexpectedly. '
                    'The actual exception can be found in log entries from the '
                    'subprocess.')