Parcourir la source

[1428] Closing the socket

Michal 'vorner' Vaner il y a 13 ans
Parent
commit
3a206ab523
2 fichiers modifiés avec 46 ajouts et 4 suppressions
  1. 9 2
      src/bin/bind10/bind10_src.py.in
  2. 37 2
      src/bin/bind10/tests/bind10_test.py.in

+ 9 - 2
src/bin/bind10/bind10_src.py.in

@@ -873,7 +873,7 @@ class BoB:
         others we care about.
         """
         socket = self._srv_socket.accept()
-        self._unix_sockets[socket.fileno()] = (socket, '')
+        self._unix_sockets[socket.fileno()] = (socket, b'')
 
     def _socket_data(self, socket_fileno):
         """
@@ -881,7 +881,14 @@ class BoB:
         attention. We try to read data from there. If it is closed, we remove
         it.
         """
-        pass
+        (sock, previous) = self._unix_sockets[socket_fileno]
+        while True:
+            data = sock.recv(1, socket.MSG_DONTWAIT)
+            if len(data) == 0: # The socket got to it's end
+                del self._unix_sockets[socket_fileno]
+                self.socket_consumer_dead(sock)
+                sock.close()
+                return
 
     def run(self, wakeup_fd):
         """

+ 37 - 2
src/bin/bind10/tests/bind10_test.py.in

@@ -942,6 +942,7 @@ class SocketSrvTest(unittest.TestCase):
         self.__select_backup = bind10_src.select.select
         self.__select_called = None
         self.__socket_data_called = None
+        self.__consumer_dead_called = None
 
     def tearDown(self):
         """
@@ -956,6 +957,8 @@ class SocketSrvTest(unittest.TestCase):
         def __init__(self, owner, fileno=42):
             self.__owner = owner
             self.__fileno = fileno
+            self.data = None
+            self.closed = False
 
         def fileno(self):
             return self.__fileno
@@ -963,6 +966,17 @@ class SocketSrvTest(unittest.TestCase):
         def accept(self):
             return self.__class__(self.__owner, 13)
 
+        def recv(self, bufsize, flags=0):
+            self.__owner.assertEqual(1, bufsize)
+            self.__owner.assertEqual(socket.MSG_DONTWAIT, flags)
+            if self.data is not None:
+                pass # TODO
+            else:
+                return b''
+
+        def close(self):
+            self.closed = True
+
     class __CCS:
         """
         A mock CCS, just to provide the socket file number.
@@ -1017,7 +1031,7 @@ class SocketSrvTest(unittest.TestCase):
         # The socket is properly stored there
         self.assertIsInstance(socket, self.__FalseSocket)
         # And the buffer (yet empty) is there
-        self.assertEqual({13: (socket, '')}, self.__boss._unix_sockets)
+        self.assertEqual({13: (socket, b'')}, self.__boss._unix_sockets)
 
     def __socket_data(self, socket):
         self.__boss.runnable = False
@@ -1030,13 +1044,34 @@ class SocketSrvTest(unittest.TestCase):
         self.__boss._srv_socket = self.__FalseSocket(self)
         self.__boss._socket_data = self.__socket_data
         self.__boss.ccs = self.__CCS()
-        self.__boss._unix_sockets = {13: (self.__FalseSocket(self, 13), '')}
+        self.__boss._unix_sockets = {13: (self.__FalseSocket(self, 13), b'')}
         self.__boss.runnable = True
         bind10_src.select.select = self.__select_data
         self.__boss.run(2)
         self.assertEqual(13, self.__socket_data_called)
         self.assertEqual(([2, 1, 42, 13], [], [], None), self.__select_called)
 
+    def __prepare_data(self, data):
+        socket = self.__FalseSocket(self, 13)
+        self.__boss._unix_sockets = {13: (socket, b'')}
+        socket.data = data
+        self.__boss.socket_consumer_dead = self.__consumer_dead
+        return socket
+
+    def __consumer_dead(self, socket):
+        self.__consumer_dead_called = socket
+
+    def test_socket_closed(self):
+        """
+        Test that a socket is removed and the socket_consumer_dead is called
+        when it is closed.
+        """
+        socket = self.__prepare_data(None)
+        self.__boss._socket_data(13)
+        self.assertEqual(socket, self.__consumer_dead_called)
+        self.assertEqual({}, self.__boss._unix_sockets)
+        self.assertTrue(socket.closed)
+
 if __name__ == '__main__':
     # store os.environ for test_unchanged_environment
     original_os_environ = copy.deepcopy(os.environ)