summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLorenzo Colitti <lorenzo@google.com>2016-01-06 08:55:54 +0000
committerGerrit Code Review <noreply-gerritcodereview@google.com>2016-01-06 08:55:54 +0000
commit7e9e7f3b8099345b6982e39cea09cb3d438dc1c8 (patch)
treeece874150ebf9733b7299c788eaca25e713174a0
parentc526937fc3ba36afefaa42fbf795591968a26f59 (diff)
parent02c264a0e449bec97c3ad4c7dec31d86dc26d1f6 (diff)
downloadextras-7e9e7f3b8099345b6982e39cea09cb3d438dc1c8.tar.gz
Merge changes I11334802,I769518d1
* changes: More SOCK_DESTROY test work Add code and tests to close sockets via SOCK_DESTROY.
-rw-r--r--tests/net_test/packets.py5
-rwxr-xr-xtests/net_test/run_net_test.sh1
-rwxr-xr-xtests/net_test/sock_diag.py10
-rwxr-xr-xtests/net_test/sock_diag_test.py335
4 files changed, 346 insertions, 5 deletions
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py
index d92a97e4..c02adc0a 100644
--- a/tests/net_test/packets.py
+++ b/tests/net_test/packets.py
@@ -120,11 +120,12 @@ def ACK(version, srcaddr, dstaddr, packet, payload=""):
def FIN(version, srcaddr, dstaddr, packet):
ip = _GetIpLayer(version)
original = packet.getlayer("TCP")
- was_fin = (original.flags & TCP_FIN) != 0
+ was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0
+ ack_delta = was_syn_or_fin + len(original.payload)
return ("TCP FIN",
ip(src=srcaddr, dst=dstaddr) /
scapy.TCP(sport=original.dport, dport=original.sport,
- ack=original.seq + was_fin, seq=original.ack,
+ ack=original.seq + ack_delta, seq=original.ack,
flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW))
def GRE(version, srcaddr, dstaddr, proto, packet):
diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh
index d745ec31..080aac73 100755
--- a/tests/net_test/run_net_test.sh
+++ b/tests/net_test/run_net_test.sh
@@ -12,6 +12,7 @@ OPTIONS="$OPTIONS IPV6_PRIVACY IPV6_OPTIMISTIC_DAD"
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG"
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2"
OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG"
+OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY"
# For 3.1 kernels, where devtmpfs is not on by default.
OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT"
diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py
index 69785aa6..a9de3458 100755
--- a/tests/net_test/sock_diag.py
+++ b/tests/net_test/sock_diag.py
@@ -235,6 +235,16 @@ class SockDiag(netlink.NetlinkSocket):
"""Constructs a diag_req from a diag_msg the kernel has given us."""
return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
+ def CloseSocket(self, req):
+ self._SendNlRequest(SOCK_DESTROY, req.Pack(),
+ netlink.NLM_F_REQUEST | netlink.NLM_F_ACK)
+
+ def CloseSocketFromFd(self, s):
+ diag_msg = self.FindSockDiagFromFd(s)
+ protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL)
+ req = self.DiagReqFromDiagMsg(diag_msg, protocol)
+ return self.CloseSocket(req)
+
if __name__ == "__main__":
n = SockDiag()
diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py
index 7eff7e40..59759315 100755
--- a/tests/net_test/sock_diag_test.py
+++ b/tests/net_test/sock_diag_test.py
@@ -14,7 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import errno
+from errno import *
+import os
import random
from socket import *
import time
@@ -26,12 +27,15 @@ import multinetwork_base
import net_test
import packets
import sock_diag
+import threading
NUM_SOCKETS = 100
ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT)
+# TODO: Backport SOCK_DESTROY and delete this.
+HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4)
class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
@@ -48,11 +52,13 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
return socketpairs
def setUp(self):
+ super(SockDiagTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
- self.socketpairs = self._CreateLotsOfSockets()
+ self.socketpairs = {}
def tearDown(self):
[s.close() for socketpair in self.socketpairs.values() for s in socketpair]
+ super(SockDiagTest, self).tearDown()
def testFixupDiagMsg(self):
src = "0a00fa02303030312030312038302031"
@@ -78,6 +84,16 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00"
self.assertEquals(msg4.Pack(), fixed4.Pack())
+ def assertSocketClosed(self, sock):
+ self.assertRaisesErrno(ENOTCONN, sock.getpeername)
+
+ def assertSocketConnected(self, sock):
+ sock.getpeername() # No errors? Socket is alive and connected.
+
+ def assertSocketsClosed(self, socketpair):
+ for sock in socketpair:
+ self.assertSocketClosed(sock)
+
def assertSockDiagMatchesSocket(self, s, diag_msg):
family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
self.assertEqual(diag_msg.family, family)
@@ -93,9 +109,10 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
self.assertEqual(diag_msg.id.dport, dport)
else:
- assertRaisesErrno(errno.ENOTCONN, s.getpeername)
+ assertRaisesErrno(ENOTCONN, s.getpeername)
def testFindsAllMySockets(self):
+ self.socketpairs = self._CreateLotsOfSockets()
sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
states=ALL_NON_TIME_WAIT)
self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
@@ -131,6 +148,318 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest):
diag_msg = self.sock_diag.GetSockDiag(req)
self.assertSockDiagMatchesSocket(sock, diag_msg)
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testClosesSockets(self):
+ self.socketpairs = self._CreateLotsOfSockets()
+ for (addr, _, _), socketpair in self.socketpairs.iteritems():
+ # Close one of the sockets.
+ # This will send a RST that will close the other side as well.
+ s = random.choice(socketpair)
+ if random.randrange(0, 2) == 1:
+ self.sock_diag.CloseSocketFromFd(s)
+ else:
+ diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+ family = AF_INET6 if ":" in addr else AF_INET
+
+ # Get the cookie wrong and ensure that we get an error and the socket
+ # is not closed.
+ real_cookie = diag_msg.id.cookie
+ diag_msg.id.cookie = os.urandom(len(real_cookie))
+ req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
+ self.assertSocketConnected(s)
+
+ # Now close it with the correct cookie.
+ req.id.cookie = real_cookie
+ self.sock_diag.CloseSocket(req)
+
+ # Check that both sockets in the pair are closed.
+ self.assertSocketsClosed(socketpair)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testNonTcpSockets(self):
+ s = socket(AF_INET6, SOCK_DGRAM, 0)
+ s.connect(("::1", 53))
+ diag_msg = self.sock_diag.FindSockDiagFromFd(s)
+ self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
+
+ def testNonSockDiagCommand(self):
+ def DiagDump(code):
+ sock_id = self.sock_diag._EmptyInetDiagSockId()
+ req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
+ sock_id))
+ self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg)
+
+ op = sock_diag.SOCK_DIAG_BY_FAMILY
+ DiagDump(op) # No errors? Good.
+ self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+
+ # TODO:
+ # Test that killing unix sockets returns EOPNOTSUPP.
+
+
+class SocketExceptionThread(threading.Thread):
+
+ def __init__(self, sock, operation):
+ self.exception = None
+ super(SocketExceptionThread, self).__init__()
+ self.daemon = True
+ self.sock = sock
+ self.operation = operation
+
+ def run(self):
+ try:
+ self.operation(self.sock)
+ except Exception, e:
+ self.exception = e
+
+
+# TODO: Take a tun fd as input, make this a utility class, and reuse at least
+# in forwarding_test.
+class TcpTest(SockDiagTest):
+
+ NOT_YET_ACCEPTED = -1
+
+ def setUp(self):
+ super(TcpTest, self).setUp()
+ self.sock_diag = sock_diag.SockDiag()
+ self.netid = random.choice(self.tuns.keys())
+
+ def OpenListenSocket(self, version):
+ self.port = packets.RandomPort()
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ address = {4: "0.0.0.0", 6: "::"}[version]
+ s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ s.bind((address, self.port))
+ # We haven't configured inbound iptables marking, so bind explicitly.
+ self.SelectInterface(s, self.netid, "mark")
+ s.listen(100)
+ return s
+
+ def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+ pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
+ reply, msg)
+ self.last_packet = pkt
+ return pkt
+
+ def ReceivePacketOn(self, netid, packet):
+ super(TcpTest, self).ReceivePacketOn(netid, packet)
+ self.last_packet = packet
+
+ def RstPacket(self):
+ return packets.RST(self.version, self.myaddr, self.remoteaddr,
+ self.last_packet)
+
+ def IncomingConnection(self, version, end_state, netid):
+ self.version = version
+ self.s = self.OpenListenSocket(version)
+ self.end_state = end_state
+
+ remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+ myaddr = self.myaddr = self.MyAddress(version, netid)
+
+ if end_state == sock_diag.TCP_LISTEN:
+ return
+
+ desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+ synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+ msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+ reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+ if end_state == sock_diag.TCP_SYN_RECV:
+ return
+
+ establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+ self.ReceivePacketOn(netid, establishing_ack)
+
+ if end_state == self.NOT_YET_ACCEPTED:
+ return
+
+ self.accepted, _ = self.s.accept()
+ if end_state == sock_diag.TCP_ESTABLISHED:
+ return
+
+ desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+ payload=net_test.UDP_PAYLOAD)
+ self.accepted.send(net_test.UDP_PAYLOAD)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+ desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+ fin = packets._GetIpLayer(version)(str(fin))
+ ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+ msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+ # TODO: Why can't we use this?
+ # self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+ self.ReceivePacketOn(netid, fin)
+ time.sleep(0.1)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+ if end_state == sock_diag.TCP_CLOSE_WAIT:
+ return
+
+ raise ValueError("Invalid TCP state %d specified" % end_state)
+
+ def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
+ """Closes the socket and checks whether a RST is sent or not."""
+ if sock is not None:
+ self.assertIsNone(req, "Must specify sock or req, not both")
+ self.sock_diag.CloseSocketFromFd(sock)
+ self.assertRaisesErrno(EINVAL, sock.accept)
+ else:
+ self.assertIsNone(sock, "Must specify sock or req, not both")
+ self.sock_diag.CloseSocket(req)
+
+ if expect_reset:
+ desc, rst = self.RstPacket()
+ msg = "%s: expecting %s: " % (msg, desc)
+ self.ExpectPacketOn(self.netid, msg, rst)
+ else:
+ msg = "%s: " % msg
+ self.ExpectNoPacketsOn(self.netid, msg)
+
+ if sock is not None and do_close:
+ sock.close()
+
+ def CheckTcpReset(self, state, statename):
+ for version in [4, 6]:
+ msg = "Closing incoming IPv%d %s socket" % (version, statename)
+ self.IncomingConnection(version, state, self.netid)
+ self.CheckRstOnClose(self.s, None, False, msg)
+ if state != sock_diag.TCP_LISTEN:
+ msg = "Closing accepted IPv%d %s socket" % (version, statename)
+ self.CheckRstOnClose(self.accepted, None, True, msg)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testTcpResets(self):
+ """Checks that closing sockets in appropriate states sends a RST."""
+ self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
+ self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+ self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+
+ def FindChildSockets(self, s):
+ """Finds the SYN_RECV child sockets of a given listening socket."""
+ d = self.sock_diag.FindSockDiagFromFd(self.s)
+ req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+ req.id.cookie = "\x00" * 8
+ children = self.sock_diag.Dump(req)
+ return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ for d, _ in children]
+
+ def CheckChildSocket(self, state, statename, parent_first):
+ for version in [4, 6]:
+ self.IncomingConnection(version, state, self.netid)
+
+ d = self.sock_diag.FindSockDiagFromFd(self.s)
+ parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+ children = self.FindChildSockets(self.s)
+ self.assertEquals(1, len(children))
+
+ is_established = (state == self.NOT_YET_ACCEPTED)
+
+ # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
+ # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
+ # Before 4.4, we can see those sockets in dumps, but we can't fetch
+ # or close them.
+ can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
+
+ for child in children:
+ if can_close_children:
+ self.sock_diag.GetSockDiag(child) # No errors? Good, child found.
+ else:
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+ def CloseParent(expect_reset):
+ msg = "Closing parent IPv%d %s socket %s child" % (
+ version, statename, "before" if parent_first else "after")
+ self.CheckRstOnClose(self.s, None, expect_reset, msg)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
+
+ def CheckChildrenClosed():
+ for child in children:
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+ def CloseChildren():
+ for child in children:
+ msg = "Closing child IPv%d %s socket %s parent" % (
+ version, statename, "after" if parent_first else "before")
+ self.sock_diag.GetSockDiag(child)
+ self.CheckRstOnClose(None, child, is_established, msg)
+ self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+ CheckChildrenClosed()
+
+ if parent_first:
+ # Closing the parent will close child sockets, which will send a RST,
+ # iff they are already established.
+ CloseParent(is_established)
+ if is_established:
+ CheckChildrenClosed()
+ elif can_close_children:
+ CloseChildren()
+ CheckChildrenClosed()
+ self.s.close()
+ else:
+ if can_close_children:
+ CloseChildren()
+ CloseParent(False)
+ self.s.close()
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testChildSockets(self):
+ self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+ self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+ self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
+ self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+
+ def CloseDuringBlockingCall(self, sock, call, expected_errno):
+ thread = SocketExceptionThread(sock, call)
+ thread.start()
+ time.sleep(0.1)
+ self.sock_diag.CloseSocketFromFd(sock)
+ thread.join(1)
+ self.assertFalse(thread.is_alive())
+ self.assertIsNotNone(thread.exception)
+ self.assertTrue(isinstance(thread.exception, IOError),
+ "Expected IOError, got %s" % thread.exception)
+ self.assertEqual(expected_errno, thread.exception.errno)
+ self.assertSocketClosed(sock)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testAcceptInterrupted(self):
+ """Tests that accept() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+ self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
+ self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
+ self.assertRaisesErrno(EINVAL, self.s.accept)
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testReadInterrupted(self):
+ """Tests that read() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+ self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
+ ECONNABORTED)
+ self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
+
+ @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testConnectInterrupted(self):
+ """Tests that connect() is interrupted by SOCK_DESTROY."""
+ for version in [4, 6]:
+ family = {4: AF_INET, 6: AF_INET6}[version]
+ s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ self.SelectInterface(s, self.netid, "mark")
+ remoteaddr = self.GetRemoteAddress(version)
+ s.bind(("", 0))
+ _, sport = s.getsockname()[:2]
+ self.CloseDuringBlockingCall(
+ s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
+ desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
+ remoteaddr, sport=sport, seq=None)
+ self.ExpectPacketOn(self.netid, desc, syn)
+ msg = "SOCK_DESTROY of socket in connect, expected no RST"
+ self.ExpectNoPacketsOn(self.netid, msg)
+
if __name__ == "__main__":
unittest.main()