Browse Source

Fix the potential race condition problem and change the code according jinmei's review result.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac352@3270 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 14 years ago
parent
commit
55618d4439

+ 36 - 23
src/lib/python/isc/utils/serve_mixin.py

@@ -20,7 +20,10 @@ import select
 SOCK_DATA = b'somedata'
 SOCK_DATA = b'somedata'
 class ServeMixIn:
 class ServeMixIn:
     '''Mix-In class to override the function serve_forever()
     '''Mix-In class to override the function serve_forever()
-    and shutdown() in class socketserver.TCPServer.
+    and shutdown() in class socketserver::TCPServer.
+      serve_forever() in socketserver::TCPServer use polling
+    for checking request, which reduces the responsiveness to
+    the shutdown request and wastes cpu at all other times.
       ServeMixIn should be used together with socketserver.TCPServer
       ServeMixIn should be used together with socketserver.TCPServer
     or some derived classes of it, and ServeMixIn must be the first
     or some derived classes of it, and ServeMixIn must be the first
     base class in multiple inheritance, eg. MyClass(ServeMixIn,
     base class in multiple inheritance, eg. MyClass(ServeMixIn,
@@ -29,42 +32,52 @@ class ServeMixIn:
     '''
     '''
     def __init__(self):
     def __init__(self):
         self.__serving = False
         self.__serving = False
-        self.__is_shut_down = threading.Event()
         self.__read_sock, self.__write_sock = socket.socketpair()
         self.__read_sock, self.__write_sock = socket.socketpair()
+        self.__serve_thread = None
 
 
-    def serve_forever(self, poll_interval=0.5):
-        ''' Override the serve_forever([poll_interval]) in class
-        socketserver.TCPServer. use one socket pair to wake up
-        the select when shutdown() is called in anther thread.
+    def serve_forever(self, poll_interval=None):
+        ''' Override the serve_forever([poll_interval]) in class 
+        socketserver.TCPServer by using the socketpair to wake up
+        instead of pulling.
           Note, parameter 'poll_interval' is just used to keep the
           Note, parameter 'poll_interval' is just used to keep the
         interface, it's never used in this function.
         interface, it's never used in this function.
         '''        
         '''        
         self.__serving = True
         self.__serving = True
-        self.__is_shut_down.clear()
+        started_event = threading.Event()
+
+        self.__serve_thread = threading.Thread(target=self.__serve_forever, \
+                                               args=(started_event,))
+        self.__serve_thread.start()
+        
+        started_event.wait() # wait until the thread has started
+        return self.__serve_thread
+
+    def __serve_forever(self, syn_event):
+        '''Use one socket pair to wake up the select when shutdown() 
+        is called in anther thread.
+        '''        
+        self.__serving = True
+        syn_event.set()
+
         while self.__serving:
         while self.__serving:
             # block until the self.socket or self.__read_sock is readable
             # block until the self.socket or self.__read_sock is readable
             try:
             try:
                 r, w, e = select.select([self, self.__read_sock], [], [])
                 r, w, e = select.select([self, self.__read_sock], [], [])
-            except select.error as err:
-                if err.args[0] != EINTR:
-                    raise
-                else:
+            except select.error:
                     continue
                     continue
-            if r:
-                if (self.__read_sock in r) and \
-                   (self.__read_sock.recv(len(SOCK_DATA)) == SOCK_DATA):
-                    break
-                else:
-                    self._handle_request_noblock()
-
-        self.__is_shut_down.set()
+            
+            if self.__read_sock in r:
+                break
+            else:
+                self._handle_request_noblock()
 
 
     def shutdown(self):
     def shutdown(self):
-        '''Stops the serve_forever loop.
-        Blocks until the loop has finished, the function should be called
-        in another thread when serve_forever is running, or it will block.
+        '''Stops the self.__serve_thread( self.__serve_forever loop).
+        when self.__serve_thread is running, it will block until the 
+        self.__serve_thread terminate.
         '''
         '''
         self.__serving = False
         self.__serving = False
         self.__write_sock.send(SOCK_DATA) # make self.__read_sock readable.
         self.__write_sock.send(SOCK_DATA) # make self.__read_sock readable.
-        self.__is_shut_down.wait()
+        if self.__serve_thread:
+            self.__serve_thread.join() # wait until the serve thread terminate
 
 

+ 1 - 4
src/lib/python/isc/utils/tests/serve_mixin_test.py

@@ -46,9 +46,7 @@ class TestServeMixIn(unittest.TestCase):
         # use port 0 to select an arbitrary unused port.
         # use port 0 to select an arbitrary unused port.
         server = MyServer(('127.0.0.1', 0), MyHandler)
         server = MyServer(('127.0.0.1', 0), MyHandler)
         ip, port = server.server_address
         ip, port = server.server_address
-        server_thread = threading.Thread(target=server.serve_forever)
-        server_thread.setDaemon(True)
-        server_thread.start()
+        server_thread = server.serve_forever()
 
 
         msg = b'senddata'
         msg = b'senddata'
         self.assertEqual(msg, send_and_get_reply(ip, port, msg))
         self.assertEqual(msg, send_and_get_reply(ip, port, msg))
@@ -57,7 +55,6 @@ class TestServeMixIn(unittest.TestCase):
         # Now shutdown the server
         # Now shutdown the server
         server.shutdown()
         server.shutdown()
         # Sleep a while, make sure the thread has finished.
         # Sleep a while, make sure the thread has finished.
-        time.sleep(0.1)
         self.assertFalse(server_thread.is_alive())
         self.assertFalse(server_thread.is_alive())
 
 
 if __name__== "__main__":
 if __name__== "__main__":