summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLorenzo Colitti <lorenzo@google.com>2015-10-29 23:51:25 +0900
committerLorenzo Colitti <lorenzo@google.com>2015-12-15 02:16:29 +0900
commit8a50c9c793b4764f214ccccc61a9433a38b2fe68 (patch)
tree4a0e6566e59d3b12d7cafbc9bd21169f2485a242
parent8056a9bdfcaf8b8776eccbae3ec29560679e867e (diff)
downloadextras-8a50c9c793b4764f214ccccc61a9433a38b2fe68.tar.gz
Add support for the sock_diag netlink interface.
Change-Id: Id5b1b3516d0a708bcfd69ae0e182dc39fe225934
-rw-r--r--tests/net_test/netlink.py7
-rwxr-xr-xtests/net_test/sock_diag.py226
-rwxr-xr-xtests/net_test/sock_diag_test.py112
3 files changed, 342 insertions, 3 deletions
diff --git a/tests/net_test/netlink.py b/tests/net_test/netlink.py
index 6b2c60d1..514ad08b 100644
--- a/tests/net_test/netlink.py
+++ b/tests/net_test/netlink.py
@@ -121,9 +121,10 @@ class NetlinkSocket(object):
# If it's an attribute we know about, try to decode it.
nla_name, nla_data = self._Decode(command, msg, nla.nla_type, nla_data)
- # We only support unique attributes for now.
- if nla_name in attributes:
- raise ValueError("Duplicate attribute %d" % nla_name)
+ # We only support unique attributes for now, except for INET_DIAG_NONE,
+ # which can appear more than once but doesn't seem to contain any data.
+ if nla_name in attributes and nla_name != "INET_DIAG_NONE":
+ raise ValueError("Duplicate attribute %s" % nla_name)
attributes[nla_name] = nla_data
self._Debug(" %s" % str((nla_name, nla_data)))
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
new file mode 100755
index 00000000..8c70eb33
--- /dev/null
+++ b/tests/net_test/sock_diag.py
@@ -0,0 +1,226 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""Partial Python implementation of sock_diag functionality."""
+
+# pylint: disable=g-bad-todo
+
+import errno
+from socket import * # pylint: disable=wildcard-import
+
+import cstruct
+import net_test
+import netlink
+
+### Base netlink constants. See include/uapi/linux/netlink.h.
+NETLINK_SOCK_DIAG = 4
+
+### sock_diag constants. See include/uapi/linux/sock_diag.h.
+# Message types.
+SOCK_DIAG_BY_FAMILY = 20
+SOCK_DESTROY = 21
+
+### inet_diag_constants. See include/uapi/linux/inet_diag.h
+# Message types.
+TCPDIAG_GETSOCK = 18
+
+# Extensions.
+INET_DIAG_NONE = 0
+INET_DIAG_MEMINFO = 1
+INET_DIAG_INFO = 2
+INET_DIAG_VEGASINFO = 3
+INET_DIAG_CONG = 4
+INET_DIAG_TOS = 5
+INET_DIAG_TCLASS = 6
+INET_DIAG_SKMEMINFO = 7
+INET_DIAG_SHUTDOWN = 8
+INET_DIAG_DCTCPINFO = 9
+
+# Data structure formats.
+# These aren't constants, they're classes. So, pylint: disable=invalid-name
+InetDiagSockId = cstruct.Struct(
+ "InetDiagSockId", "!HH16s16sI8s", "sport dport src dst iface cookie")
+InetDiagReqV2 = cstruct.Struct(
+ "InetDiagReqV2", "=BBBxIS", "family protocol ext states id",
+ [InetDiagSockId])
+InetDiagMsg = cstruct.Struct(
+ "InetDiagMsg", "=BBBBSLLLLL",
+ "family state timer retrans id expires rqueue wqueue uid inode",
+ [InetDiagSockId])
+InetDiagMeminfo = cstruct.Struct(
+ "InetDiagMeminfo", "=IIII", "rmem wmem fmem tmem")
+
+SkMeminfo = cstruct.Struct(
+ "SkMeminfo", "=IIIIIIII",
+ "rmem_alloc rcvbuf wmem_alloc sndbuf fwd_alloc wmem_queued optmem backlog")
+TcpInfo = cstruct.Struct(
+ "TcpInfo", "=BBBBBBBxIIIIIIIIIIIIIIIIIIIIIIII",
+ "state ca_state retransmits probes backoff options wscale "
+ "rto ato snd_mss rcv_mss "
+ "unacked sacked lost retrans fackets "
+ "last_data_sent last_ack_sent last_data_recv last_ack_recv "
+ "pmtu rcv_ssthresh rtt rttvar snd_ssthresh snd_cwnd advmss reordering "
+ "rcv_rtt rcv_space "
+ "total_retrans") # As of linux 3.13, at least.
+
+# TCP states. See include/net/tcp_states.h.
+TCP_ESTABLISHED = 1
+TCP_SYN_SENT = 2
+TCP_SYN_RECV = 3
+TCP_FIN_WAIT1 = 4
+TCP_FIN_WAIT2 = 5
+TCP_TIME_WAIT = 6
+TCP_CLOSE = 7
+TCP_CLOSE_WAIT = 8
+TCP_LAST_ACK = 9
+TCP_LISTEN = 10
+TCP_CLOSING = 11
+TCP_NEW_SYN_RECV = 12
+
+
+class SockDiag(netlink.NetlinkSocket):
+
+ FAMILY = NETLINK_SOCK_DIAG
+ NL_DEBUG = []
+
+ def _Decode(self, command, msg, nla_type, nla_data):
+ """Decodes netlink attributes to Python types."""
+ if msg.family == AF_INET or msg.family == AF_INET6:
+ name = self._GetConstantName(__name__, nla_type, "INET_DIAG")
+ else:
+ # Don't know what this is. Leave it as an integer.
+ name = nla_type
+
+ if name in ["INET_DIAG_SHUTDOWN", "INET_DIAG_TOS", "INET_DIAG_TCLASS"]:
+ data = ord(nla_data)
+ elif name == "INET_DIAG_CONG":
+ data = nla_data.strip("\x00")
+ elif name == "INET_DIAG_MEMINFO":
+ data = InetDiagMeminfo(nla_data)
+ elif name == "INET_DIAG_INFO":
+ # TODO: Catch the exception and try something else if it's not TCP.
+ data = TcpInfo(nla_data)
+ elif name == "INET_DIAG_SKMEMINFO":
+ data = SkMeminfo(nla_data)
+ else:
+ data = nla_data
+
+ return name, data
+
+ def MaybeDebugCommand(self, command, data):
+ name = self._GetConstantName(__name__, command, "SOCK_")
+ if "ALL" not in self.NL_DEBUG and "SOCK" not in self.NL_DEBUG:
+ return
+ parsed = self._ParseNLMsg(data, InetDiagReqV2)
+ print "%s %s" % (name, str(parsed))
+
+ @staticmethod
+ def _EmptyInetDiagSockId():
+ return InetDiagSockId(("\x00" * len(InetDiagSockId)))
+
+ def DumpSockets(self, family, protocol, ext, states, sock_id):
+ """Dumps sockets matching the specified parameters."""
+ if sock_id is None:
+ sock_id = self._EmptyInetDiagSockId()
+
+ diag_req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+ return self._Dump(SOCK_DIAG_BY_FAMILY, diag_req, InetDiagMsg)
+
+ def DumpAllInetSockets(self, protocol, sock_id=None, ext=0, states=0xffffffff):
+ # DumpSockets(AF_UNSPEC) does not result in dumping all inet sockets, it
+ # results in ENOENT.
+ sockets = []
+ for family in [AF_INET, AF_INET6]:
+ sockets += self.DumpSockets(family, protocol, ext, states, None)
+ return sockets
+
+ @staticmethod
+ def GetRawAddress(family, addr):
+ """Fetches the source address from an InetDiagMsg."""
+ addrlen = {AF_INET:4, AF_INET6: 16}[family]
+ return inet_ntop(family, addr[:addrlen])
+
+ @staticmethod
+ def GetSourceAddress(diag_msg):
+ """Fetches the source address from an InetDiagMsg."""
+ return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.src)
+
+ @staticmethod
+ def GetDestinationAddress(diag_msg):
+ """Fetches the source address from an InetDiagMsg."""
+ return SockDiag.GetRawAddress(diag_msg.family, diag_msg.id.dst)
+
+ @staticmethod
+ def RawAddress(addr):
+ """Converts an IP address string to binary format."""
+ family = AF_INET6 if ":" in addr else AF_INET
+ return inet_pton(family, addr)
+
+ @staticmethod
+ def PaddedAddress(addr):
+ """Converts an IP address string to binary format for InetDiagSockId."""
+ padded = SockDiag.RawAddress(addr)
+ if len(padded) < 16:
+ padded += "\x00" * (16 - len(padded))
+ return padded
+
+ @staticmethod
+ def DiagReqFromSocket(s):
+ """Creates an InetDiagReqV2 that matches the specified socket."""
+ family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
+ protocol = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_PROTOCOL)
+ iface = s.getsockopt(SOL_SOCKET, net_test.SO_BINDTODEVICE)
+ src, sport = s.getsockname()[:2]
+ try:
+ dst, dport = s.getpeername()[:2]
+ except error, e:
+ if e.errno == errno.ENOTCONN:
+ dport = 0
+ dst = "::" if family == AF_INET6 else "0.0.0.0"
+ else:
+ raise e
+ src = SockDiag.PaddedAddress(src)
+ dst = SockDiag.PaddedAddress(dst)
+ sock_id = InetDiagSockId((sport, dport, src, dst, iface, "\x00" * 8))
+ return InetDiagReqV2((family, protocol, 0, 0xffffffff, sock_id))
+
+ def GetSockDiagForFd(self, s):
+ """Gets an InetDiagMsg from the kernel for the specified socket."""
+ req = self.DiagReqFromSocket(s)
+ for diag_msg, attrs in self._Dump(SOCK_DIAG_BY_FAMILY, req, InetDiagMsg):
+ return diag_msg
+ raise ValueError("Dump of %s returned no sockets" % req)
+
+ def GetSockDiag(self, family, protocol, sock_id, ext=0, states=0xffffffff):
+ """Gets an InetDiagMsg from the kernel for the specified parameters."""
+ req = InetDiagReqV2((family, protocol, ext, states, sock_id))
+ self._SendNlRequest(SOCK_DIAG_BY_FAMILY, req.Pack(), netlink.NLM_F_REQUEST)
+ data = self._Recv()
+ return self._ParseNLMsg(data, InetDiagMsg)[0]
+
+
+if __name__ == "__main__":
+ n = SockDiag()
+ n.DEBUG = True
+ sock_id = n._EmptyInetDiagSockId()
+ sock_id.dport = 443
+ family = AF_INET6
+ protocol = IPPROTO_TCP
+ ext = 0
+ states = 0xffffffff
+ ext = 1 << (INET_DIAG_TOS - 1) | 1 << (INET_DIAG_TCLASS - 1)
+ diag_msgs = n.DumpSockets(family, protocol, ext, states, sock_id)
+ print diag_msgs
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
new file mode 100755
index 00000000..2803bd20
--- /dev/null
+++ b/tests/net_test/sock_diag_test.py
@@ -0,0 +1,112 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import errno
+import random
+from socket import *
+import time
+import unittest
+
+import csocket
+import cstruct
+import multinetwork_base
+import net_test
+import packets
+import sock_diag
+
+
+NUM_SOCKETS = 100
+
+ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
+
+
+class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
+
+ @staticmethod
+ def _CreateLotsOfSockets():
+ # Dict mapping (addr, sport, dport) tuples to socketpairs.
+ socketpairs = {}
+ for i in xrange(NUM_SOCKETS):
+ family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
+ socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
+ sport, dport = (socketpair[0].getsockname()[1],
+ socketpair[1].getsockname()[1])
+ socketpairs[(addr, sport, dport)] = socketpair
+ return socketpairs
+
+ def setUp(self):
+ self.sock_diag = sock_diag.SockDiag()
+ self.socketpairs = self._CreateLotsOfSockets()
+
+ def tearDown(self):
+ [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
+
+ def assertSockDiagMatchesSocket(self, s, diag_msg):
+ family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
+ self.assertEqual(diag_msg.family, family)
+
+ # TODO: The kernel (at least 3.10) seems only to fill in the first 4 bytes
+ # of src and dst in the case of IPv4 addresses. This means we can't just do
+ # something like:
+ # self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
+ # because the trailing bytes might not match.
+ # This seems like a bug because it might leaks kernel memory contents, but
+ # regardless, work around that here.
+ addrlen = {AF_INET: 4, AF_INET6: 16}[family]
+
+ src, sport = s.getsockname()[0:2]
+ self.assertEqual(diag_msg.id.sport, sport)
+ self.assertEqual(diag_msg.id.src[:addrlen],
+ self.sock_diag.RawAddress(src))
+
+ if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
+ dst, dport = s.getpeername()[0:2]
+ self.assertEqual(diag_msg.id.dst[:addrlen],
+ self.sock_diag.RawAddress(dst))
+ self.assertEqual(diag_msg.id.dport, dport)
+ else:
+ assertRaisesErrno(errno.ENOTCONN, s.getpeername)
+
+ def testFindsAllMySockets(self):
+ sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
+ states=ALL_NON_TIME_WAIT)
+ self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
+
+ # Find the cookies for all of our sockets.
+ cookies = {}
+ for diag_msg, attrs in sockets:
+ addr = self.sock_diag.GetSourceAddress(diag_msg)
+ sport = diag_msg.id.sport
+ dport = diag_msg.id.dport
+ if (addr, sport, dport) in self.socketpairs:
+ cookies[(addr, sport, dport)] = diag_msg.id.cookie
+ elif (addr, dport, sport) in self.socketpairs:
+ cookies[(addr, sport, dport)] = diag_msg.id.cookie
+
+ # Did we find all the cookies?
+ self.assertEquals(2 * NUM_SOCKETS, len(cookies))
+
+ socketpairs = self.socketpairs.values()
+ random.shuffle(socketpairs)
+ for socketpair in socketpairs:
+ for sock in socketpair:
+ self.assertSockDiagMatchesSocket(
+ sock,
+ self.sock_diag.GetSockDiagForFd(sock))
+
+
+if __name__ == "__main__":
+ unittest.main()