Browse Source

[2967] cleanup and slight refactoring

Paul Selkirk 12 years ago
parent
commit
338cfe2874
2 changed files with 50 additions and 92 deletions
  1. 14 20
      src/bin/zonemgr/tests/zonemgr_test.py
  2. 36 72
      src/bin/zonemgr/zonemgr.py.in

+ 14 - 20
src/bin/zonemgr/tests/zonemgr_test.py

@@ -40,8 +40,6 @@ RELOAD_JITTER = 0.75
 rdata_net = 'a.example.net. root.example.net. 2009073106 7200 3600 2419200 21600'
 rdata_org = 'a.example.org. root.example.org. 2009073112 7200 3600 2419200 21600'
 
-TEST_SQLITE3_DBFILE = os.getenv("TESTDATAOBJDIR") + '/initdb.file'
-
 class ZonemgrTestException(Exception):
     pass
 
@@ -50,20 +48,16 @@ class FakeCCSession(isc.config.ConfigData, MockModuleCCSession):
         module_spec = isc.config.module_spec_from_file(SPECFILE_LOCATION)
         ConfigData.__init__(self, module_spec)
         MockModuleCCSession.__init__(self)
+        # For inspection
+        self.added_remote_modules = []
 
     def add_remote_config_by_name(self, name, callback):
-        pass
+        self.added_remote_modules.append((name, callback))
 
     def rpc_call(self, command, module, instance="*", to="*", params=None):
         if module not in ("Auth", "Xfrin"):
             raise ZonemgrTestException("module name not exist")
 
-    def get_remote_config_value(self, module_name, identifier):
-        if module_name == "Auth" and identifier == "database_file":
-            return TEST_SQLITE3_DBFILE, False
-        else:
-            return "unknown", False
-
 class MockDataSourceClient():
     '''A simple mock data source client.'''
     def find_zone(self, zone_name):
@@ -132,19 +126,23 @@ class MyZonemgrRefresh(ZonemgrRefresh):
 
 class TestZonemgrRefresh(unittest.TestCase):
     def setUp(self):
-        if os.path.exists(TEST_SQLITE3_DBFILE):
-            os.unlink(TEST_SQLITE3_DBFILE)
         self.stderr_backup = sys.stderr
         sys.stderr = open(os.devnull, 'w')
         self.zone_refresh = MyZonemgrRefresh()
         self.cc_session = FakeCCSession()
 
     def tearDown(self):
-        if os.path.exists(TEST_SQLITE3_DBFILE):
-            os.unlink(TEST_SQLITE3_DBFILE)
         sys.stderr.close()
         sys.stderr = self.stderr_backup
 
+    def test_init(self):
+        """Check some initial configuration after construction"""
+        # data source "module" should have been registrered as a necessary
+        # remote config
+        self.assertEqual([('data_sources',
+                           self.zone_refresh._datasrc_config_handler)],
+                         self.zone_refresh._module_cc.added_remote_modules)
+
     def test_random_jitter(self):
         max = 100025.120
         jitter = 0
@@ -332,7 +330,7 @@ class TestZonemgrRefresh(unittest.TestCase):
         rdata_net = old_rdata_net
 
         old_get_zone_soa = self.zone_refresh._get_zone_soa
-        def get_zone_soa2(zone_name, db_file):
+        def get_zone_soa2(zone_name_class):
             return None
         self.zone_refresh._get_zone_soa = get_zone_soa2
         self.zone_refresh.zonemgr_add_zone(ZONE_NAME_CLASS2_IN)
@@ -400,7 +398,7 @@ class TestZonemgrRefresh(unittest.TestCase):
         self.assertRaises(ZonemgrException, self.zone_refresh.zone_refresh_fail, ZONE_NAME_CLASS3_IN)
 
         old_get_zone_soa = self.zone_refresh._get_zone_soa
-        def get_zone_soa(zone_name, db_file):
+        def get_zone_soa(zone_name_class):
             return None
         self.zone_refresh._get_zone_soa = get_zone_soa
         self.zone_refresh.zone_refresh_fail(ZONE_NAME_CLASS1_IN)
@@ -654,7 +652,6 @@ class MyZonemgr(Zonemgr):
         def __exit__(self, type, value, traceback): pass
 
     def __init__(self):
-        self._db_file = TEST_SQLITE3_DBFILE
         self._zone_refresh = None
         self._shutdown_event = threading.Event()
         self._module_cc = FakeCCSession()
@@ -675,13 +672,10 @@ class MyZonemgr(Zonemgr):
 class TestZonemgr(unittest.TestCase):
 
     def setUp(self):
-        if os.path.exists(TEST_SQLITE3_DBFILE):
-            os.unlink(TEST_SQLITE3_DBFILE)
         self.zonemgr = MyZonemgr()
 
     def tearDown(self):
-        if os.path.exists(TEST_SQLITE3_DBFILE):
-            os.unlink(TEST_SQLITE3_DBFILE)
+        pass
 
     def test_config_handler(self):
         config_data1 = {

+ 36 - 72
src/bin/zonemgr/zonemgr.py.in

@@ -34,7 +34,6 @@ import threading
 import select
 import socket
 import errno
-from isc.datasrc import sqlite3_ds
 from optparse import OptionParser, OptionValueError
 from isc.config.ccsession import *
 import isc.util.process
@@ -117,8 +116,7 @@ class ZonemgrRefresh:
         self._max_transfer_timeout = None
         self._refresh_jitter = None
         self._reload_jitter = None
-        self.update_config_data(module_cc.get_full_config(),
-                                module_cc)
+        self.update_config_data(module_cc.get_full_config(), module_cc)
         self._running = False
         # This is essentially private, but we allow tests to customize it.
         self._datasrc_clients_mgr = DataSrcClientsMgr()
@@ -218,11 +216,6 @@ class ZonemgrRefresh:
         zone; the Auth module should have rejected the case where it's not
         even authoritative for the zone.
 
-        Note: to be more robust and less independent from other module's
-        behavior, it's probably safer to check the authority condition here,
-        too.  But right now it uses SQLite3 specific API (to be deprecated),
-        so we rather rely on Auth.
-
         Parameters:
         zone_name_class (Name, RRClass): the notified zone name and class.
         master (str): textual address of the NOTIFY sender.
@@ -239,14 +232,14 @@ class ZonemgrRefresh:
     def zonemgr_reload_zone(self, zone_name_class):
         """ Reload a zone."""
         self._zonemgr_refresh_info[zone_name_class]["zone_soa_rdata"] = \
-            self._get_zone_soa(zone_name_class[0], zone_name_class[1])
+            self._get_zone_soa(zone_name_class)
 
     def zonemgr_add_zone(self, zone_name_class):
         """ Add a zone into zone manager."""
         logger.debug(DBG_ZONEMGR_BASIC, ZONEMGR_LOAD_ZONE, zone_name_class[0],
                      zone_name_class[1])
         zone_info = {}
-        zone_soa = self._get_zone_soa(zone_name_class[0], zone_name_class[1])
+        zone_soa = self._get_zone_soa(zone_name_class)
         if zone_soa is None:
             logger.warn(ZONEMGR_NO_SOA, zone_name_class[0], zone_name_class[1])
             zone_info["zone_soa_rdata"] = None
@@ -263,24 +256,48 @@ class ZonemgrRefresh:
         self._set_zone_timer(zone_name_class, zone_reload_time,
                              self._reload_jitter * zone_reload_time)
 
-    def _get_zone_soa(self, zone_name, zone_class):
+    def _get_zone_soa(self, zone_name_class):
         """Retrieve the current SOA RR of the zone to be transferred."""
+
+        def get_zone_soa_rrset(datasrc_client, zone_name, zone_class):
+            """Retrieve the current SOA RR of the zone to be transferred."""
+            def format_zone_str(zone_name, zone_class):
+                """Helper function to format a zone name and class as a string
+                of the form '<name>/<class>'.
+                Parameters:
+                zone_name (isc.dns.Name) name to format
+                zone_class (isc.dns.RRClass) class to format
+                """
+                return zone_name.to_text(True) + '/' + str(zone_class)
+            # get the zone finder.  this must be SUCCESS (not even
+            # PARTIALMATCH) because we are specifying the zone origin name.
+            result, finder = datasrc_client.find_zone(zone_name)
+            if result != DataSourceClient.SUCCESS:
+                # The data source doesn't know the zone.  In the context in
+                # which this function is called, this shouldn't happen.
+                raise ZonemgrException("unexpected result: zone %s doesn't exist" %
+                                       format_zone_str(zone_name, zone_class))
+            result, soa_rrset, _ = finder.find(zone_name, RRType.SOA)
+            if result != ZoneFinder.SUCCESS:
+                logger.info(ZONEMGR_NO_SOA, format_zone_str(zone_name, zone_class))
+                return None
+            return soa_rrset
+
         # Identify the data source to which the zone content is transferred,
         # and get the current zone SOA from the data source (if available).
-        # Note that we do this before spawning the zonemgr session thread.
-        # find() on the client list and use of ZoneFinder (in _get_zone_soa())
-        # should be completed within the same single thread.
         datasrc_client = None
-        clist = self._datasrc_clients_mgr.get_client_list(zone_class)
+        clist = self._datasrc_clients_mgr.get_client_list(zone_name_class[1])
         if clist is None:
             return None
         try:
-            datasrc_client = clist.find(zone_name, True, False)[0]
+            datasrc_client = clist.find(zone_name_class[0], True, False)[0]
             if datasrc_client is None: # can happen, so log it separately.
                 logger.error(ZONEMGR_DATASRC_UNKNOWN,
-                             format_zone_str(zone_name, zone_class))
+                             zone_name_class[0] + '/' + zone_name_class[1])
                 return None
-            zone_soa = _get_zone_soa(datasrc_client, Name(zone_name), RRClass(zone_class))
+            zone_soa = get_zone_soa_rrset(datasrc_client,
+                                          Name(zone_name_class[0]),
+                                          RRClass(zone_name_class[1]))
             if (zone_soa == None):
                 return None
             else:
@@ -289,7 +306,7 @@ class ZonemgrRefresh:
             # rare case error. re-raise as ZonemgrException so it'll be logged
             # in command_handler().
             raise ZonemgrException('unexpected failure in datasrc module: ' +
-                                 str(ex))
+                                   str(ex))
 
     def _zone_is_expired(self, zone_name_class):
         """Judge whether a zone is expired or not."""
@@ -777,59 +794,6 @@ class Zonemgr:
         finally:
             self._module_cc.send_stopping()
 
-# XXX copy from xfrin for now
-def format_zone_str(zone_name, zone_class):
-    """Helper function to format a zone name and class as a string of
-       the form '<name>/<class>'.
-       Parameters:
-       zone_name (isc.dns.Name) name to format
-       zone_class (isc.dns.RRClass) class to format
-    """
-    return zone_name.to_text(True) + '/' + str(zone_class)
-
-# XXX copy from xfrin for now
-def _get_zone_soa(datasrc_client, zone_name, zone_class):
-    """Retrieve the current SOA RR of the zone to be transferred.
-
-    This function is essentially private to the module, but will also
-    be called (or tweaked) from tests; no one else should use this
-    function directly.
-
-    The specified zone is expected to exist in the data source referenced
-    by the given datasrc_client at the point of the call to this function.
-    If this is not met ZonemgrException exception will be raised.
-
-    It will be used for various purposes in subsequent xfr protocol
-    processing.   It is validly possible that the zone is currently
-    empty and therefore doesn't have an SOA, so this method doesn't
-    consider it an error and returns None in such a case.  It may or
-    may not result in failure in the actual processing depending on
-    how the SOA is used.
-
-    When the zone has an SOA RR, this method makes sure that it's
-    valid, i.e., it has exactly one RDATA; if it is not the case
-    this method returns None.
-
-    """
-    # get the zone finder.  this must be SUCCESS (not even
-    # PARTIALMATCH) because we are specifying the zone origin name.
-    result, finder = datasrc_client.find_zone(zone_name)
-    if result != DataSourceClient.SUCCESS:
-        # The data source doesn't know the zone.  In the context of this
-        # function is called, this shouldn't happen.
-        raise ZonemgrException("unexpected result: zone %s doesn't exist" %
-                             format_zone_str(zone_name, zone_class))
-    result, soa_rrset, _ = finder.find(zone_name, RRType.SOA)
-    if result != ZoneFinder.SUCCESS:
-        logger.info(ZONEMGR_NO_SOA, format_zone_str(zone_name, zone_class))
-        return None
-    if soa_rrset.get_rdata_count() != 1:
-        logger.warn(ZONEMGR_MULTIPLE_SOA,
-                    format_zone_str(zone_name, zone_class),
-                    soa_rrset.get_rdata_count())
-        return None
-    return soa_rrset
-
 zonemgrd = None
 
 def signal_handler(signal, frame):