diff options
author | Lorenzo Colitti <lorenzo@google.com> | 2015-11-06 16:42:06 +0900 |
---|---|---|
committer | Lorenzo Colitti <lorenzo@google.com> | 2016-01-06 17:49:52 +0900 |
commit | 3823b47a664a64024278c1fa0cbc1b632cc60634 (patch) | |
tree | c96a546bba321bebc7614c5048b6ed9af759f435 | |
parent | da3aa4c1e092b6cddc2e3b64ddde24777fedd931 (diff) | |
download | extras-3823b47a664a64024278c1fa0cbc1b632cc60634.tar.gz |
Add code and tests to close sockets via SOCK_DESTROY.
Change-Id: I769518d128fcff8035c58fbf3dc868f02fbd6c9d
-rw-r--r-- | tests/net_test/packets.py | 5 | ||||
-rwxr-xr-x | tests/net_test/run_net_test.sh | 1 | ||||
-rwxr-xr-x | tests/net_test/sock_diag.py | 10 | ||||
-rwxr-xr-x | tests/net_test/sock_diag_test.py | 288 |
4 files changed, 299 insertions, 5 deletions
diff --git a/tests/net_test/packets.py b/tests/net_test/packets.py index d92a97e4..c02adc0a 100644 --- a/tests/net_test/packets.py +++ b/tests/net_test/packets.py @@ -120,11 +120,12 @@ def ACK(version, srcaddr, dstaddr, packet, payload=""): def FIN(version, srcaddr, dstaddr, packet): ip = _GetIpLayer(version) original = packet.getlayer("TCP") - was_fin = (original.flags & TCP_FIN) != 0 + was_syn_or_fin = (original.flags & (TCP_SYN | TCP_FIN)) != 0 + ack_delta = was_syn_or_fin + len(original.payload) return ("TCP FIN", ip(src=srcaddr, dst=dstaddr) / scapy.TCP(sport=original.dport, dport=original.sport, - ack=original.seq + was_fin, seq=original.ack, + ack=original.seq + ack_delta, seq=original.ack, flags=TCP_ACK | TCP_FIN, window=TCP_WINDOW)) def GRE(version, srcaddr, dstaddr, proto, packet): diff --git a/tests/net_test/run_net_test.sh b/tests/net_test/run_net_test.sh index d745ec31..080aac73 100755 --- a/tests/net_test/run_net_test.sh +++ b/tests/net_test/run_net_test.sh @@ -12,6 +12,7 @@ OPTIONS="$OPTIONS IPV6_PRIVACY IPV6_OPTIMISTIC_DAD" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_TARGET_NFLOG" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA CONFIG_NETFILTER_XT_MATCH_QUOTA2" OPTIONS="$OPTIONS CONFIG_NETFILTER_XT_MATCH_QUOTA2_LOG" +OPTIONS="$OPTIONS CONFIG_INET_UDP_DIAG CONFIG_INET_DIAG_DESTROY" # For 3.1 kernels, where devtmpfs is not on by default. OPTIONS="$OPTIONS DEVTMPFS DEVTMPFS_MOUNT" diff --git a/tests/net_test/sock_diag.py b/tests/net_test/sock_diag.py index 69785aa6..a9de3458 100755 --- a/tests/net_test/sock_diag.py +++ b/tests/net_test/sock_diag.py @@ -235,6 +235,16 @@ class SockDiag(netlink.NetlinkSocket): """Constructs a diag_req from a diag_msg the kernel has given us.""" return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id)) + def CloseSocket(self, req): + self._SendNlRequest(SOCK_DESTROY, req.Pack(), + netlink.NLM_F_REQUEST | netlink.NLM_F_ACK) + + def CloseSocketFromFd(self, s): + diag_msg = self.FindSockDiagFromFd(s) + protocol = s.getsockopt(SOL_SOCKET, net_test.SO_PROTOCOL) + req = self.DiagReqFromDiagMsg(diag_msg, protocol) + return self.CloseSocket(req) + if __name__ == "__main__": n = SockDiag() diff --git a/tests/net_test/sock_diag_test.py b/tests/net_test/sock_diag_test.py index 7eff7e40..e122c9cc 100755 --- a/tests/net_test/sock_diag_test.py +++ b/tests/net_test/sock_diag_test.py @@ -14,7 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. -import errno +from errno import * +import os import random from socket import * import time @@ -26,12 +27,15 @@ import multinetwork_base import net_test import packets import sock_diag +import threading NUM_SOCKETS = 100 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << sock_diag.TCP_TIME_WAIT) +# TODO: Backport SOCK_DESTROY and delete this. +HAVE_SOCK_DESTROY = net_test.LINUX_VERSION >= (4, 4) class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): @@ -48,11 +52,13 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): return socketpairs def setUp(self): + super(SockDiagTest, self).setUp() self.sock_diag = sock_diag.SockDiag() - self.socketpairs = self._CreateLotsOfSockets() + self.socketpairs = {} def tearDown(self): [s.close() for socketpair in self.socketpairs.values() for s in socketpair] + super(SockDiagTest, self).tearDown() def testFixupDiagMsg(self): src = "0a00fa02303030312030312038302031" @@ -78,6 +84,16 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): msg4.id.dst = dst.decode("hex")[:4] + 12 * "\x00" self.assertEquals(msg4.Pack(), fixed4.Pack()) + def assertSocketClosed(self, sock): + self.assertRaisesErrno(ENOTCONN, sock.getpeername) + + def assertSocketConnected(self, sock): + sock.getpeername() # No errors? Socket is alive and connected. + + def assertSocketsClosed(self, socketpair): + for sock in socketpair: + self.assertSocketClosed(sock) + def assertSockDiagMatchesSocket(self, s, diag_msg): family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN) self.assertEqual(diag_msg.family, family) @@ -93,9 +109,10 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst)) self.assertEqual(diag_msg.id.dport, dport) else: - assertRaisesErrno(errno.ENOTCONN, s.getpeername) + assertRaisesErrno(ENOTCONN, s.getpeername) def testFindsAllMySockets(self): + self.socketpairs = self._CreateLotsOfSockets() sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, states=ALL_NON_TIME_WAIT) self.assertGreaterEqual(len(sockets), NUM_SOCKETS) @@ -131,6 +148,271 @@ class SockDiagTest(multinetwork_base.MultiNetworkBaseTest): diag_msg = self.sock_diag.GetSockDiag(req) self.assertSockDiagMatchesSocket(sock, diag_msg) + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testClosesSockets(self): + self.socketpairs = self._CreateLotsOfSockets() + for (addr, _, _), socketpair in self.socketpairs.iteritems(): + # Close one of the sockets. + # This will send a RST that will close the other side as well. + s = random.choice(socketpair) + if random.randrange(0, 2) == 1: + self.sock_diag.CloseSocketFromFd(s) + else: + diag_msg = self.sock_diag.FindSockDiagFromFd(s) + family = AF_INET6 if ":" in addr else AF_INET + + # Get the cookie wrong and ensure that we get an error and the socket + # is not closed. + real_cookie = diag_msg.id.cookie + diag_msg.id.cookie = os.urandom(len(real_cookie)) + req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP) + self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req) + self.assertSocketConnected(s) + + # Now close it with the correct cookie. + req.id.cookie = real_cookie + self.sock_diag.CloseSocket(req) + + # Check that both sockets in the pair are closed. + self.assertSocketsClosed(socketpair) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testNonTcpSockets(self): + s = socket(AF_INET6, SOCK_DGRAM, 0) + s.connect(("::1", 53)) + diag_msg = self.sock_diag.FindSockDiagFromFd(s) + self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s) + + def testNonSockDiagCommand(self): + def DiagDump(code): + sock_id = self.sock_diag._EmptyInetDiagSockId() + req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff, + sock_id)) + self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg) + + op = sock_diag.SOCK_DIAG_BY_FAMILY + DiagDump(op) # No errors? Good. + self.assertRaisesErrno(EINVAL, DiagDump, op + 17) + + # TODO: + # Test that killing unix sockets returns EOPNOTSUPP. + + +class SocketExceptionThread(threading.Thread): + + def __init__(self, sock, operation): + self.exception = None + super(SocketExceptionThread, self).__init__() + self.daemon = True + self.sock = sock + self.operation = operation + + def run(self): + try: + self.operation(self.sock) + except Exception, e: + self.exception = e + + +# TODO: Take a tun fd as input, make this a utility class, and reuse at least +# in forwarding_test. +class TcpTest(SockDiagTest): + + def setUp(self): + super(TcpTest, self).setUp() + self.sock_diag = sock_diag.SockDiag() + self.netid = random.choice(self.tuns.keys()) + + def OpenListenSocket(self, version): + self.port = packets.RandomPort() + family = {4: AF_INET, 6: AF_INET6}[version] + address = {4: "0.0.0.0", 6: "::"}[version] + s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) + s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) + s.bind((address, self.port)) + # We haven't configured inbound iptables marking, so bind explicitly. + self.SelectInterface(s, self.netid, "mark") + s.listen(100) + return s + + def _ReceiveAndExpectResponse(self, netid, packet, reply, msg): + pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet, + reply, msg) + self.last_packet = pkt + return pkt + + def ReceivePacketOn(self, netid, packet): + super(TcpTest, self).ReceivePacketOn(netid, packet) + self.last_packet = packet + + def RstPacket(self): + return packets.RST(self.version, self.myaddr, self.remoteaddr, + self.last_packet) + + def IncomingConnection(self, version, end_state, netid): + self.version = version + self.s = self.OpenListenSocket(version) + self.end_state = end_state + + remoteaddr = self.remoteaddr = self.GetRemoteAddress(version) + myaddr = self.myaddr = self.MyAddress(version, netid) + + if end_state == sock_diag.TCP_LISTEN: + return + + desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr) + synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn) + msg = "Received %s, expected to see reply %s" % (desc, synack_desc) + reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg) + if end_state == sock_diag.TCP_SYN_RECV: + return + + establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1] + self.ReceivePacketOn(netid, establishing_ack) + + self.accepted, _ = self.s.accept() + desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack, + payload=net_test.UDP_PAYLOAD) + + if end_state == sock_diag.TCP_ESTABLISHED: + return + + self.accepted.send(net_test.UDP_PAYLOAD) + self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data) + + desc, fin = packets.FIN(version, remoteaddr, myaddr, data) + fin = packets._GetIpLayer(version)(str(fin)) + ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin) + msg = "Received %s, expected to see reply %s" % (desc, ack_desc) + + # TODO: Why can't we use this? + # self._ReceiveAndExpectResponse(netid, fin, ack, msg) + self.ReceivePacketOn(netid, fin) + time.sleep(0.1) + self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack) + if end_state == sock_diag.TCP_CLOSE_WAIT: + return + + raise ValueError("Invalid TCP state %d specified" % end_state) + + def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True): + """Closes the socket and checks whether a RST is sent or not.""" + if sock is not None: + self.assertIsNone(req, "Must specify sock or req, not both") + self.sock_diag.CloseSocketFromFd(sock) + self.assertRaisesErrno(EINVAL, sock.accept) + else: + self.assertIsNone(sock, "Must specify sock or req, not both") + self.sock_diag.CloseSocket(req) + + if expect_reset: + desc, rst = self.RstPacket() + msg = "%s: expecting %s: " % (msg, desc) + self.ExpectPacketOn(self.netid, msg, rst) + else: + msg = "%s: " % msg + self.ExpectNoPacketsOn(self.netid, msg) + + if sock is not None: + sock.close() + + def FindChildSockets(self, s): + """Finds the SYN_RECV child sockets of a given listening socket.""" + d = self.sock_diag.FindSockDiagFromFd(self.s) + req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP) + req.states = 1 << sock_diag.TCP_SYN_RECV + req.id.cookie = "\x00" * 8 + sockets = self.sock_diag.Dump(req) + sockets = [diag_msg for diag_msg, attrs in sockets] + return sockets + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testTcpResets(self): + """Checks that closing sockets in appropriate states sends a RST.""" + for version in [4, 6]: + msg = "Closing incoming IPv%d TCP_LISTEN socket" % version + self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid) + self.CheckRstOnClose(self.s, None, False, msg) + + msg = "Closing incoming IPv%d TCP_SYN_RECV socket" % version + self.IncomingConnection(version, sock_diag.TCP_SYN_RECV, self.netid) + children = self.FindChildSockets(self.s) + self.assertEquals(1, len(children)) + for child in children: + req = self.sock_diag.DiagReqFromDiagMsg(child, IPPROTO_TCP) + if net_test.LINUX_VERSION >= (4, 4): + # The new TCP listener code in 4.4 makes request sockets live in the + # regular TCP hash tables, and inet_diag_find_one_icsk can find them. + self.sock_diag.GetSockDiag(req) # No errors? Good, child found. + self.CheckRstOnClose(None, req, False, msg + " child") + self.assertFalse(self.FindChildSockets(self.s)) + else: + # Before 4.4, we can't see or kill SYN_RECV sockets. + self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, req) + + self.CheckRstOnClose(self.s, None, False, msg) + + msg = "Closing incoming IPv%d TCP_ESTABLISHED socket" % version + self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid) + self.CheckRstOnClose(self.s, None, False, msg) + msg = "Closing accepted IPv%d TCP_ESTABLISHED socket" % version + self.CheckRstOnClose(self.accepted, None, True, msg) + + msg = "Closing incoming IPv%d TCP_CLOSE_WAIT socket" % version + self.IncomingConnection(version, sock_diag.TCP_CLOSE_WAIT, self.netid) + self.CheckRstOnClose(self.s, None, False, msg) + msg = "Closing accepted IPv%d TCP_ESTABLISHED socket" % version + self.CheckRstOnClose(self.accepted, None, True, msg) + + def CloseDuringBlockingCall(self, sock, call, expected_errno): + thread = SocketExceptionThread(sock, call) + thread.start() + time.sleep(0.1) + self.sock_diag.CloseSocketFromFd(sock) + thread.join(1) + self.assertFalse(thread.is_alive()) + self.assertIsNotNone(thread.exception) + self.assertTrue(isinstance(thread.exception, IOError), + "Expected IOError, got %s" % thread.exception) + self.assertEqual(expected_errno, thread.exception.errno) + self.assertSocketClosed(sock) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testAcceptInterrupted(self): + """Tests that accept() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid) + self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL) + self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo") + self.assertRaisesErrno(EINVAL, self.s.accept) + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testReadInterrupted(self): + """Tests that read() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid) + self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096), + ECONNABORTED) + self.assertRaisesErrno(EPIPE, self.accepted.send, "foo") + + @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported") + def testConnectInterrupted(self): + """Tests that connect() is interrupted by SOCK_DESTROY.""" + for version in [4, 6]: + family = {4: AF_INET, 6: AF_INET6}[version] + s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP) + self.SelectInterface(s, self.netid, "mark") + remoteaddr = self.GetRemoteAddress(version) + s.bind(("", 0)) + _, sport = s.getsockname()[:2] + self.CloseDuringBlockingCall( + s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED) + desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid), + remoteaddr, sport=sport, seq=None) + self.ExpectPacketOn(self.netid, desc, syn) + msg = "SOCK_DESTROY of socket in connect, expected no RST" + self.ExpectNoPacketsOn(self.netid, msg) + if __name__ == "__main__": unittest.main() |