diff options
author | Lorenzo Colitti <lorenzo@google.com> | 2015-12-18 06:02:06 +0000 |
---|---|---|
committer | Gerrit Code Review <noreply-gerritcodereview@google.com> | 2015-12-18 06:02:06 +0000 |
commit | abf5656527ec767fdf306fbf703f73a10e4ae3d6 (patch) | |
tree | 504bf2f5b74690eafc93a93f14a97d87bbe91d0c | |
parent | f2f800951aa3ec920a719556d71c2fbee734209d (diff) | |
parent | f2dffc46c435e1260185eb525f237e862525c2be (diff) | |
download | extras-abf5656527ec767fdf306fbf703f73a10e4ae3d6.tar.gz |
Merge changes Ib0ab1722,Id5b1b351
* changes:
Don't create TIME-WAIT sockets in CreateSocketPair.
Add support for the sock_diag netlink interface.
-rwxr-xr-x | tests/net_test/net_test.py | 7 | ||||
-rw-r--r-- | tests/net_test/netlink.py | 7 | ||||
-rwxr-xr-x | tests/net_test/sock_diag.py | 226 | ||||
-rwxr-xr-x | tests/net_test/sock_diag_test.py | 112 |
4 files changed, 349 insertions, 3 deletions
diff --git a/tests/net_test/net_test.py b/tests/net_test/net_test.py index 5c5a4c5f..f108aa8f 100755 --- a/tests/net_test/net_test.py +++ b/tests/net_test/net_test.py @@ -151,6 +151,10 @@ def RawGRESocket(family): return s +def DisableLinger(sock): + sock.setsockopt(SOL_SOCKET, SO_LINGER, struct.pack("ii", 1, 0)) + + def CreateSocketPair(family, socktype, addr): clientsock = socket(family, socktype, 0) listensock = socket(family, socktype, 0) @@ -159,6 +163,9 @@ def CreateSocketPair(family, socktype, addr): listensock.listen(1) clientsock.connect(addr) acceptedsock, _ = listensock.accept() + DisableLinger(clientsock) + DisableLinger(acceptedsock) + listensock.close() return clientsock, acceptedsock diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py index 6b2c60d1..514ad08b 100644 --- a/tests/net_test/netlink.py +++ b/tests/net_test/netlink.py @@ -121,9 +121,10 @@ class NetlinkSocket(object): # If it's an attribute we know about, try to decode it. nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data) - # We only support unique attributes for now. - if nla_name in attributes: - raise ValueError("Duplicate attribute %d" % nla_name) + # We only support unique attributes for now, except for INET_DIAG_NONE, + # which can appear more than once but doesn't seem to contain any data. + if nla_name in attributes and nla_name != "INET_DIAG_NONE": + raise ValueError("Duplicate attribute %s" % nla_name) attributes[nla_name] = nla_data self._Debug(" %s" % str((nla_name, nla_data))) diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py new file mode 100755 index 00000000..8c70eb33 --- /dev/null +++ b/tests/net_test/sock_diag.py @@ -0,0 +1,226 @@ +#!/usr/bin/python +# +# Copyright 2015 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Partial Python implementation of sock_diag functionality.""" + +# pylint: disable=g-bad-todo + +import errno +from socket import * # pylint: disable=wildcard-import + +import cstruct +import net_test +import netlink + +### Base netlink constants. See include/uapi/linux/netlink.h. +NETLINK_SOCK_DIAG = 4 + +### sock_diag constants. See include/uapi/linux/sock_diag.h. +# Message types. +SOCK_DIAG_BY_FAMILY = 20 +SOCK_DESTROY = 21 + +### inet_diag_constants. See include/uapi/linux/inet_diag.h +# Message types. +TCPDIAG_GETSOCK = 18 + +# Extensions. +INET_DIAG_NONE = 0 +INET_DIAG_MEMINFO = 1 +INET_DIAG_INFO = 2 +INET_DIAG_VEGASINFO = 3 +INET_DIAG_CONG = 4 +INET_DIAG_TOS = 5 +INET_DIAG_TCLASS = 6 +INET_DIAG_SKMEMINFO = 7 +INET_DIAG_SHUTDOWN = 8 +INET_DIAG_DCTCPINFO = 9 + +# Data structure formats. +# These aren't constants, they're classes. So, pylint: disable=invalid-name +InetDiagSockId = cstruct.Struct( + "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie") +InetDiagReqV2 = cstruct.Struct( + "InetDiagReqV2", "=BBBxIS", "family protocol ext states id", + [InetDiagSockId]) +InetDiagMsg = cstruct.Struct( + "InetDiagMsg", "=BBBBSLLLLL", + "family state timer retrans id expires rqueue wqueue uid inode", + [InetDiagSockId]) +InetDiagMeminfo = cstruct.Struct( + "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem") + +SkMeminfo = cstruct.Struct( + "SkMeminfo", "=IIIIIIII", + "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog") +TcpInfo = cstruct.Struct( + "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII", + "state ca_state retransmits probes backoff options wscale " + "rto ato snd_mss rcv_mss " + "unacked sacked lost retrans fackets " + "last_data_sent last_ack_sent last_data_recv last_ack_recv " + "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering " + "rcv_rtt rcv_space " + "total_retrans") # As of linux 3.13, at least. + +# TCP states. See include/net/tcp_states.h. +TCP_ESTABLISHED = 1 +TCP_SYN_SENT = 2 +TCP_SYN_RECV = 3 +TCP_FIN_WAIT1 = 4 +TCP_FIN_WAIT2 = 5 +TCP_TIME_WAIT = 6 +TCP_CLOSE = 7 +TCP_CLOSE_WAIT = 8 +TCP_LAST_ACK = 9 +TCP_LISTEN = 10 +TCP_CLOSING = 11 +TCP_NEW_SYN_RECV = 12 + + +class SockDiag(netlink.NetlinkSocket): + + FAMILY = NETLINK_SOCK_DIAG + NL_DEBUG = [] + + def _Decode(self, command, msg, nla_type, nla_data): + """Decodes netlink attributes to Python types.""" + if msg.family == AF_INET or msg.family == AF_INET6: + name = self._GetConstantName(__name__, nla_type, "INET_DIAG") + else: + # Don't know what this is. Leave it as an integer. + name = nla_type + + if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]: + data = ord(nla_data) + elif name == "INET_DIAG_CONG": + data = nla_data.strip("\x00") + elif name == "INET_DIAG_MEMINFO": + data = InetDiagMeminfo(nla_data) + elif name == "INET_DIAG_INFO": + # TODO: Catch the exception and try something else if it's not TCP. + data = TcpInfo(nla_data) + elif name == "INET_DIAG_SKMEMINFO": + data = SkMeminfo(nla_data) + else: + data = nla_data + + return name, data + + def MaybeDebugCommand(self, command, data): + name = self._GetConstantName(__name__, command, "SOCK_") + if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG: + return + parsed = self._ParseNLMsg(data, InetDiagReqV2) + print "%s %s" % (name, str(parsed)) + + @staticmethod + def _EmptyInetDiagSockId(): + return InetDiagSockId(("\x00" * len(InetDiagSockId))) + + def DumpSockets(self, family, protocol, ext, states, sock_id): + """Dumps sockets matching the specified parameters.""" + if sock_id is None: + sock_id = self._EmptyInetDiagSockId() + + diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id)) + return self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg) + + def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, states=0xffffffff): + # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it + # results in ENOENT. + sockets = [] + for family in [AF_INET, AF_INET6]: + sockets += self.DumpSockets(family, protocol, ext, states, None) + return sockets + + @staticmethod + def GetRawAddress(family, addr): + """Fetches the source address from an InetDiagMsg.""" + addrlen = {AF_INET:4, AF_INET6: 16}[family] + return inet_ntop(family, addr[:addrlen]) + + @staticmethod + def GetSourceAddress(diag_msg): + """Fetches the source address from an InetDiagMsg.""" + return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src) + + @staticmethod + def GetDestinationAddress(diag_msg): + """Fetches the source address from an InetDiagMsg.""" + return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst) + + @staticmethod + def RawAddress(addr): + """Converts an IP address string to binary format.""" + family = AF_INET6 if ":" in addr else AF_INET + return inet_pton(family, addr) + + @staticmethod + def PaddedAddress(addr): + """Converts an IP address string to binary format for InetDiagSockId.""" + padded = SockDiag.RawAddress(addr) + if len(padded) < 16: + padded += "\x00" * (16 - len(padded)) + return padded + + @staticmethod + def DiagReqFromSocket(s): + """Creates an InetDiagReqV2 that matches the specified socket.""" + family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) + protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL) + iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE) + src, sport = s.getsockname()[:2] + try: + dst, dport = s.getpeername()[:2] + except error, e: + if e.errno == errno.ENOTCONN: + dport = 0 + dst = "::" if family == AF_INET6 else "0.0.0.0" + else: + raise e + src = SockDiag.PaddedAddress(src) + dst = SockDiag.PaddedAddress(dst) + sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8)) + return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id)) + + def GetSockDiagForFd(self, s): + """Gets an InetDiagMsg from the kernel for the specified socket.""" + req = self.DiagReqFromSocket(s) + for diag_msg, attrs in self._Dump(SOCK_DIAG_BY_FAMILY, req, InetDiagMsg): + return diag_msg + raise ValueError("Dump of %s returned no sockets" % req) + + def GetSockDiag(self, family, protocol, sock_id, ext=0, states=0xffffffff): + """Gets an InetDiagMsg from the kernel for the specified parameters.""" + req = InetDiagReqV2((family, protocol, ext, states, sock_id)) + self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST) + data = self._Recv() + return self._ParseNLMsg(data, InetDiagMsg)[0] + + +if __name__ == "__main__": + n = SockDiag() + n.DEBUG = True + sock_id = n._EmptyInetDiagSockId() + sock_id.dport = 443 + family = AF_INET6 + protocol = IPPROTO_TCP + ext = 0 + states = 0xffffffff + ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1) + diag_msgs = n.DumpSockets(family, protocol, ext, states, sock_id) + print diag_msgs diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py new file mode 100755 index 00000000..2803bd20 --- /dev/null +++ b/tests/net_test/sock_diag_test.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# +# Copyright 2015 The Android Open Source Project +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import errno +import random +from socket import * +import time +import unittest + +import csocket +import cstruct +import multinetwork_base +import net_test +import packets +import sock_diag + + +NUM_SOCKETS = 100 + +ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT) + + +class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): + + @staticmethod + def _CreateLotsOfSockets(): + # Dict mapping (addr, sport, dport) tuples to socketpairs. + socketpairs = {} + for i in xrange(NUM_SOCKETS): + family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")]) + socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr) + sport, dport = (socketpair[0].getsockname()[1], + socketpair[1].getsockname()[1]) + socketpairs[(addr, sport, dport)] = socketpair + return socketpairs + + def setUp(self): + self.sock_diag = sock_diag.SockDiag() + self.socketpairs = self._CreateLotsOfSockets() + + def tearDown(self): + [s.close() for socketpair in self.socketpairs.values() for s in socketpair] + + def assertSockDiagMatchesSocket(self, s, diag_msg): + family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) + self.assertEqual(diag_msg.family, family) + + # TODO: The kernel (at least 3.10) seems only to fill in the first 4 bytes + # of src and dst in the case of IPv4 addresses. This means we can't just do + # something like: + # self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src)) + # because the trailing bytes might not match. + # This seems like a bug because it might leaks kernel memory contents, but + # regardless, work around that here. + addrlen = {AF_INET: 4, AF_INET6: 16}[family] + + src, sport = s.getsockname()[0:2] + self.assertEqual(diag_msg.id.sport, sport) + self.assertEqual(diag_msg.id.src[:addrlen], + self.sock_diag.RawAddress(src)) + + if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]: + dst, dport = s.getpeername()[0:2] + self.assertEqual(diag_msg.id.dst[:addrlen], + self.sock_diag.RawAddress(dst)) + self.assertEqual(diag_msg.id.dport, dport) + else: + assertRaisesErrno(errno.ENOTCONN, s.getpeername) + + def testFindsAllMySockets(self): + sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, + states=ALL_NON_TIME_WAIT) + self.assertGreaterEqual(len(sockets), NUM_SOCKETS) + + # Find the cookies for all of our sockets. + cookies = {} + for diag_msg, attrs in sockets: + addr = self.sock_diag.GetSourceAddress(diag_msg) + sport = diag_msg.id.sport + dport = diag_msg.id.dport + if (addr, sport, dport) in self.socketpairs: + cookies[(addr, sport, dport)] = diag_msg.id.cookie + elif (addr, dport, sport) in self.socketpairs: + cookies[(addr, sport, dport)] = diag_msg.id.cookie + + # Did we find all the cookies? + self.assertEquals(2 * NUM_SOCKETS, len(cookies)) + + socketpairs = self.socketpairs.values() + random.shuffle(socketpairs) + for socketpair in socketpairs: + for sock in socketpair: + self.assertSockDiagMatchesSocket( + sock, + self.sock_diag.GetSockDiagForFd(sock)) + + +if __name__ == "__main__": + unittest.main() |