Browse Source

Make sure xfrout can be shutdown, now notify-out thread and transfer-server thread can be terminated when the main thread get shutdown command.

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac335@3051 e5f2f494-b856-4b98-b285-d166d9295462
Likun Zhang 14 years ago
parent
commit
76bdd11c1c

+ 8 - 5
src/bin/xfrout/xfrout.py.in

@@ -262,7 +262,8 @@ class XfroutSession(BaseRequestHandler):
 
         for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
             if  self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
-                self._log.log_message("error", "shutdown!")
+                self._log.log_message("info", "xfrout process is being shutdown")
+                return
 
             # TODO: RRType.SOA() ?
             if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
@@ -394,8 +395,9 @@ def listen_on_xfr_query(unix_socket_server):
             # normal program flow continue by trying serve_forever()
             # again.
             if err.args[0] != errno.EINTR: raise
-
-   
+        else:
+            # serve_forever() loop has been stoped normally.
+            break
 
 class XfroutServer:
     def __init__(self):
@@ -424,8 +426,7 @@ class XfroutServer:
     def _start_notifier(self):
         datasrc = self._unix_socket_server.get_db_file()
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
-        td = threading.Thread(target = notify_out.dispatcher, args = (self._notifier,))
-        td.daemon = True
+        td = threading.Thread(target=self._notifier.dispatcher)
         td.start()
 
     def send_notify(self, zone_name, zone_class):
@@ -457,10 +458,12 @@ class XfroutServer:
         global xfrout_server
         xfrout_server = None #Avoid shutdown is called twice
         self._shutdown_event.set()
+        self._notifier.shutdown()
         if self._unix_socket_server:
             self._unix_socket_server.shutdown()
 
         main_thread = threading.currentThread()
+        # close the thread which's doing zone transfer.
         for th in threading.enumerate():
             if th is main_thread:
                 continue

+ 52 - 25
src/lib/python/isc/notify/notify_out.py

@@ -47,26 +47,7 @@ _BAD_REPLY_PACKET = 5
 def addr_to_str(addr):
     return '%s#%s' % (addr[0], addr[1])
 
-def dispatcher(notifier):
-    '''The loop function for handling notify related events.
-    If one zone get the notify reply before timeout, call the
-    handle to process the reply. If one zone can't get the notify
-    before timeout, call the handler to resend notify or notify 
-    next slave.  
-    notifier: one object of class NotifyOut. '''
-    while True:
-        replied_zones, not_replied_zones = notifier._wait_for_notify_reply()
-        if len(replied_zones) == 0 and len(not_replied_zones) == 0:
-            time.sleep(_IDLE_SLEEP_TIME) #TODO set a better time for idle sleep
-            continue
-
-        for name_ in replied_zones:
-            notifier._zone_notify_handler(replied_zones[name_], _EVENT_READ)
-            
-        for name_ in not_replied_zones:
-            if not_replied_zones[name_].notify_timeout <= time.time():
-                notifier._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
- 
+
 class ZoneNotifyInfo:
     '''This class keeps track of notify-out information for one zone.'''
 
@@ -115,14 +96,17 @@ class ZoneNotifyInfo:
 
 class NotifyOut:
     '''This class is used to handle notify logic for all zones(sending
-    notify message to its slaves).The only interface provided to 
-    the user is send_notify(). the object of this class should be 
-    used together with function dispatcher(). '''
+    notify message to its slaves). notify service can be started by 
+    calling  dispatcher(), and it can be stoped by calling shutdown()
+    in another thread. ''' 
     def __init__(self, datasrc_file, log=None, verbose=True):
         self._notify_infos = {} # key is (zone_name, zone_class)
         self._waiting_zones = []
         self._notifying_zones = []
         self._log = log
+        self._serving = False
+        self._is_shut_down = threading.Event()
+        self._read_sock, self._write_sock = socket.socketpair()
         self.notify_num = 0  # the count of in progress notifies
         self._verbose = verbose
         self._lock = threading.Lock()
@@ -165,6 +149,42 @@ class NotifyOut:
                 self.notify_num += 1 
                 self._notifying_zones.append(zone_id)
 
+    def dispatcher(self):
+        '''The loop function for handling notify related events.
+        If one zone get the notify reply before timeout, call the
+        handle to process the reply. If one zone can't get the notify
+        before timeout, call the handler to resend notify or notify 
+        next slave.  
+           The loop can be stoped by calling shutdown() in another 
+        thread. '''
+        self._serving = True
+        self._is_shut_down.clear()
+        while self._serving:
+            replied_zones, not_replied_zones = self._wait_for_notify_reply()
+            if replied_zones is None:
+                break
+
+            if len(replied_zones) == 0 and len(not_replied_zones) == 0:
+                time.sleep(_IDLE_SLEEP_TIME) #TODO set a better time for idle sleep
+                continue
+
+            for name_ in replied_zones:
+                self._zone_notify_handler(replied_zones[name_], _EVENT_READ)
+
+            for name_ in not_replied_zones:
+                if not_replied_zones[name_].notify_timeout <= time.time():
+                    self._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
+
+        self._is_shut_down.set()
+
+    def shutdown(self):
+        '''Stop the dispatcher() loop. Blocks until the loop has finished. This
+        must be called when dispatcher() is running in anther thread, or it
+        will deadlock.  '''
+        self._serving = False
+        self._write_sock.send(b'shutdown') # make self._read_sock be readable.
+        self._is_shut_down.wait()
+
     def _get_rdata_data(self, rr):
         return rr[7].strip()
 
@@ -232,16 +252,23 @@ class NotifyOut:
 
     def _wait_for_notify_reply(self):
         '''receive notify replies in specified time. returned value 
-        is one tuple:(replied_zones, not_replied_zones)
+        is one tuple:(replied_zones, not_replied_zones). (None, None)
+        will be returned when self._read_sock is readable, since user
+        has called shutdown().
         replied_zones: the zones which receive notify reply.
         not_replied_zones: the zones which haven't got notify reply.
+
         '''
         (block_timeout, valid_socks, notifying_zones) = self._prepare_select_info()
+        valid_socks.append(self._read_sock)
         try:
             r_fds, w, e = select.select(valid_socks, [], [], block_timeout)
         except select.error as err:
             if err.args[0] != EINTR:
-                return [], []
+                return {}, {}
+        
+        if self._read_sock in r_fds:
+            return None, None # user has called shutdown()
         
         not_replied_zones = {}
         replied_zones = {}

+ 21 - 3
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -53,8 +53,6 @@ class TestZoneNotifyInfo(unittest.TestCase):
 
 class TestNotifyOut(unittest.TestCase):
     def setUp(self):
-        self.old_stdout = sys.stdout
-        sys.stdout = open(os.devnull, 'w')
         self._db_file = tempfile.NamedTemporaryFile(delete=False)
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
@@ -70,7 +68,6 @@ class TestNotifyOut(unittest.TestCase):
         info.notify_slaves.append(('1.1.1.1', 5353))
 
     def tearDown(self):
-        sys.stdout = self.old_stdout
         self._db_file.close()
         os.unlink(self._db_file.name)
 
@@ -123,6 +120,19 @@ class TestNotifyOut(unittest.TestCase):
         self.assertTrue(('com.', 'IN') in timeout_zones.keys())
         self.assertLess(time.time(), self._notify._notify_infos[('com.', 'IN')].notify_timeout)
     
+    def test_wait_for_notify_reply_2(self):
+        # Test the returned value when the read_side socket is readable.
+        self._notify.send_notify('cn.')
+        self._notify.send_notify('com.')
+
+        # Now make one socket be readable
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('com.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._write_sock.send(b'shutdown')    
+        replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
+        self.assertIsNone(replied_zones) 
+        self.assertIsNone(timeout_zones) 
+
     def test_notify_next_target(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
@@ -279,6 +289,14 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual(timeout, 0)
         self.assertListEqual([2, 1], valid_fds)
 
+    def test_shutdown(self):
+        import threading
+        td = threading.Thread(target=self._notify.dispatcher)
+        td.start()
+        self.assertTrue(td.is_alive())
+        self._notify.shutdown()
+        self.assertFalse(td.is_alive())
+
 if __name__== "__main__":
     unittest.main()