Browse Source

[2380] handle signals

JINMEI Tatuya 12 years ago
parent
commit
0785c84f4f
2 changed files with 61 additions and 3 deletions
  1. 19 1
      src/bin/loadzone/loadzone.py.in
  2. 42 2
      src/bin/loadzone/tests/loadzone_test.py

+ 19 - 1
src/bin/loadzone/loadzone.py.in

@@ -18,6 +18,7 @@
 import sys
 import sys
 sys.path.append('@@PYTHONPATH@@')
 sys.path.append('@@PYTHONPATH@@')
 import time
 import time
+import signal
 from optparse import OptionParser
 from optparse import OptionParser
 from isc.dns import *
 from isc.dns import *
 from isc.datasrc import *
 from isc.datasrc import *
@@ -85,6 +86,7 @@ class LoadZoneRunner:
     def __init__(self, command_args):
     def __init__(self, command_args):
         self.__command_args = command_args
         self.__command_args = command_args
         self.__loaded_rrs = 0
         self.__loaded_rrs = 0
+        self.__interrupted = False # will be set to True on receiving signal
 
 
         # system-wide log configuration.  We need to configure logging this
         # system-wide log configuration.  We need to configure logging this
         # way so that the logging policy applies to underlying libraries, too.
         # way so that the logging policy applies to underlying libraries, too.
@@ -199,6 +201,11 @@ class LoadZoneRunner:
                         [self._zone_name.to_text()])
                         [self._zone_name.to_text()])
 
 
     def _report_progress(self, loaded_rrs):
     def _report_progress(self, loaded_rrs):
+        '''Dump the current progress report to stdout.
+
+        This is essentially private, but defined as "protected" for tests.
+
+        '''
         elapsed = time.time() - self.__start_time
         elapsed = time.time() - self.__start_time
         sys.stdout.write("\r" + (80 * " "))
         sys.stdout.write("\r" + (80 * " "))
         sys.stdout.write("\r%d RRs loaded in %.2f seconds" %
         sys.stdout.write("\r%d RRs loaded in %.2f seconds" %
@@ -225,10 +232,13 @@ class LoadZoneRunner:
                 limit = self._load_iteration_limit
                 limit = self._load_iteration_limit
             else:
             else:
                 limit = LOAD_INTERVAL_DEFAULT
                 limit = LOAD_INTERVAL_DEFAULT
-            while not loader.load_incremental(limit):
+            while (not self.__interrupted and
+                   not loader.load_incremental(limit)):
                 self.__loaded_rrs += self._load_iteration_limit
                 self.__loaded_rrs += self._load_iteration_limit
                 if self._load_iteration_limit > 0:
                 if self._load_iteration_limit > 0:
                     self._report_progress(self.__loaded_rrs)
                     self._report_progress(self.__loaded_rrs)
+            if self.__interrupted:
+                raise LoadFailure('loading interrupted by signal')
         except Exception as ex:
         except Exception as ex:
             # release any remaining lock held in the client/loader
             # release any remaining lock held in the client/loader
             loader, datasrc_client = None, None
             loader, datasrc_client = None, None
@@ -260,10 +270,18 @@ class LoadZoneRunner:
         logger.warn(LOADZONE_POSTLOAD_ISSUE, self._zone_name,
         logger.warn(LOADZONE_POSTLOAD_ISSUE, self._zone_name,
                     self._zone_class, msg)
                     self._zone_class, msg)
 
 
+    def _set_signal_handlers(self):
+        signal.signal(signal.SIGINT, self._interrupt_handler)
+        signal.signal(signal.SIGTERM, self._interrupt_handler)
+
+    def _interrupt_handler(self, signal, frame):
+        self.__interrupted = True
+
     def run(self):
     def run(self):
         '''Top-level method, simply calling other helpers'''
         '''Top-level method, simply calling other helpers'''
 
 
         try:
         try:
+            self._set_signal_handlers()
             self._parse_args()
             self._parse_args()
             self._do_load()
             self._do_load()
             logger.info(LOADZONE_DONE, self._zone_name, self._zone_class)
             logger.info(LOADZONE_DONE, self._zone_name, self._zone_class)

+ 42 - 2
src/bin/loadzone/tests/loadzone_test.py

@@ -243,7 +243,7 @@ class TestLoadZoneRunner(unittest.TestCase):
         self.__runner._do_load()
         self.__runner._do_load()
         self.__runner._post_load_checks()
         self.__runner._post_load_checks()
 
 
-    def test_load_fail_create_cancel(self):
+    def test_load_post_check_fail_soa(self):
         '''Load succeeds but warns about missing SOA, should cause warn'''
         '''Load succeeds but warns about missing SOA, should cause warn'''
         self.__common_load_setup()
         self.__common_load_setup()
         self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
         self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
@@ -252,7 +252,7 @@ class TestLoadZoneRunner(unittest.TestCase):
         self.assertEqual(1, len(self.__warnings))
         self.assertEqual(1, len(self.__warnings))
         self.assertEqual('zone has no SOA', self.__warnings[0])
         self.assertEqual('zone has no SOA', self.__warnings[0])
 
 
-    def test_load_fail_create_cancel(self):
+    def test_load_post_check_fail_ns(self):
         '''Load succeeds but warns about missing NS, should cause warn'''
         '''Load succeeds but warns about missing NS, should cause warn'''
         self.__common_load_setup()
         self.__common_load_setup()
         self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
         self.__common_post_load_setup(LOCAL_TESTDATA_PATH +
@@ -261,6 +261,43 @@ class TestLoadZoneRunner(unittest.TestCase):
         self.assertEqual(1, len(self.__warnings))
         self.assertEqual(1, len(self.__warnings))
         self.assertEqual('zone has no NS', self.__warnings[0])
         self.assertEqual('zone has no NS', self.__warnings[0])
 
 
+    def __interrupt_progress(self, loaded_rrs):
+        '''A helper emulating a signal in the middle of loading.
+
+        On the second progress report, it internally invokes the signal
+        handler to see if it stops the loading.
+
+        '''
+        self.__reports.append(loaded_rrs)
+        if len(self.__reports) == 2:
+            self.__runner._interrupt_handler()
+
+    def test_load_interrupted(self):
+        '''Load attempt fails due to signal interruption'''
+        self.__common_load_setup()
+        self.__runner._report_progress = lambda x: self.__interrupt_progress(x)
+        # The interrupting _report_progress() will terminate the loading
+        # in the middle.  the number of reports is smaller, and the zone
+        # won't be changed.
+        self.assertRaises(LoadFailure, self.__runner._do_load)
+        self.assertEqual([1, 2], self.__reports)
+        self.__check_zone_soa(ORIG_SOA_TXT)
+
+    def test_load_interrupted_create_cancel(self):
+        '''Load attempt for a new zone fails due to signal interruption
+
+        It cancels the zone creation.
+
+        '''
+        self.__common_load_setup()
+        self.__runner._report_progress = lambda x: self.__interrupt_progress(x)
+        self.__runner._zone_name = Name('example.com')
+        self.__runner._zone_file = ALT_NEW_ZONE_TXT_FILE
+        self.__check_zone_soa(None, zone_name=Name('example.com'))
+        self.assertRaises(LoadFailure, self.__runner._do_load)
+        self.assertEqual([1, 2], self.__reports)
+        self.__check_zone_soa(None, zone_name=Name('example.com'))
+
     def test_run_success(self):
     def test_run_success(self):
         '''Check for the top-level method.
         '''Check for the top-level method.
 
 
@@ -291,4 +328,7 @@ if __name__== "__main__":
     # Disable the internal logging setup so the test output won't be too
     # Disable the internal logging setup so the test output won't be too
     # verbose by default.
     # verbose by default.
     LoadZoneRunner._config_log = lambda x: None
     LoadZoneRunner._config_log = lambda x: None
+
+    # Cancel signal handlers so we can stop tests when they hang
+    LoadZoneRunner._set_signal_handlers = lambda x: None
     unittest.main()
     unittest.main()