diff options
Diffstat (limited to 'client/common_lib/cros/dev_server.py')
-rw-r--r-- | client/common_lib/cros/dev_server.py | 215 |
1 files changed, 175 insertions, 40 deletions
diff --git a/client/common_lib/cros/dev_server.py b/client/common_lib/cros/dev_server.py index b2dbb46015..ffbd1457ee 100644 --- a/client/common_lib/cros/dev_server.py +++ b/client/common_lib/cros/dev_server.py @@ -13,6 +13,9 @@ import logging import multiprocessing import os import re +import shutil +import subprocess +from threading import Timer import six from six.moves import urllib import six.moves.html_parser @@ -28,10 +31,9 @@ from autotest_lib.client.common_lib import seven from autotest_lib.client.common_lib import utils from autotest_lib.client.common_lib.cros import retry -# TODO(cmasone): redo this class using requests module; http://crosbug.com/30107 try: - from chromite.lib import metrics + from autotest_lib.utils.frozen_chromite.lib import metrics except ImportError: metrics = utils.metrics_mock @@ -155,6 +157,7 @@ class MarkupStripper(six.moves.html_parser.HTMLParser): def __init__(self): self.reset() self.fed = [] + self.convert_charrefs = True def handle_data(self, d): @@ -183,7 +186,11 @@ def _strip_http_message(message): def _get_image_storage_server(): - return CONFIG.get_config_value('CROS', 'image_storage_server', type=str) + image_path = CONFIG.get_config_value('CROS', + 'image_storage_server', + type=str) + # see b/203531740; this forces a trailing / if not there yet. + return os.path.join(image_path, '') def _get_canary_channel_server(): @@ -193,7 +200,11 @@ def _get_canary_channel_server(): @return: The url to the canary channel server. """ - return CONFIG.get_config_value('CROS', 'canary_channel_server', type=str) + image_path = CONFIG.get_config_value('CROS', + 'canary_channel_server', + type=str) + # see b/203531740; this forces a trailing / if not there yet. + return os.path.join(image_path, '') def _get_storage_server_for_artifacts(artifacts=None): @@ -602,6 +613,23 @@ class DevServer(object): @classmethod + def run_request(cls, call, timeout=None): + """Invoke a given devserver call using urllib.open. + + Open the URL with HTTP, and return the text of the response. Exceptions + may be raised as for urllib2.urlopen(). + + @param call: a url string that calls a method to a devserver. + @param timeout: The timeout seconds for this urlopen call. + + @return A HTTPResponse object. + """ + if timeout is None: + return urllib.request.urlopen(call) + else: + return utils.urlopen_socket_timeout(call, timeout=timeout) + + @classmethod def run_call(cls, call, readline=False, timeout=None): """Invoke a given devserver call using urllib.open. @@ -614,14 +642,11 @@ class DevServer(object): @return the results of this call. """ - if timeout is not None: - return utils.urlopen_socket_timeout( - call, timeout=timeout).read() - elif readline: - response = urllib.request.urlopen(call) + response = cls.run_request(call, timeout=timeout) + if readline: return [line.rstrip() for line in response] else: - return urllib.request.urlopen(call).read() + return response.read() @staticmethod @@ -709,12 +734,17 @@ class DevServer(object): @classmethod - def get_available_devservers(cls, hostname=None, + def get_available_devservers(cls, + hostname=None, prefer_local_devserver=PREFER_LOCAL_DEVSERVER, - restricted_subnets=utils.RESTRICTED_SUBNETS): + restricted_subnets=utils.ALL_SUBNETS): """Get devservers in the same subnet of the given hostname. @param hostname: Hostname of a DUT to choose devserver for. + @param prefer_local_devserver: A boolean indicating using a devserver in + the same subnet with the DUT. + @param restricted_subnets: A list of restricted subnets or p2p subnet + groups. @return: A tuple of (devservers, can_retry), devservers is a list of devservers that's available for the given hostname. can_retry @@ -734,20 +764,33 @@ class DevServer(object): if not host_ip: return cls.get_unrestricted_devservers(restricted_subnets), False - # Go through all restricted subnet settings and check if the DUT is - # inside a restricted subnet. If so, only return the devservers in the - # restricted subnet and doesn't allow retry. - if host_ip and restricted_subnets: - subnet_ip, mask_bits = _get_subnet_for_host_ip( - host_ip, restricted_subnets=restricted_subnets) - if subnet_ip: - logging.debug('The host %s (%s) is in a restricted subnet. ' - 'Try to locate a devserver inside subnet ' - '%s:%d.', hostname, host_ip, subnet_ip, - mask_bits) - devservers = cls.get_devservers_in_same_subnet( - subnet_ip, mask_bits) - return devservers, False + # For the sake of backward compatibility, we use the argument + # 'restricted_subnets' to store both the legacy subnets (a tuple of + # (ip, mask)) and p2p subnets group (a list of subnets, i.e. [(ip, + # mask), ...]) data. For consistency, we convert all legacy subnets to + # a "singleton p2p subnets" and store them in a new list. + all_subnets = [] + for s in restricted_subnets: + if isinstance(s, tuple): + all_subnets.append([s]) + else: + all_subnets.append(s) + + # Find devservers in the subnets reachable from the DUT. + if host_ip and all_subnets: + subnet_group = _get_subnet_group_for_host_ip( + host_ip, all_subnets=all_subnets) + if subnet_group: + devservers = set() + for ip, mask in subnet_group: + logging.debug( + 'The host %s (%s) is in a restricted subnet ' + '(or its peers). ' + 'Try to locate devservers inside subnet ' + '%s/%d.', hostname, host_ip, ip, mask) + devservers |= set( + cls.get_devservers_in_same_subnet(ip, mask)) + return sorted(devservers), False # If prefer_local_devserver is set to True and the host is not in # restricted subnet, pick a devserver in the same subnet if possible. @@ -987,7 +1030,7 @@ class ImageServerBase(DevServer): """ server_name = get_hostname(call) is_in_restricted_subnet = utils.get_restricted_subnet( - server_name, utils.RESTRICTED_SUBNETS) + server_name, utils.get_all_restricted_subnets()) _EMPTY_SENTINEL_VALUE = object() def kickoff_call(): """Invoke a given devserver call using urllib.open or ssh. @@ -1004,6 +1047,11 @@ class ImageServerBase(DevServer): else: response = cls.run_ssh_call( call, readline=readline, timeout=timeout) + + # six.ensure_str would be nice, but its not in all the envs, so + # this is what we are left with for now. + if isinstance(response, bytes): + response = response.decode() # Retry if devserver service is temporarily down, e.g. in a # devserver push. if ERR_MSG_FOR_DOWN_DEVSERVER in response: @@ -1036,9 +1084,57 @@ class ImageServerBase(DevServer): @param local_file: The path of the file saved to local. @param timeout: The timeout seconds for this call. """ - response = cls.run_call(remote_file, timeout=timeout) - with open(local_file, 'w') as out_log: - out_log.write(response) + server_name = get_hostname(remote_file) + is_in_restricted_subnet = utils.get_restricted_subnet( + server_name, utils.get_all_restricted_subnets()) + + if (not ENABLE_SSH_CONNECTION_FOR_DEVSERVER + or not is_in_restricted_subnet): + response = super(ImageServerBase, cls).run_request(remote_file, + timeout=timeout) + with open(local_file, 'wb') as out_log: + shutil.copyfileobj(response, out_log) + else: + timeout_seconds = timeout if timeout else DEVSERVER_SSH_TIMEOUT_MINS * 60 + # SSH to the dev server and attach the local file as stdout. + with open(local_file, 'wb') as out_log: + ssh_cmd = [ + 'ssh', server_name, + 'curl -s -S -f "%s"' % utils.sh_escape(remote_file) + ] + logging.debug("Running command %s", ssh_cmd) + with open(os.devnull) as devnull: + cmd = subprocess.Popen( + ssh_cmd, + stdout=out_log, + stdin=devnull, + stderr=subprocess.PIPE, + ) + + # Python 2.7 doesn't have Popen.wait(timeout), so start a + # timer and kill the ssh process if it takes too long. + def stop_process(): + """Kills the subprocess after the timeout.""" + cmd.kill() + logging.error("ssh call timed out after %s secs", + timeout_seconds) + + t = Timer(timeout_seconds, stop_process) + try: + t.start() + cmd.wait() + finally: + t.cancel() + error_output = cmd.stderr.read() + if error_output: + logging.error("ssh call output: %s", error_output) + if cmd.returncode != 0: + c = metrics.Counter( + 'chromeos/autotest/devserver/ssh_failure') + c.increment(fields={'dev_server': server_name}) + raise DevServerException( + "ssh call failed with exit code %s", + cmd.returncode) def _poll_is_staged(self, **kwargs): @@ -1124,7 +1220,7 @@ class ImageServerBase(DevServer): 'the call: %s' % (self.url(), call)) if expected_response and not response == expected_response: - raise DevServerException(error_message) + raise DevServerException(error_message) # `os_type` is needed in build a devserver call, but not needed for # wait_for_artifacts_staged, since that method is implemented by @@ -1157,7 +1253,8 @@ class ImageServerBase(DevServer): @raise DevServerException upon any return code that's not HTTP OK. """ if not archive_url: - archive_url = _get_storage_server_for_artifacts(artifacts) + build + archive_url = os.path.join( + _get_storage_server_for_artifacts(artifacts), build) artifacts_arg = ','.join(artifacts) if artifacts else '' files_arg = ','.join(files) if files else '' @@ -1424,12 +1521,17 @@ class ImageServer(ImageServerBase): self.nton_payload = nton_payload - def wait_for_artifacts_staged(self, archive_url, artifacts='', files=''): + def wait_for_artifacts_staged(self, + archive_url, + artifacts='', + files='', + **kwargs): """Polling devserver.is_staged until all artifacts are staged. @param archive_url: Google Storage URL for the build. @param artifacts: Comma separated list of artifacts to download. @param files: Comma separated list of files to download. + @param kwargs: keyword arguments to make is_staged devserver call. @return: True if all artifacts are staged in devserver. """ kwargs = {'archive_url': archive_url, @@ -1439,8 +1541,14 @@ class ImageServer(ImageServerBase): @remote_devserver_call() - def call_and_wait(self, call_name, archive_url, artifacts, files, - error_message, expected_response=SUCCESS): + def call_and_wait(self, + call_name, + archive_url, + artifacts, + files, + error_message, + expected_response=SUCCESS, + clean=False): """Helper method to make a urlopen call, and wait for artifacts staged. @param call_name: name of devserver rpc call. @@ -1453,21 +1561,26 @@ class ImageServer(ImageServerBase): to be good. @param error_message: Error message to be thrown if response does not match expected_response. + @param clean: Force re-loading artifacts/files from cloud, ignoring + cached version. @return: The response from rpc. @raise DevServerException upon any return code that's expected_response. """ - kwargs = {'archive_url': archive_url, - 'artifacts': artifacts, - 'files': files} + kwargs = { + 'archive_url': archive_url, + 'artifacts': artifacts, + 'files': files, + 'clean': clean + } return self._call_and_wait(call_name, error_message, expected_response, **kwargs) @remote_devserver_call() def stage_artifacts(self, image=None, artifacts=None, files='', - archive_url=None): + archive_url=None, **kwargs): """Tell the devserver to download and stage |artifacts| from |image|. This is the main call point for staging any specific artifacts for a @@ -1483,13 +1596,15 @@ class ImageServer(ImageServerBase): @param archive_url: Optional parameter that has the archive_url to stage this artifact from. Default is specified in autotest config + image. + @param kwargs: keyword arguments that specify the build information, to + make stage devserver call. @raise DevServerException upon any return code that's not HTTP OK. """ if not artifacts and not files: raise DevServerException('Must specify something to stage.') image = self.translate(image) - self._stage_artifacts(image, artifacts, files, archive_url) + self._stage_artifacts(image, artifacts, files, archive_url, **kwargs) @remote_devserver_call(timeout_min=DEVSERVER_SSH_TIMEOUT_MINS) @@ -2011,6 +2126,26 @@ def _get_subnet_for_host_ip(host_ip, return None, None +def _get_subnet_group_for_host_ip(host_ip, all_subnets=()): + """Get subnet group for a given host IP. + + All subnets in the group are reachable from the input host ip. + + @param host_ip: the IP of a DUT. + @param all_subnets: A two level list of subnets including singleton + lists of a restricted subnet and p2p subnets. + + @return: a list of (subnet_ip, mask_bits) tuple. If no matched subnets for + the host_ip, return []. + """ + for subnet_group in all_subnets: + subnet, _ = _get_subnet_for_host_ip(host_ip, + restricted_subnets=subnet_group) + if subnet: + return subnet_group + return [] + + def get_least_loaded_devserver(devserver_type=ImageServer, hostname=None): """Get the devserver with the least load. |