Browse Source

[2690] Test the run_select

And adjust the main code so it is possible to test it.
Michal 'vorner' Vaner 11 years ago
parent
commit
b177e3f9b5
2 changed files with 49 additions and 4 deletions
  1. 6 4
      src/bin/msgq/msgq.py.in
  2. 43 0
      src/bin/msgq/tests/msgq_test.py

+ 6 - 4
src/bin/msgq/msgq.py.in

@@ -202,6 +202,7 @@ class MsgQ:
         # side.
         # side.
         self.__lock = threading.Lock()
         self.__lock = threading.Lock()
         self._session = None
         self._session = None
+        self.__poller_sock = None
 
 
     def members_notify(self, event, params):
     def members_notify(self, event, params):
         """
         """
@@ -533,7 +534,7 @@ class MsgQ:
             self.sendbuffs[fileno] = (last_sent, buff)
             self.sendbuffs[fileno] = (last_sent, buff)
         return True
         return True
 
 
-    def __process_write(self, fileno):
+    def _process_write(self, fileno):
         # Try to send some data from the buffer
         # Try to send some data from the buffer
         (_, msg) = self.sendbuffs[fileno]
         (_, msg) = self.sendbuffs[fileno]
         sock = self.sockets[fileno]
         sock = self.sockets[fileno]
@@ -661,7 +662,7 @@ class MsgQ:
             reads = list(self.fd_to_lname.keys())
             reads = list(self.fd_to_lname.keys())
             if self.listen_socket.fileno() != -1: # Skip in tests
             if self.listen_socket.fileno() != -1: # Skip in tests
                 reads.append(self.listen_socket.fileno())
                 reads.append(self.listen_socket.fileno())
-            if self.__poller_sock.fileno() != -1:
+            if self.__poller_sock and self.__poller_sock.fileno() != -1:
                 reads.append(self.__poller_sock.fileno())
                 reads.append(self.__poller_sock.fileno())
             writes = list(self.sendbuffs.keys())
             writes = list(self.sendbuffs.keys())
             (read_ready, write_ready) = ([], [])
             (read_ready, write_ready) = ([], [])
@@ -685,14 +686,15 @@ class MsgQ:
                         write_ready.remove(fd)
                         write_ready.remove(fd)
                     if fd == self.listen_socket.fileno():
                     if fd == self.listen_socket.fileno():
                         self.process_accept()
                         self.process_accept()
-                    elif fd == self.__poller_sock.fileno():
+                    elif self.__poller_sock and fd == \
+                        self.__poller_sock.fileno():
                         # The signal socket. We should terminate now.
                         # The signal socket. We should terminate now.
                         self.running = False
                         self.running = False
                         break
                         break
                     else:
                     else:
                         self.process_packet(fd, self.sockets[fd])
                         self.process_packet(fd, self.sockets[fd])
                 for fd in write_ready:
                 for fd in write_ready:
-                    self.__process_write(fd)
+                    self._process_write(fd)
 
 
     def stop(self):
     def stop(self):
         # Signal it should terminate.
         # Signal it should terminate.

+ 43 - 0
src/bin/msgq/tests/msgq_test.py

@@ -990,9 +990,11 @@ class SocketTests(unittest.TestCase):
         self.__killed_socket = None
         self.__killed_socket = None
         self.__logger = self.LoggerWrapper(msgq.logger)
         self.__logger = self.LoggerWrapper(msgq.logger)
         msgq.logger = self.__logger
         msgq.logger = self.__logger
+        self.__orig_select = msgq.select.select
 
 
     def tearDown(self):
     def tearDown(self):
         msgq.logger = self.__logger.orig_logger
         msgq.logger = self.__logger.orig_logger
+        msgq.select.select = self.__orig_select
 
 
     def test_send_data(self):
     def test_send_data(self):
         # Successful case: _send_data() returns the hardcoded value, and
         # Successful case: _send_data() returns the hardcoded value, and
@@ -1073,6 +1075,47 @@ class SocketTests(unittest.TestCase):
             self.assertEqual(expected_errors, self.__logger.error_called)
             self.assertEqual(expected_errors, self.__logger.error_called)
             self.assertEqual(expected_debugs, self.__logger.debug_called)
             self.assertEqual(expected_debugs, self.__logger.debug_called)
 
 
+    def test_do_select(self):
+        """
+        Check the behaviour of the run_select method.
+
+        In particular, check that we skip writing to the sockets we read,
+        because a read may have side effects (like closing the socket) and
+        we want to prevent strange behavior.
+        """
+        self.__read_called = []
+        self.__write_called = []
+        self.__reads = None
+        self.__writes = None
+        def do_read(fd, socket):
+            self.__read_called.append(fd)
+            self.__msgq.running = False
+        def do_write(fd):
+            self.__write_called.append(fd)
+            self.__msgq.running = False
+        self.__msgq.process_packet = do_read
+        self.__msgq._process_write = do_write
+        self.__msgq.fd_to_lname = {42: 'lname', 44: 'other', 45: 'unused'}
+        # The do_select does index it, but just passes the value. So reuse
+        # the dict to safe typing in the test.
+        self.__msgq.sockets = self.__msgq.fd_to_lname
+        self.__msgq.sendbuffs = {42: 'data', 43: 'data'}
+        def my_select(reads, writes, errors):
+            self.__reads = reads
+            self.__writes = writes
+            self.assertEqual([], errors)
+            return ([42, 44], [42, 43], [])
+        msgq.select.select = my_select
+        self.__msgq.listen_socket = DummySocket
+
+        self.__msgq.running = True
+        self.__msgq.run_select()
+
+        self.assertEqual([42, 44], self.__read_called)
+        self.assertEqual([43], self.__write_called)
+        self.assertEqual({42, 44, 45}, set(self.__reads))
+        self.assertEqual({42, 43}, set(self.__writes))
+
 if __name__ == '__main__':
 if __name__ == '__main__':
     isc.log.resetUnitTestRootLogger()
     isc.log.resetUnitTestRootLogger()
     unittest.main()
     unittest.main()