aboutsummaryrefslogtreecommitdiff
path: root/client/common_lib/cros/dev_server.py
diff options
context:
space:
mode:
Diffstat (limited to 'client/common_lib/cros/dev_server.py')
-rw-r--r--client/common_lib/cros/dev_server.py215
1 files changed, 175 insertions, 40 deletions
diff --git a/client/common_lib/cros/dev_server.py b/client/common_lib/cros/dev_server.py
index b2dbb46015..ffbd1457ee 100644
--- a/client/common_lib/cros/dev_server.py
+++ b/client/common_lib/cros/dev_server.py
@@ -13,6 +13,9 @@ import logging
import multiprocessing
import os
import re
+import shutil
+import subprocess
+from threading import Timer
import six
from six.moves import urllib
import six.moves.html_parser
@@ -28,10 +31,9 @@ from autotest_lib.client.common_lib import seven
from autotest_lib.client.common_lib import utils
from autotest_lib.client.common_lib.cros import retry
-# TODO(cmasone): redo this class using requests module; http://crosbug.com/30107
try:
- from chromite.lib import metrics
+ from autotest_lib.utils.frozen_chromite.lib import metrics
except ImportError:
metrics = utils.metrics_mock
@@ -155,6 +157,7 @@ class MarkupStripper(six.moves.html_parser.HTMLParser):
def __init__(self):
self.reset()
self.fed = []
+ self.convert_charrefs = True
def handle_data(self, d):
@@ -183,7 +186,11 @@ def _strip_http_message(message):
def _get_image_storage_server():
- return CONFIG.get_config_value('CROS', 'image_storage_server', type=str)
+ image_path = CONFIG.get_config_value('CROS',
+ 'image_storage_server',
+ type=str)
+ # see b/203531740; this forces a trailing / if not there yet.
+ return os.path.join(image_path, '')
def _get_canary_channel_server():
@@ -193,7 +200,11 @@ def _get_canary_channel_server():
@return: The url to the canary channel server.
"""
- return CONFIG.get_config_value('CROS', 'canary_channel_server', type=str)
+ image_path = CONFIG.get_config_value('CROS',
+ 'canary_channel_server',
+ type=str)
+ # see b/203531740; this forces a trailing / if not there yet.
+ return os.path.join(image_path, '')
def _get_storage_server_for_artifacts(artifacts=None):
@@ -602,6 +613,23 @@ class DevServer(object):
@classmethod
+ def run_request(cls, call, timeout=None):
+ """Invoke a given devserver call using urllib.open.
+
+ Open the URL with HTTP, and return the text of the response. Exceptions
+ may be raised as for urllib2.urlopen().
+
+ @param call: a url string that calls a method to a devserver.
+ @param timeout: The timeout seconds for this urlopen call.
+
+ @return A HTTPResponse object.
+ """
+ if timeout is None:
+ return urllib.request.urlopen(call)
+ else:
+ return utils.urlopen_socket_timeout(call, timeout=timeout)
+
+ @classmethod
def run_call(cls, call, readline=False, timeout=None):
"""Invoke a given devserver call using urllib.open.
@@ -614,14 +642,11 @@ class DevServer(object):
@return the results of this call.
"""
- if timeout is not None:
- return utils.urlopen_socket_timeout(
- call, timeout=timeout).read()
- elif readline:
- response = urllib.request.urlopen(call)
+ response = cls.run_request(call, timeout=timeout)
+ if readline:
return [line.rstrip() for line in response]
else:
- return urllib.request.urlopen(call).read()
+ return response.read()
@staticmethod
@@ -709,12 +734,17 @@ class DevServer(object):
@classmethod
- def get_available_devservers(cls, hostname=None,
+ def get_available_devservers(cls,
+ hostname=None,
prefer_local_devserver=PREFER_LOCAL_DEVSERVER,
- restricted_subnets=utils.RESTRICTED_SUBNETS):
+ restricted_subnets=utils.ALL_SUBNETS):
"""Get devservers in the same subnet of the given hostname.
@param hostname: Hostname of a DUT to choose devserver for.
+ @param prefer_local_devserver: A boolean indicating using a devserver in
+ the same subnet with the DUT.
+ @param restricted_subnets: A list of restricted subnets or p2p subnet
+ groups.
@return: A tuple of (devservers, can_retry), devservers is a list of
devservers that's available for the given hostname. can_retry
@@ -734,20 +764,33 @@ class DevServer(object):
if not host_ip:
return cls.get_unrestricted_devservers(restricted_subnets), False
- # Go through all restricted subnet settings and check if the DUT is
- # inside a restricted subnet. If so, only return the devservers in the
- # restricted subnet and doesn't allow retry.
- if host_ip and restricted_subnets:
- subnet_ip, mask_bits = _get_subnet_for_host_ip(
- host_ip, restricted_subnets=restricted_subnets)
- if subnet_ip:
- logging.debug('The host %s (%s) is in a restricted subnet. '
- 'Try to locate a devserver inside subnet '
- '%s:%d.', hostname, host_ip, subnet_ip,
- mask_bits)
- devservers = cls.get_devservers_in_same_subnet(
- subnet_ip, mask_bits)
- return devservers, False
+ # For the sake of backward compatibility, we use the argument
+ # 'restricted_subnets' to store both the legacy subnets (a tuple of
+ # (ip, mask)) and p2p subnets group (a list of subnets, i.e. [(ip,
+ # mask), ...]) data. For consistency, we convert all legacy subnets to
+ # a "singleton p2p subnets" and store them in a new list.
+ all_subnets = []
+ for s in restricted_subnets:
+ if isinstance(s, tuple):
+ all_subnets.append([s])
+ else:
+ all_subnets.append(s)
+
+ # Find devservers in the subnets reachable from the DUT.
+ if host_ip and all_subnets:
+ subnet_group = _get_subnet_group_for_host_ip(
+ host_ip, all_subnets=all_subnets)
+ if subnet_group:
+ devservers = set()
+ for ip, mask in subnet_group:
+ logging.debug(
+ 'The host %s (%s) is in a restricted subnet '
+ '(or its peers). '
+ 'Try to locate devservers inside subnet '
+ '%s/%d.', hostname, host_ip, ip, mask)
+ devservers |= set(
+ cls.get_devservers_in_same_subnet(ip, mask))
+ return sorted(devservers), False
# If prefer_local_devserver is set to True and the host is not in
# restricted subnet, pick a devserver in the same subnet if possible.
@@ -987,7 +1030,7 @@ class ImageServerBase(DevServer):
"""
server_name = get_hostname(call)
is_in_restricted_subnet = utils.get_restricted_subnet(
- server_name, utils.RESTRICTED_SUBNETS)
+ server_name, utils.get_all_restricted_subnets())
_EMPTY_SENTINEL_VALUE = object()
def kickoff_call():
"""Invoke a given devserver call using urllib.open or ssh.
@@ -1004,6 +1047,11 @@ class ImageServerBase(DevServer):
else:
response = cls.run_ssh_call(
call, readline=readline, timeout=timeout)
+
+ # six.ensure_str would be nice, but its not in all the envs, so
+ # this is what we are left with for now.
+ if isinstance(response, bytes):
+ response = response.decode()
# Retry if devserver service is temporarily down, e.g. in a
# devserver push.
if ERR_MSG_FOR_DOWN_DEVSERVER in response:
@@ -1036,9 +1084,57 @@ class ImageServerBase(DevServer):
@param local_file: The path of the file saved to local.
@param timeout: The timeout seconds for this call.
"""
- response = cls.run_call(remote_file, timeout=timeout)
- with open(local_file, 'w') as out_log:
- out_log.write(response)
+ server_name = get_hostname(remote_file)
+ is_in_restricted_subnet = utils.get_restricted_subnet(
+ server_name, utils.get_all_restricted_subnets())
+
+ if (not ENABLE_SSH_CONNECTION_FOR_DEVSERVER
+ or not is_in_restricted_subnet):
+ response = super(ImageServerBase, cls).run_request(remote_file,
+ timeout=timeout)
+ with open(local_file, 'wb') as out_log:
+ shutil.copyfileobj(response, out_log)
+ else:
+ timeout_seconds = timeout if timeout else DEVSERVER_SSH_TIMEOUT_MINS * 60
+ # SSH to the dev server and attach the local file as stdout.
+ with open(local_file, 'wb') as out_log:
+ ssh_cmd = [
+ 'ssh', server_name,
+ 'curl -s -S -f "%s"' % utils.sh_escape(remote_file)
+ ]
+ logging.debug("Running command %s", ssh_cmd)
+ with open(os.devnull) as devnull:
+ cmd = subprocess.Popen(
+ ssh_cmd,
+ stdout=out_log,
+ stdin=devnull,
+ stderr=subprocess.PIPE,
+ )
+
+ # Python 2.7 doesn't have Popen.wait(timeout), so start a
+ # timer and kill the ssh process if it takes too long.
+ def stop_process():
+ """Kills the subprocess after the timeout."""
+ cmd.kill()
+ logging.error("ssh call timed out after %s secs",
+ timeout_seconds)
+
+ t = Timer(timeout_seconds, stop_process)
+ try:
+ t.start()
+ cmd.wait()
+ finally:
+ t.cancel()
+ error_output = cmd.stderr.read()
+ if error_output:
+ logging.error("ssh call output: %s", error_output)
+ if cmd.returncode != 0:
+ c = metrics.Counter(
+ 'chromeos/autotest/devserver/ssh_failure')
+ c.increment(fields={'dev_server': server_name})
+ raise DevServerException(
+ "ssh call failed with exit code %s",
+ cmd.returncode)
def _poll_is_staged(self, **kwargs):
@@ -1124,7 +1220,7 @@ class ImageServerBase(DevServer):
'the call: %s' % (self.url(), call))
if expected_response and not response == expected_response:
- raise DevServerException(error_message)
+ raise DevServerException(error_message)
# `os_type` is needed in build a devserver call, but not needed for
# wait_for_artifacts_staged, since that method is implemented by
@@ -1157,7 +1253,8 @@ class ImageServerBase(DevServer):
@raise DevServerException upon any return code that's not HTTP OK.
"""
if not archive_url:
- archive_url = _get_storage_server_for_artifacts(artifacts) + build
+ archive_url = os.path.join(
+ _get_storage_server_for_artifacts(artifacts), build)
artifacts_arg = ','.join(artifacts) if artifacts else ''
files_arg = ','.join(files) if files else ''
@@ -1424,12 +1521,17 @@ class ImageServer(ImageServerBase):
self.nton_payload = nton_payload
- def wait_for_artifacts_staged(self, archive_url, artifacts='', files=''):
+ def wait_for_artifacts_staged(self,
+ archive_url,
+ artifacts='',
+ files='',
+ **kwargs):
"""Polling devserver.is_staged until all artifacts are staged.
@param archive_url: Google Storage URL for the build.
@param artifacts: Comma separated list of artifacts to download.
@param files: Comma separated list of files to download.
+ @param kwargs: keyword arguments to make is_staged devserver call.
@return: True if all artifacts are staged in devserver.
"""
kwargs = {'archive_url': archive_url,
@@ -1439,8 +1541,14 @@ class ImageServer(ImageServerBase):
@remote_devserver_call()
- def call_and_wait(self, call_name, archive_url, artifacts, files,
- error_message, expected_response=SUCCESS):
+ def call_and_wait(self,
+ call_name,
+ archive_url,
+ artifacts,
+ files,
+ error_message,
+ expected_response=SUCCESS,
+ clean=False):
"""Helper method to make a urlopen call, and wait for artifacts staged.
@param call_name: name of devserver rpc call.
@@ -1453,21 +1561,26 @@ class ImageServer(ImageServerBase):
to be good.
@param error_message: Error message to be thrown if response does not
match expected_response.
+ @param clean: Force re-loading artifacts/files from cloud, ignoring
+ cached version.
@return: The response from rpc.
@raise DevServerException upon any return code that's expected_response.
"""
- kwargs = {'archive_url': archive_url,
- 'artifacts': artifacts,
- 'files': files}
+ kwargs = {
+ 'archive_url': archive_url,
+ 'artifacts': artifacts,
+ 'files': files,
+ 'clean': clean
+ }
return self._call_and_wait(call_name, error_message,
expected_response, **kwargs)
@remote_devserver_call()
def stage_artifacts(self, image=None, artifacts=None, files='',
- archive_url=None):
+ archive_url=None, **kwargs):
"""Tell the devserver to download and stage |artifacts| from |image|.
This is the main call point for staging any specific artifacts for a
@@ -1483,13 +1596,15 @@ class ImageServer(ImageServerBase):
@param archive_url: Optional parameter that has the archive_url to stage
this artifact from. Default is specified in autotest config +
image.
+ @param kwargs: keyword arguments that specify the build information, to
+ make stage devserver call.
@raise DevServerException upon any return code that's not HTTP OK.
"""
if not artifacts and not files:
raise DevServerException('Must specify something to stage.')
image = self.translate(image)
- self._stage_artifacts(image, artifacts, files, archive_url)
+ self._stage_artifacts(image, artifacts, files, archive_url, **kwargs)
@remote_devserver_call(timeout_min=DEVSERVER_SSH_TIMEOUT_MINS)
@@ -2011,6 +2126,26 @@ def _get_subnet_for_host_ip(host_ip,
return None, None
+def _get_subnet_group_for_host_ip(host_ip, all_subnets=()):
+ """Get subnet group for a given host IP.
+
+ All subnets in the group are reachable from the input host ip.
+
+ @param host_ip: the IP of a DUT.
+ @param all_subnets: A two level list of subnets including singleton
+ lists of a restricted subnet and p2p subnets.
+
+ @return: a list of (subnet_ip, mask_bits) tuple. If no matched subnets for
+ the host_ip, return [].
+ """
+ for subnet_group in all_subnets:
+ subnet, _ = _get_subnet_for_host_ip(host_ip,
+ restricted_subnets=subnet_group)
+ if subnet:
+ return subnet_group
+ return []
+
+
def get_least_loaded_devserver(devserver_type=ImageServer, hostname=None):
"""Get the devserver with the least load.