Browse Source

Merge trac #335

git-svn-id: svn://bind10.isc.org/svn/bind10/trunk@3273 e5f2f494-b856-4b98-b285-d166d9295462
Michal Vaner 14 years ago
parent
commit
2e733bc5b9

+ 9 - 0
ChangeLog

@@ -1,3 +1,12 @@
+  111.	[bug]*   zhanglikun, Michal Vaner
+	Make sure process xfrin/xfrout/zonemgr/cmdctl can be stoped
+	properly when user enter "ctrl+c" or 'Boss shutdown' command
+	through	bindctl.
+
+	The ZonemgrRefresh.run_timer and NotifyOut.dispatcher spawn
+	a thread themself.
+	(Trac #335, svn r3273)
+
   110.  [func]      Michal Vaner
   110.  [func]      Michal Vaner
 	Added isc.net.check module to check ip addresses and ports for correctness
 	Added isc.net.check module to check ip addresses and ports for correctness
 	and isc.net.addr to hold IP address. The bind10, xfrin and cmdctl programs
 	and isc.net.addr to hold IP address. The bind10, xfrin and cmdctl programs

+ 9 - 7
src/bin/bind10/bind10.py.in

@@ -454,12 +454,12 @@ class BoB:
     def stop_all_processes(self):
     def stop_all_processes(self):
         """Stop all processes."""
         """Stop all processes."""
         cmd = { "command": ['shutdown']}
         cmd = { "command": ['shutdown']}
-        self.cc_session.group_sendmsg(cmd, 'Boss', 'Cmdctl')
-        self.cc_session.group_sendmsg(cmd, "Boss", "ConfigManager")
-        self.cc_session.group_sendmsg(cmd, "Boss", "Auth")
-        self.cc_session.group_sendmsg(cmd, "Boss", "Xfrout")
-        self.cc_session.group_sendmsg(cmd, "Boss", "Xfrin")
-        self.cc_session.group_sendmsg(cmd, "Boss", "Zonemgr")
+        self.cc_session.group_sendmsg(cmd, 'Cmdctl', 'Cmdctl')
+        self.cc_session.group_sendmsg(cmd, "ConfigManager", "ConfigManager")
+        self.cc_session.group_sendmsg(cmd, "Auth", "Auth")
+        self.cc_session.group_sendmsg(cmd, "Xfrout", "Xfrout")
+        self.cc_session.group_sendmsg(cmd, "Xfrin", "Xfrin")
+        self.cc_session.group_sendmsg(cmd, "Zonemgr", "Zonemgr")
         self.cc_session.group_sendmsg(cmd, "Boss", "Stats")
         self.cc_session.group_sendmsg(cmd, "Boss", "Stats")
 
 
     def stop_process(self, process):
     def stop_process(self, process):
@@ -477,7 +477,9 @@ class BoB:
         except:
         except:
             pass
             pass
         # XXX: some delay probably useful... how much is uncertain
         # XXX: some delay probably useful... how much is uncertain
-        time.sleep(0.5)  
+        # I have changed the delay from 0.5 to 1, but sometime it's 
+        # still not enough.
+        time.sleep(1)  
         self.reap_children()
         self.reap_children()
         # next try sending a SIGTERM
         # next try sending a SIGTERM
         processes_to_stop = list(self.processes.values())
         processes_to_stop = list(self.processes.values())

+ 13 - 3
src/bin/xfrin/xfrin.py.in

@@ -519,11 +519,21 @@ class Xfrin:
         param = {'zone_name': zone_name, 'zone_class': zone_class.to_text()}
         param = {'zone_name': zone_name, 'zone_class': zone_class.to_text()}
         if xfr_result == XFRIN_OK:
         if xfr_result == XFRIN_OK:
             msg = create_command(notify_out.ZONE_NEW_DATA_READY_CMD, param)
             msg = create_command(notify_out.ZONE_NEW_DATA_READY_CMD, param)
-            self._send_cc_session.group_sendmsg(msg, XFROUT_MODULE_NAME)
-            self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME)
+            # catch the exception, in case msgq has been killed.
+            try:
+                self._send_cc_session.group_sendmsg(msg, XFROUT_MODULE_NAME)
+                self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME)
+            except socket.error as err: 
+                log_error("Fail to send message to %s and %s, msgq may has been killed" 
+                          % (XFROUT_MODULE_NAME, ZONE_MANAGER_MODULE_NAME))
         else:
         else:
             msg = create_command(ZONE_XFRIN_FAILED, param)
             msg = create_command(ZONE_XFRIN_FAILED, param)
-            self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME)
+            # catch the exception, in case msgq has been killed.
+            try:
+                self._send_cc_session.group_sendmsg(msg, ZONE_MANAGER_MODULE_NAME)
+            except socket.error as err:
+                log_error("Fail to send message to %s, msgq may has been killed" 
+                          % ZONE_MANAGER_MODULE_NAME)
 
 
     def startup(self):
     def startup(self):
         while not self._shutdown_event.is_set():
         while not self._shutdown_event.is_set():

+ 11 - 9
src/bin/xfrout/xfrout.py.in

@@ -266,7 +266,8 @@ class XfroutSession(BaseRequestHandler):
 
 
         for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
         for rr_data in sqlite3_ds.get_zone_datas(zone_name, self.server.get_db_file()):
             if  self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
             if  self.server._shutdown_event.is_set(): # Check if xfrout is shutdown
-                self._log.log_message("error", "shutdown!")
+                self._log.log_message("info", "xfrout process is being shutdown")
+                return
 
 
             # TODO: RRType.SOA() ?
             # TODO: RRType.SOA() ?
             if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
             if RRType(rr_data[5]) == RRType("SOA"): #ignore soa record
@@ -398,8 +399,9 @@ def listen_on_xfr_query(unix_socket_server):
             # normal program flow continue by trying serve_forever()
             # normal program flow continue by trying serve_forever()
             # again.
             # again.
             if err.args[0] != errno.EINTR: raise
             if err.args[0] != errno.EINTR: raise
-
-   
+        else:
+            # serve_forever() loop has been stoped normally.
+            break
 
 
 class XfroutServer:
 class XfroutServer:
     def __init__(self):
     def __init__(self):
@@ -428,9 +430,7 @@ class XfroutServer:
     def _start_notifier(self):
     def _start_notifier(self):
         datasrc = self._unix_socket_server.get_db_file()
         datasrc = self._unix_socket_server.get_db_file()
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
         self._notifier = notify_out.NotifyOut(datasrc, self._log)
-        td = threading.Thread(target = notify_out.dispatcher, args = (self._notifier,))
-        td.daemon = True
-        td.start()
+        self._notifier.dispatcher()
 
 
     def send_notify(self, zone_name, zone_class):
     def send_notify(self, zone_name, zone_class):
         self._notifier.send_notify(zone_name, zone_class)
         self._notifier.send_notify(zone_name, zone_class)
@@ -443,7 +443,7 @@ class XfroutServer:
                 answer = create_answer(1, "Unknown config data: " + str(key))
                 answer = create_answer(1, "Unknown config data: " + str(key))
                 continue
                 continue
             self._config_data[key] = new_config[key]
             self._config_data[key] = new_config[key]
-        
+
         if self._log:
         if self._log:
             self._log.update_config(new_config)
             self._log.update_config(new_config)
 
 
@@ -461,9 +461,11 @@ class XfroutServer:
         global xfrout_server
         global xfrout_server
         xfrout_server = None #Avoid shutdown is called twice
         xfrout_server = None #Avoid shutdown is called twice
         self._shutdown_event.set()
         self._shutdown_event.set()
+        self._notifier.shutdown()
         if self._unix_socket_server:
         if self._unix_socket_server:
             self._unix_socket_server.shutdown()
             self._unix_socket_server.shutdown()
 
 
+        # Wait for all threads to terminate
         main_thread = threading.currentThread()
         main_thread = threading.currentThread()
         for th in threading.enumerate():
         for th in threading.enumerate():
             if th is main_thread:
             if th is main_thread:
@@ -475,7 +477,7 @@ class XfroutServer:
             self._log.log_message("info", "Received shutdown command.")
             self._log.log_message("info", "Received shutdown command.")
             self.shutdown()
             self.shutdown()
             answer = create_answer(0)
             answer = create_answer(0)
-        
+
         elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
         elif cmd == notify_out.ZONE_NEW_DATA_READY_CMD:
             zone_name = args.get('zone_name')
             zone_name = args.get('zone_name')
             zone_class = args.get('zone_class')
             zone_class = args.get('zone_class')
@@ -490,7 +492,7 @@ class XfroutServer:
         else: 
         else: 
             answer = create_answer(1, "Unknown command:" + str(cmd))
             answer = create_answer(1, "Unknown command:" + str(cmd))
 
 
-        return answer    
+        return answer
 
 
     def run(self):
     def run(self):
         '''Get and process all commands sent from cfgmgr or other modules. '''
         '''Get and process all commands sent from cfgmgr or other modules. '''

+ 2 - 0
src/bin/zonemgr/tests/Makefile.am

@@ -1,6 +1,8 @@
 PYTESTS = zonemgr_test.py
 PYTESTS = zonemgr_test.py
 EXTRA_DIST = $(PYTESTS)
 EXTRA_DIST = $(PYTESTS)
 
 
+CLEANFILES = initdb.file
+
 # later will have configure option to choose this, like: coverage run --branch
 # later will have configure option to choose this, like: coverage run --branch
 PYCOVERAGE = $(PYTHON)
 PYCOVERAGE = $(PYTHON)
 # test using command-line arguments, so use check-local target instead of TESTS
 # test using command-line arguments, so use check-local target instead of TESTS

+ 40 - 23
src/bin/zonemgr/tests/zonemgr_test.py

@@ -45,13 +45,22 @@ class MySession():
 
 
 class MyZonemgrRefresh(ZonemgrRefresh):
 class MyZonemgrRefresh(ZonemgrRefresh):
     def __init__(self):
     def __init__(self):
-        self._cc = MySession()
-        self._db_file = "initdb.file"
+        class FakeConfig:
+            def get(self, name):
+                if name == 'lowerbound_refresh':
+                    return LOWERBOUND_REFRESH
+                elif name == 'lowerbound_retry':
+                    return LOWERBOUND_RETRY
+                elif name == 'max_transfer_timeout':
+                    return MAX_TRANSFER_TIMEOUT
+                elif name == 'jitter_scope':
+                    return JITTER_SCOPE
+                else:
+                    raise ValueError('Uknown config option')
+        self._master_socket, self._slave_socket = socket.socketpair()
+        ZonemgrRefresh.__init__(self, MySession(), "initdb.file",
+            self._slave_socket, FakeConfig())
         current_time = time.time()
         current_time = time.time()
-        self._max_transfer_timeout = MAX_TRANSFER_TIMEOUT
-        self._lowerbound_refresh = LOWERBOUND_REFRESH
-        self._lowerbound_retry = LOWERBOUND_RETRY
-        self._jitter_scope = JITTER_SCOPE
         self._zonemgr_refresh_info = { 
         self._zonemgr_refresh_info = { 
          ('sd.cn.', 'IN'): {
          ('sd.cn.', 'IN'): {
          'last_refresh_time': current_time,
          'last_refresh_time': current_time,
@@ -67,8 +76,8 @@ class MyZonemgrRefresh(ZonemgrRefresh):
 
 
 class TestZonemgrRefresh(unittest.TestCase):
 class TestZonemgrRefresh(unittest.TestCase):
     def setUp(self):
     def setUp(self):
-        self.stdout_backup = sys.stdout
-        sys.stdout = open(os.devnull, 'w')
+        self.stderr_backup = sys.stderr
+        sys.stderr = open(os.devnull, 'w')
         self.zone_refresh = MyZonemgrRefresh()
         self.zone_refresh = MyZonemgrRefresh()
 
 
     def test_random_jitter(self):
     def test_random_jitter(self):
@@ -101,7 +110,7 @@ class TestZonemgrRefresh(unittest.TestCase):
         time2 = time.time()
         time2 = time.time()
         self.assertTrue((time1 + 7200 * 3 / 4) <= zone_timeout)
         self.assertTrue((time1 + 7200 * 3 / 4) <= zone_timeout)
         self.assertTrue(zone_timeout <= time2 + 7200)
         self.assertTrue(zone_timeout <= time2 + 7200)
-        
+
     def test_set_zone_retry_timer(self):
     def test_set_zone_retry_timer(self):
         time1 = time.time()
         time1 = time.time()
         self.zone_refresh._set_zone_retry_timer(ZONE_NAME_CLASS1_IN)
         self.zone_refresh._set_zone_retry_timer(ZONE_NAME_CLASS1_IN)
@@ -147,6 +156,8 @@ class TestZonemgrRefresh(unittest.TestCase):
          
          
     def test_zonemgr_reload_zone(self):
     def test_zonemgr_reload_zone(self):
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
+        # We need to restore this not to harm other tests
+        old_get_zone_soa = sqlite3_ds.get_zone_soa
         def get_zone_soa(zone_name, db_file):
         def get_zone_soa(zone_name, db_file):
             return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None, 
             return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None, 
                     'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600')
                     'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600')
@@ -154,6 +165,7 @@ class TestZonemgrRefresh(unittest.TestCase):
 
 
         self.zone_refresh.zonemgr_reload_zone(ZONE_NAME_CLASS1_IN)
         self.zone_refresh.zonemgr_reload_zone(ZONE_NAME_CLASS1_IN)
         self.assertEqual(soa_rdata, self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_soa_rdata"])
         self.assertEqual(soa_rdata, self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_soa_rdata"])
+        sqlite3_ds.get_zone_soa = old_get_zone_soa
 
 
     def test_get_zone_notifier_master(self):
     def test_get_zone_notifier_master(self):
         notify_master = "192.168.1.1"
         notify_master = "192.168.1.1"
@@ -231,6 +243,9 @@ class TestZonemgrRefresh(unittest.TestCase):
 
 
     def test_zonemgr_add_zone(self):
     def test_zonemgr_add_zone(self):
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
+        # This needs to be restored. The following test actually failed if we left
+        # this unclean
+        old_get_zone_soa = sqlite3_ds.get_zone_soa
 
 
         def get_zone_soa(zone_name, db_file):
         def get_zone_soa(zone_name, db_file):
             return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None, 
             return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None, 
@@ -251,7 +266,8 @@ class TestZonemgrRefresh(unittest.TestCase):
             return None
             return None
         sqlite3_ds.get_zone_soa = get_zone_soa2
         sqlite3_ds.get_zone_soa = get_zone_soa2
         self.assertRaises(ZonemgrException, self.zone_refresh.zonemgr_add_zone, \
         self.assertRaises(ZonemgrException, self.zone_refresh.zonemgr_add_zone, \
-                                          ZONE_NAME_CLASS1_IN)
+                                         ZONE_NAME_CLASS1_IN)
+        sqlite3_ds.get_zone_soa = old_get_zone_soa
 
 
     def test_build_zonemgr_refresh_info(self):
     def test_build_zonemgr_refresh_info(self):
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
         soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
@@ -382,7 +398,7 @@ class TestZonemgrRefresh(unittest.TestCase):
         """This case will run timer in daemon thread. 
         """This case will run timer in daemon thread. 
         The zone's next_refresh_time is less than now, so zonemgr will do zone refresh 
         The zone's next_refresh_time is less than now, so zonemgr will do zone refresh 
         immediately. The zone's state will become "refreshing". 
         immediately. The zone's state will become "refreshing". 
-        Then closing the socket ,the timer will stop, and throw a ZonemgrException."""
+        """
         time1 = time.time()
         time1 = time.time()
         self.zone_refresh._zonemgr_refresh_info = {
         self.zone_refresh._zonemgr_refresh_info = {
                 ("sd.cn.", "IN"):{
                 ("sd.cn.", "IN"):{
@@ -391,17 +407,11 @@ class TestZonemgrRefresh(unittest.TestCase):
                     'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073105 7200 3600 2419200 21600', 
                     'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073105 7200 3600 2419200 21600', 
                     'zone_state': ZONE_OK}
                     'zone_state': ZONE_OK}
                 }
                 }
-        master_socket, slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
-        self.zone_refresh._socket = master_socket 
-        master_socket.close()
-        self.assertRaises(ZonemgrException, self.zone_refresh.run_timer)
-
-        self.zone_refresh._socket = slave_socket
-        listener = threading.Thread(target = self.zone_refresh.run_timer, args = ())
-        listener.setDaemon(True)
-        listener.start()
-        time.sleep(1)
-
+        self.zone_refresh._check_sock = self.zone_refresh._master_socket 
+        listener = self.zone_refresh.run_timer(daemon=True)
+        # Shut down the timer thread
+        self.zone_refresh.shutdown()
+        # After running timer, the zone's state should become "refreshing".
         zone_state = self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_state"]
         zone_state = self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_state"]
         self.assertTrue("refresh_timeout" in self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN].keys())
         self.assertTrue("refresh_timeout" in self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN].keys())
         self.assertTrue(zone_state == ZONE_REFRESHING)
         self.assertTrue(zone_state == ZONE_REFRESHING)
@@ -419,9 +429,16 @@ class TestZonemgrRefresh(unittest.TestCase):
         self.assertEqual(19800, self.zone_refresh._max_transfer_timeout)
         self.assertEqual(19800, self.zone_refresh._max_transfer_timeout)
         self.assertEqual(0.25, self.zone_refresh._jitter_scope)
         self.assertEqual(0.25, self.zone_refresh._jitter_scope)
 
 
+    def test_shutdown(self):
+        self.zone_refresh._check_sock = self.zone_refresh._master_socket 
+        listener = self.zone_refresh.run_timer()
+        self.assertTrue(listener.is_alive())
+        # Shut down the timer thread
+        self.zone_refresh.shutdown()
+        self.assertFalse(listener.is_alive())
 
 
     def tearDown(self):
     def tearDown(self):
-        sys.stdout = self.stdout_backup
+        sys.stderr= self.stderr_backup
 
 
 
 
 class MyCCSession():
 class MyCCSession():

+ 80 - 42
src/bin/zonemgr/zonemgr.py.in

@@ -16,7 +16,7 @@
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 # WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 
 
-"""\
+"""
 This file implements the Secondary Manager program.
 This file implements the Secondary Manager program.
 
 
 The secondary manager is one of the co-operating processes
 The secondary manager is one of the co-operating processes
@@ -91,16 +91,20 @@ class ZonemgrException(Exception):
 class ZonemgrRefresh:
 class ZonemgrRefresh:
     """This class will maintain and manage zone refresh info.
     """This class will maintain and manage zone refresh info.
     It also provides methods to keep track of zone timers and 
     It also provides methods to keep track of zone timers and 
-    do zone refresh.
+    do zone refresh. 
+    Zone timers can be started by calling run_timer(), and it 
+    can be stopped by calling shutdown() in another thread.
+
     """
     """
 
 
     def __init__(self, cc, db_file, slave_socket, config_data):
     def __init__(self, cc, db_file, slave_socket, config_data):
         self._cc = cc
         self._cc = cc
-        self._socket = slave_socket 
+        self._check_sock = slave_socket 
         self._db_file = db_file
         self._db_file = db_file
         self.update_config_data(config_data)
         self.update_config_data(config_data)
         self._zonemgr_refresh_info = {} 
         self._zonemgr_refresh_info = {} 
         self._build_zonemgr_refresh_info()
         self._build_zonemgr_refresh_info()
+        self._running = False
     
     
     def _random_jitter(self, max, jitter):
     def _random_jitter(self, max, jitter):
         """Imposes some random jitters for refresh and
         """Imposes some random jitters for refresh and
@@ -319,40 +323,82 @@ class ZonemgrRefresh:
 
 
         return False
         return False
 
 
-    def run_timer(self):
-        """Keep track of zone timers."""
-        while True:
-            # Zonemgr has no zone.
+    def _run_timer(self):
+        while self._running:
+            # If zonemgr has no zone, set timer timeout to LOWERBOUND_RETRY.
             if self._zone_mgr_is_empty():
             if self._zone_mgr_is_empty():
-                time.sleep(self._lowerbound_retry) # A better time?
-                continue
-
-            zone_need_refresh = self._find_need_do_refresh_zone()
-            # If don't get zone with minimum next refresh time, set timer timeout = lowerbound_retry 
-            if not zone_need_refresh:
                 timeout = self._lowerbound_retry 
                 timeout = self._lowerbound_retry 
             else:
             else:
-                timeout = self._get_zone_next_refresh_time(zone_need_refresh) - self._get_current_time()
-                if (timeout < 0):
-                    self._do_refresh(zone_need_refresh)
-                    continue
+                zone_need_refresh = self._find_need_do_refresh_zone()
+                # If don't get zone with minimum next refresh time, set timer timeout to LOWERBOUND_RETRY
+                if not zone_need_refresh:
+                    timeout = LOWERBOUND_RETRY
+                else:
+                    timeout = self._get_zone_next_refresh_time(zone_need_refresh) - self._get_current_time()
+                    if (timeout < 0):
+                        self._do_refresh(zone_need_refresh)
+                        continue
 
 
             """ Wait for the socket notification for a maximum time of timeout 
             """ Wait for the socket notification for a maximum time of timeout 
             in seconds (as float)."""
             in seconds (as float)."""
             try:
             try:
-                (rlist, wlist, xlist) = select.select([self._socket], [], [], timeout)
-                if rlist:
-                    self._socket.recv(32)
-            except ValueError as e:
-                raise ZonemgrException("[b10-zonemgr] Socket has been closed\n")
-                break
+                rlist, wlist, xlist = select.select([self._check_sock, self._read_sock], [], [], timeout)
             except select.error as e:
             except select.error as e:
                 if e.args[0] == errno.EINTR:
                 if e.args[0] == errno.EINTR:
                     (rlist, wlist, xlist) = ([], [], [])
                     (rlist, wlist, xlist) = ([], [], [])
                 else:
                 else:
-                    raise ZonemgrException("[b10-zonemgr] Error with select(): %s\n" % e)
+                    sys.stderr.write("[b10-zonemgr] Error with select(); %s\n" % e)
                     break
                     break
 
 
+            for fd in rlist:
+                if fd == self._read_sock: # awaken by shutdown socket 
+                    # self._running will be False by now, if it is not a false
+                    # alarm
+                    continue
+                if fd == self._check_sock: # awaken by check socket
+                    self._check_sock.recv(32)
+
+    def run_timer(self, daemon=False):
+        """
+        Keep track of zone timers. Spawns and starts a thread. The thread object is returned.
+
+        You can stop it by calling shutdown().
+        """
+        # Small sanity check
+        if self._running:
+            raise RuntimeError("Trying to run the timers twice at the same time")
+
+        # Prepare the launch
+        self._running = True
+        (self._read_sock, self._write_sock) = socket.socketpair()
+
+        # Start the thread
+        self._thread = threading.Thread(target = self._run_timer, args = ())
+        if daemon:
+            self._thread.setDaemon(True)
+        self._thread.start()
+
+        # Return the thread to anyone interested
+        return self._thread
+
+    def shutdown(self):
+        """
+        Stop the run_timer() thread. Block until it finished. This must be
+        called from a different thread.
+        """
+        if not self._running:
+            raise RuntimeError("Trying to shutdown, but not running")
+
+        # Ask the thread to stop
+        self._running = False
+        self._write_sock.send(b'shutdown') # make self._read_sock readble
+        # Wait for it to actually finnish
+        self._thread.join()
+        # Wipe out what we do not need
+        self._thread = None
+        self._read_sock = None
+        self._write_sock = None
+
     def update_config_data(self, new_config):
     def update_config_data(self, new_config):
         """ update ZonemgrRefresh config """
         """ update ZonemgrRefresh config """
         self._lowerbound_refresh = new_config.get('lowerbound_refresh')
         self._lowerbound_refresh = new_config.get('lowerbound_refresh')
@@ -360,7 +406,6 @@ class ZonemgrRefresh:
         self._max_transfer_timeout = new_config.get('max_transfer_timeout')
         self._max_transfer_timeout = new_config.get('max_transfer_timeout')
         self._jitter_scope = new_config.get('jitter_scope')
         self._jitter_scope = new_config.get('jitter_scope')
 
 
-
 class Zonemgr:
 class Zonemgr:
     """Zone manager class."""
     """Zone manager class."""
     def __init__(self):
     def __init__(self):
@@ -370,16 +415,11 @@ class Zonemgr:
         # Create socket pair for communicating between main thread and zonemgr timer thread 
         # Create socket pair for communicating between main thread and zonemgr timer thread 
         self._master_socket, self._slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         self._master_socket, self._slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
         self._zone_refresh = ZonemgrRefresh(self._cc, self._db_file, self._slave_socket, self._config_data)
         self._zone_refresh = ZonemgrRefresh(self._cc, self._db_file, self._slave_socket, self._config_data)
-        self._start_zone_refresh_timer()
+        self._zone_refresh.run_timer()
 
 
         self._lock = threading.Lock()
         self._lock = threading.Lock()
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
-
-    def _start_zone_refresh_timer(self):
-        """Start a new thread to keep track of zone timers"""
-        listener = threading.Thread(target = self._zone_refresh.run_timer, args = ())
-        listener.setDaemon(True)
-        listener.start()
+        self.running = False
 
 
     def _setup_session(self):
     def _setup_session(self):
         """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving 
         """Setup two sessions for zonemgr, one(self._module_cc) is used for receiving 
@@ -410,15 +450,12 @@ class Zonemgr:
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         timers should be terminated.
         timers should be terminated.
         """ 
         """ 
+        self._zone_refresh.shutdown()
+
         self._slave_socket.close()
         self._slave_socket.close()
         self._master_socket.close()
         self._master_socket.close()
-
         self._shutdown_event.set()
         self._shutdown_event.set()
-        main_thread = threading.currentThread()
-        for th in threading.enumerate():
-            if th is main_thread:
-                continue
-            th.join()
+        self.running = False
 
 
     def config_handler(self, new_config):
     def config_handler(self, new_config):
         """ Update config data. """
         """ Update config data. """
@@ -472,21 +509,21 @@ class Zonemgr:
             with self._lock:
             with self._lock:
                 self._zone_refresh.zone_handle_notify(zone_name_class, master)
                 self._zone_refresh.zone_handle_notify(zone_name_class, master)
             # Send notification to zonemgr timer thread
             # Send notification to zonemgr timer thread
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
 
         elif command == ZONE_XFRIN_SUCCESS_COMMAND:
         elif command == ZONE_XFRIN_SUCCESS_COMMAND:
             """ Handle xfrin success command"""
             """ Handle xfrin success command"""
             zone_name_class = self._parse_cmd_params(args, command)
             zone_name_class = self._parse_cmd_params(args, command)
             with self._lock:
             with self._lock:
                 self._zone_refresh.zone_refresh_success(zone_name_class)
                 self._zone_refresh.zone_refresh_success(zone_name_class)
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
 
         elif command == ZONE_XFRIN_FAILED_COMMAND:
         elif command == ZONE_XFRIN_FAILED_COMMAND:
             """ Handle xfrin fail command"""
             """ Handle xfrin fail command"""
             zone_name_class = self._parse_cmd_params(args, command)
             zone_name_class = self._parse_cmd_params(args, command)
             with self._lock:
             with self._lock:
                 self._zone_refresh.zone_refresh_fail(zone_name_class)
                 self._zone_refresh.zone_refresh_fail(zone_name_class)
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
 
         elif command == "shutdown":
         elif command == "shutdown":
             self.shutdown()
             self.shutdown()
@@ -497,6 +534,7 @@ class Zonemgr:
         return answer
         return answer
 
 
     def run(self):
     def run(self):
+        self.running = True
         while not self._shutdown_event.is_set():
         while not self._shutdown_event.is_set():
             self._module_cc.check_command(False)
             self._module_cc.check_command(False)
 
 
@@ -536,6 +574,6 @@ if '__main__' == __name__:
     except isc.config.ModuleCCSessionError as e:
     except isc.config.ModuleCCSessionError as e:
         sys.stderr.write("[b10-zonemgr] exit zonemgr process: %s\n" % str(e))
         sys.stderr.write("[b10-zonemgr] exit zonemgr process: %s\n" % str(e))
 
 
-    if zonemgrd:
+    if zonemgrd and zonemgrd.running:
         zonemgrd.shutdown()
         zonemgrd.shutdown()
 
 

+ 107 - 38
src/lib/python/isc/notify/notify_out.py

@@ -19,6 +19,7 @@ import random
 import socket
 import socket
 import threading
 import threading
 import time
 import time
+import errno
 from isc.datasrc import sqlite3_ds
 from isc.datasrc import sqlite3_ds
 import isc
 import isc
 try: 
 try: 
@@ -44,29 +45,11 @@ _BAD_OPCODE = 3
 _BAD_QR = 4
 _BAD_QR = 4
 _BAD_REPLY_PACKET = 5
 _BAD_REPLY_PACKET = 5
 
 
+SOCK_DATA = b's'
 def addr_to_str(addr):
 def addr_to_str(addr):
     return '%s#%s' % (addr[0], addr[1])
     return '%s#%s' % (addr[0], addr[1])
 
 
-def dispatcher(notifier):
-    '''The loop function for handling notify related events.
-    If one zone get the notify reply before timeout, call the
-    handle to process the reply. If one zone can't get the notify
-    before timeout, call the handler to resend notify or notify 
-    next slave.  
-    notifier: one object of class NotifyOut. '''
-    while True:
-        replied_zones, not_replied_zones = notifier._wait_for_notify_reply()
-        if len(replied_zones) == 0 and len(not_replied_zones) == 0:
-            time.sleep(_IDLE_SLEEP_TIME) #TODO set a better time for idle sleep
-            continue
-
-        for name_ in replied_zones:
-            notifier._zone_notify_handler(replied_zones[name_], _EVENT_READ)
-            
-        for name_ in not_replied_zones:
-            if not_replied_zones[name_].notify_timeout <= time.time():
-                notifier._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
- 
+
 class ZoneNotifyInfo:
 class ZoneNotifyInfo:
     '''This class keeps track of notify-out information for one zone.'''
     '''This class keeps track of notify-out information for one zone.'''
 
 
@@ -115,14 +98,17 @@ class ZoneNotifyInfo:
 
 
 class NotifyOut:
 class NotifyOut:
     '''This class is used to handle notify logic for all zones(sending
     '''This class is used to handle notify logic for all zones(sending
-    notify message to its slaves).The only interface provided to 
-    the user is send_notify(). the object of this class should be 
-    used together with function dispatcher(). '''
+    notify message to its slaves). notify service can be started by 
+    calling  dispatcher(), and it can be stoped by calling shutdown()
+    in another thread. ''' 
     def __init__(self, datasrc_file, log=None, verbose=True):
     def __init__(self, datasrc_file, log=None, verbose=True):
         self._notify_infos = {} # key is (zone_name, zone_class)
         self._notify_infos = {} # key is (zone_name, zone_class)
         self._waiting_zones = []
         self._waiting_zones = []
         self._notifying_zones = []
         self._notifying_zones = []
         self._log = log
         self._log = log
+        self._serving = False
+        self._read_sock, self._write_sock = socket.socketpair()
+        self._read_sock.setblocking(False)
         self.notify_num = 0  # the count of in progress notifies
         self.notify_num = 0  # the count of in progress notifies
         self._verbose = verbose
         self._verbose = verbose
         self._lock = threading.Lock()
         self._lock = threading.Lock()
@@ -165,6 +151,70 @@ class NotifyOut:
                 self.notify_num += 1 
                 self.notify_num += 1 
                 self._notifying_zones.append(zone_id)
                 self._notifying_zones.append(zone_id)
 
 
+    def _dispatcher(self, started_event):
+        started_event.set() # Let the master know we are alive already
+        while self._serving:
+            replied_zones, not_replied_zones = self._wait_for_notify_reply()
+
+            for name_ in replied_zones:
+                self._zone_notify_handler(replied_zones[name_], _EVENT_READ)
+
+            for name_ in not_replied_zones:
+                if not_replied_zones[name_].notify_timeout <= time.time():
+                    self._zone_notify_handler(not_replied_zones[name_], _EVENT_TIMEOUT)
+
+    def dispatcher(self, daemon=False):
+        """Spawns a thread that will handle notify related events.
+
+        If one zone get the notify reply before timeout, call the
+        handle to process the reply. If one zone can't get the notify
+        before timeout, call the handler to resend notify or notify 
+        next slave.  
+
+        The thread can be stopped by calling shutdown().
+
+        Returns the thread object to anyone interested.
+        """
+
+        if self._serving:
+            raise RuntimeError(
+                'Dispatcher already running, tried to start twice')
+
+        # Prepare for launch
+        self._serving = True
+        started_event = threading.Event()
+
+        # Start
+        self._thread = threading.Thread(target=self._dispatcher,
+            args=[started_event])
+        if daemon:
+            self._thread.daemon = daemon
+        self._thread.start()
+
+        # Wait for it to get started
+        started_event.wait()
+
+        # Return it to anyone listening
+        return self._thread
+
+    def shutdown(self):
+        """Stop the dispatcher() thread. Blocks until the thread stopped."""
+
+        if not self._serving:
+            raise RuntimeError('Tried to stop while not running')
+
+        # Ask it to stop
+        self._serving = False
+        self._write_sock.send(SOCK_DATA) # make self._read_sock be readable.
+
+        # Wait for it
+        self._thread.join()
+
+        # Clean up
+        self._write_sock = None
+        self._read_sock = None
+        self._thread = None
+
     def _get_rdata_data(self, rr):
     def _get_rdata_data(self, rr):
         return rr[7].strip()
         return rr[7].strip()
 
 
@@ -200,49 +250,68 @@ class NotifyOut:
         return addr_list
         return addr_list
 
 
     def _prepare_select_info(self):
     def _prepare_select_info(self):
-        '''Prepare the information for select(), returned 
-        value is one tuple 
+        '''
+        Prepare the information for select(), returned
+        value is one tuple
         (block_timeout, valid_socks, notifying_zones)
         (block_timeout, valid_socks, notifying_zones)
         block_timeout: the timeout for select()
         block_timeout: the timeout for select()
         valid_socks: sockets list for waiting ready reading.
         valid_socks: sockets list for waiting ready reading.
-        notifying_zones: the zones which have been triggered 
-                        for notify. '''
+        notifying_zones: the zones which have been triggered
+                        for notify.
+        '''
         valid_socks = []
         valid_socks = []
         notifying_zones = {}
         notifying_zones = {}
-        min_timeout = None 
+        min_timeout = None
         for info in self._notify_infos:
         for info in self._notify_infos:
             sock = self._notify_infos[info].get_socket()
             sock = self._notify_infos[info].get_socket()
             if sock:
             if sock:
                 valid_socks.append(sock)
                 valid_socks.append(sock)
                 notifying_zones[info] = self._notify_infos[info]
                 notifying_zones[info] = self._notify_infos[info]
                 tmp_timeout = self._notify_infos[info].notify_timeout
                 tmp_timeout = self._notify_infos[info].notify_timeout
-                if min_timeout:
+                if min_timeout is not None:
                     if tmp_timeout < min_timeout:
                     if tmp_timeout < min_timeout:
                         min_timeout = tmp_timeout
                         min_timeout = tmp_timeout
                 else:
                 else:
                     min_timeout = tmp_timeout
                     min_timeout = tmp_timeout
-       
-        block_timeout = 0
-        if min_timeout:
+
+        block_timeout = _IDLE_SLEEP_TIME
+        if min_timeout is not None:
             block_timeout = min_timeout - time.time()
             block_timeout = min_timeout - time.time()
             if block_timeout < 0:
             if block_timeout < 0:
                 block_timeout = 0
                 block_timeout = 0
-        
+
         return (block_timeout, valid_socks, notifying_zones)
         return (block_timeout, valid_socks, notifying_zones)
 
 
     def _wait_for_notify_reply(self):
     def _wait_for_notify_reply(self):
-        '''receive notify replies in specified time. returned value 
-        is one tuple:(replied_zones, not_replied_zones)
+        '''
+        Receive notify replies in specified time. returned value
+        is one tuple:(replied_zones, not_replied_zones). ({}, {}) is
+        returned if shutdown() was called.
+
         replied_zones: the zones which receive notify reply.
         replied_zones: the zones which receive notify reply.
         not_replied_zones: the zones which haven't got notify reply.
         not_replied_zones: the zones which haven't got notify reply.
+
         '''
         '''
-        (block_timeout, valid_socks, notifying_zones) = self._prepare_select_info()
+        (block_timeout, valid_socks, notifying_zones) = \
+            self._prepare_select_info()
+        # This is None only during some tests
+        if self._read_sock is not None:
+            valid_socks.append(self._read_sock)
         try:
         try:
             r_fds, w, e = select.select(valid_socks, [], [], block_timeout)
             r_fds, w, e = select.select(valid_socks, [], [], block_timeout)
         except select.error as err:
         except select.error as err:
             if err.args[0] != EINTR:
             if err.args[0] != EINTR:
-                return [], []
-        
+                return {}, {}
+
+        if self._read_sock in r_fds: # user has called shutdown()
+            try:
+                # Noone should write anything else than shutdown
+                assert self._read_sock.recv(len(SOCK_DATA)) == SOCK_DATA
+                return {}, {}
+            except socket.error as e: # Workaround around rare linux bug
+                if e.errno != errno.EAGAIN and e.errno != errno.EWOULDBLOCK:
+                    raise
+
         not_replied_zones = {}
         not_replied_zones = {}
         replied_zones = {}
         replied_zones = {}
         for info in notifying_zones:
         for info in notifying_zones:

+ 22 - 5
src/lib/python/isc/notify/tests/notify_out_test.py

@@ -20,7 +20,7 @@ import tempfile
 import time
 import time
 import socket
 import socket
 from isc.datasrc import sqlite3_ds
 from isc.datasrc import sqlite3_ds
-from isc.notify import notify_out
+from isc.notify import notify_out, SOCK_DATA
 
 
 class TestZoneNotifyInfo(unittest.TestCase):
 class TestZoneNotifyInfo(unittest.TestCase):
     def setUp(self):
     def setUp(self):
@@ -53,8 +53,6 @@ class TestZoneNotifyInfo(unittest.TestCase):
 
 
 class TestNotifyOut(unittest.TestCase):
 class TestNotifyOut(unittest.TestCase):
     def setUp(self):
     def setUp(self):
-        self.old_stdout = sys.stdout
-        sys.stdout = open(os.devnull, 'w')
         self._db_file = tempfile.NamedTemporaryFile(delete=False)
         self._db_file = tempfile.NamedTemporaryFile(delete=False)
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'cn.', self._cn_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
         sqlite3_ds.load(self._db_file.name, 'com.', self._com_data_reader)
@@ -70,7 +68,6 @@ class TestNotifyOut(unittest.TestCase):
         info.notify_slaves.append(('1.1.1.1', 5353))
         info.notify_slaves.append(('1.1.1.1', 5353))
 
 
     def tearDown(self):
     def tearDown(self):
-        sys.stdout = self.old_stdout
         self._db_file.close()
         self._db_file.close()
         os.unlink(self._db_file.name)
         os.unlink(self._db_file.name)
 
 
@@ -123,6 +120,20 @@ class TestNotifyOut(unittest.TestCase):
         self.assertTrue(('com.', 'IN') in timeout_zones.keys())
         self.assertTrue(('com.', 'IN') in timeout_zones.keys())
         self.assertLess(time.time(), self._notify._notify_infos[('com.', 'IN')].notify_timeout)
         self.assertLess(time.time(), self._notify._notify_infos[('com.', 'IN')].notify_timeout)
     
     
+    def test_wait_for_notify_reply_2(self):
+        # Test the returned value when the read_side socket is readable.
+        self._notify.send_notify('cn.')
+        self._notify.send_notify('com.')
+
+        # Now make one socket be readable
+        self._notify._notify_infos[('cn.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._notify_infos[('com.', 'IN')].notify_timeout = time.time() + 10
+        self._notify._read_sock, self._notify._write_sock = socket.socketpair()
+        self._notify._write_sock.send(SOCK_DATA)
+        replied_zones, timeout_zones = self._notify._wait_for_notify_reply()
+        self.assertEqual(0, len(replied_zones))
+        self.assertEqual(0, len(timeout_zones))
+
     def test_notify_next_target(self):
     def test_notify_next_target(self):
         self._notify.send_notify('cn.')
         self._notify.send_notify('cn.')
         self._notify.send_notify('com.')
         self._notify.send_notify('com.')
@@ -258,7 +269,7 @@ class TestNotifyOut(unittest.TestCase):
         
         
     def test_prepare_select_info(self):
     def test_prepare_select_info(self):
         timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
         timeout, valid_fds, notifying_zones = self._notify._prepare_select_info()
-        self.assertEqual(0, timeout)
+        self.assertEqual(notify_out._IDLE_SLEEP_TIME, timeout)
         self.assertListEqual([], valid_fds)
         self.assertListEqual([], valid_fds)
 
 
         self._notify._notify_infos[('cn.', 'IN')]._sock = 1
         self._notify._notify_infos[('cn.', 'IN')]._sock = 1
@@ -279,6 +290,12 @@ class TestNotifyOut(unittest.TestCase):
         self.assertEqual(timeout, 0)
         self.assertEqual(timeout, 0)
         self.assertListEqual([2, 1], valid_fds)
         self.assertListEqual([2, 1], valid_fds)
 
 
+    def test_shutdown(self):
+        thread = self._notify.dispatcher()
+        self.assertTrue(thread.is_alive())
+        self._notify.shutdown()
+        self.assertFalse(thread.is_alive())
+
 if __name__== "__main__":
 if __name__== "__main__":
     unittest.main()
     unittest.main()