Browse Source

fix #330: zonemgr exception on SIGINT

git-svn-id: svn://bind10.isc.org/svn/bind10/branches/trac335@3061 e5f2f494-b856-4b98-b285-d166d9295462
Jerry 14 years ago
parent
commit
ddf21e206a
2 changed files with 61 additions and 27 deletions
  1. 24 11
      src/bin/zonemgr/tests/zonemgr_test.py
  2. 37 16
      src/bin/zonemgr/zonemgr.py.in

+ 24 - 11
src/bin/zonemgr/tests/zonemgr_test.py

@@ -42,6 +42,9 @@ class MyZonemgrRefresh(ZonemgrRefresh):
     def __init__(self):
         self._cc = MySession()
         self._db_file = "initdb.file"
+        self._is_shut_down = threading.Event()
+        self._read_sock, self._write_sock = socket.socketpair()
+        self._master_socket, self._slave_socket = socket.socketpair()
         self._zonemgr_refresh_info = { 
          ('sd.cn.', 'IN'): {
          'last_refresh_time': 1280474398.822142,
@@ -54,11 +57,11 @@ class MyZonemgrRefresh(ZonemgrRefresh):
          'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073112 7200 3600 2419200 21600', 
          'zone_state': 0}
         } 
-
+        
 class TestZonemgrRefresh(unittest.TestCase):
     def setUp(self):
-        self.stdout_backup = sys.stdout
-        sys.stdout = open(os.devnull, 'w')
+        self.stderr_backup = sys.stderr
+        sys.stderr = open(os.devnull, 'w')
         self.zone_refresh = MyZonemgrRefresh()
 
     def test_random_jitter(self):
@@ -387,24 +390,34 @@ class TestZonemgrRefresh(unittest.TestCase):
                     'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073105 7200 3600 2419200 21600', 
                     'zone_state': ZONE_OK}
                 }
-        master_socket, slave_socket = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
-        self.zone_refresh._socket = master_socket 
-        master_socket.close()
-        self.assertRaises(ZonemgrException, self.zone_refresh.run_timer)
-
-        self.zone_refresh._socket = slave_socket
+        self.zone_refresh._check_sock = self.zone_refresh._master_socket 
         listener = threading.Thread(target = self.zone_refresh.run_timer, args = ())
         listener.setDaemon(True)
         listener.start()
         time.sleep(1)
-
+        self.zone_refresh.shutdown()
+        self.assertFalse(listener.is_alive())
+        # After running timer, the zone's state should become "refreshing".
         zone_state = self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_state"]
         self.assertTrue("refresh_timeout" in self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN].keys())
         self.assertTrue(zone_state == ZONE_REFRESHING)
 
+        # test select.error by using bad file descriptor 
+        bad_file_descriptor = self.zone_refresh._master_socket.fileno()
+        self.zone_refresh._check_sock = bad_file_descriptor  
+        self.zone_refresh._master_socket.close()
+        self.assertRaises(None, self.zone_refresh.run_timer()) 
+
+    def test_shutdown(self):
+        self.zone_refresh._check_sock = self.zone_refresh._master_socket 
+        listener = threading.Thread(target=self.zone_refresh.run_timer)
+        listener.start()
+        self.assertTrue(listener.is_alive())
+        self.zone_refresh.shutdown()
+        self.assertFalse(listener.is_alive())
 
     def tearDown(self):
-        sys.stdout = self.stdout_backup
+        sys.stderr= self.stderr_backup
 
 
 class MyCCSession():

+ 37 - 16
src/bin/zonemgr/zonemgr.py.in

@@ -94,12 +94,18 @@ class ZonemgrException(Exception):
 class ZonemgrRefresh:
     """This class will maintain and manage zone refresh info.
     It also provides methods to keep track of zone timers and 
-    do zone refresh.
+    do zone refresh. 
+    Zone timers can be started by calling run_timer(), and it 
+    can be stopped by calling shutdown() in another thread.
+
     """
 
     def __init__(self, cc, db_file, slave_socket):
         self._cc = cc
-        self._socket = slave_socket 
+        self._check_sock = slave_socket 
+        self._runnable = False
+        self._is_shut_down = threading.Event()
+        self._read_sock, self._write_sock = socket.socketpair()
         self._db_file = db_file
         self._zonemgr_refresh_info = {} 
         self._build_zonemgr_refresh_info()
@@ -328,8 +334,12 @@ class ZonemgrRefresh:
         return False
 
     def run_timer(self):
-        """Keep track of zone timers."""
-        while True:
+        """Keep track of zone timers. The loop can be stopped by calling shutdown() in 
+        another thread.
+        """
+        self._runnable = True
+        self._is_shut_down.clear()
+        while self._runnable:
             # Zonemgr has no zone.
             if self._zone_mgr_is_empty():
                 time.sleep(LOWERBOUND_RETRY) # A better time?
@@ -348,19 +358,29 @@ class ZonemgrRefresh:
             """ Wait for the socket notification for a maximum time of timeout 
             in seconds (as float)."""
             try:
-                (rlist, wlist, xlist) = select.select([self._socket], [], [], timeout)
-                if rlist:
-                    self._socket.recv(32)
-            except ValueError as e:
-                raise ZonemgrException("[b10-zonemgr] Socket has been closed\n")
-                break
+                rlist, wlist, xlist = select.select([self._check_sock, self._read_sock], [], [], timeout)
             except select.error as e:
                 if e.args[0] == errno.EINTR:
                     (rlist, wlist, xlist) = ([], [], [])
                 else:
-                    raise ZonemgrException("[b10-zonemgr] Error with select(): %s\n" % e)
+                    sys.stderr.write("[b10-zonemgr] Error with select(); %s\n" % e)
                     break
 
+            if not rlist: # timer timeout 
+                continue
+            if self._read_sock in rlist: # awaken by shutdown socket 
+                break 
+            if self._check_sock in rlist: # awaken by check socket
+                self._check_sock.recv(5)
+
+        self._is_shut_down.set()
+
+    def shutdown(self):
+        """Stop the run_timer() loop. Block until the loop has finished. This must be
+        called when run_timer() is running in another thread, or it will deadlock."""
+        self._runnable = False
+        self._write_sock.send(b'shutdown') # make self._read_sock readble
+        self._is_shut_down.wait()
 
 class Zonemgr:
     """Zone manager class."""
@@ -378,7 +398,6 @@ class Zonemgr:
     def _start_zone_refresh_timer(self):
         """Start a new thread to keep track of zone timers"""
         listener = threading.Thread(target = self._zone_refresh.run_timer, args = ())
-        listener.setDaemon(True)
         listener.start()
 
     def _setup_session(self):
@@ -406,12 +425,14 @@ class Zonemgr:
         """Shutdown the zonemgr process. the thread which is keeping track of zone
         timers should be terminated.
         """ 
+        self._zone_refresh.shutdown()
+
         self._slave_socket.close()
         self._master_socket.close()
-
         self._shutdown_event.set()
         main_thread = threading.currentThread()
         for th in threading.enumerate():
+        # Stop the thread  which is running zone refresh timer
             if th is main_thread:
                 continue
             th.join()
@@ -459,21 +480,21 @@ class Zonemgr:
             with self._lock:
                 self._zone_refresh.zone_handle_notify(zone_name_class, master)
             # Send notification to zonemgr timer thread
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
         elif command == ZONE_XFRIN_SUCCESS_COMMAND:
             """ Handle xfrin success command"""
             zone_name_class = self._parse_cmd_params(args, command)
             with self._lock:
                 self._zone_refresh.zone_refresh_success(zone_name_class)
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
         elif command == ZONE_XFRIN_FAILED_COMMAND:
             """ Handle xfrin fail command"""
             zone_name_class = self._parse_cmd_params(args, command)
             with self._lock:
                 self._zone_refresh.zone_refresh_fail(zone_name_class)
-            self._master_socket.send(b" ")
+            self._master_socket.send(b" ")# make self._slave_socket readble
 
         elif command == "shutdown":
             self.shutdown()