Parcourir la source

Merge branch 'work/secondary'

Michal 'vorner' Vaner il y a 14 ans
Parent
commit
7c1e4d5e1e

+ 85 - 47
src/bin/zonemgr/tests/zonemgr_test.py

@@ -22,10 +22,10 @@ import tempfile
 from zonemgr import *
 
 ZONE_NAME_CLASS1_IN = ("sd.cn.", "IN")
-ZONE_NAME_CLASS2_CH = ("tw.cn", "CH")
+ZONE_NAME_CLASS2_CH = ("tw.cn.", "CH")
 ZONE_NAME_CLASS3_IN = ("example.com", "IN")
 ZONE_NAME_CLASS1_CH = ("sd.cn.", "CH")
-ZONE_NAME_CLASS2_IN = ("tw.cn", "IN")
+ZONE_NAME_CLASS2_IN = ("tw.cn.", "IN")
 
 MAX_TRANSFER_TIMEOUT = 14400
 LOWERBOUND_REFRESH = 10
@@ -46,21 +46,43 @@ class MySession():
     def group_recvmsg(self, nonblock, seq):
         return None, None
 
+class FakeConfig:
+    def __init__(self):
+        self.zone_list = []
+        self.set_zone_list_from_name_classes([ZONE_NAME_CLASS1_IN,
+                                              ZONE_NAME_CLASS2_CH])
+    def set_zone_list_from_name_classes(self, zones):
+        self.zone_list = map(lambda nc: {"name": nc[0], "class": nc[1]}, zones)
+    def get(self, name):
+        if name == 'lowerbound_refresh':
+            return LOWERBOUND_REFRESH
+        elif name == 'lowerbound_retry':
+            return LOWERBOUND_RETRY
+        elif name == 'max_transfer_timeout':
+            return MAX_TRANSFER_TIMEOUT
+        elif name == 'jitter_scope':
+            return JITTER_SCOPE
+        elif name == 'secondary_zones':
+            return self.zone_list
+        else:
+            raise ValueError('Uknown config option')
+
 class MyZonemgrRefresh(ZonemgrRefresh):
     def __init__(self):
-        class FakeConfig:
-            def get(self, name):
-                if name == 'lowerbound_refresh':
-                    return LOWERBOUND_REFRESH
-                elif name == 'lowerbound_retry':
-                    return LOWERBOUND_RETRY
-                elif name == 'max_transfer_timeout':
-                    return MAX_TRANSFER_TIMEOUT
-                elif name == 'jitter_scope':
-                    return JITTER_SCOPE
-                else:
-                    raise ValueError('Uknown config option')
         self._master_socket, self._slave_socket = socket.socketpair()
+        self._zonemgr_refresh_info = {}
+
+        def get_zone_soa(zone_name, db_file):
+            if zone_name == 'sd.cn.':
+                return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None,
+                        'a.dns.cn. root.cnnic.cn. 2009073106 7200 3600 2419200 21600')
+            elif zone_name == 'tw.cn.':
+                return (1, 2, 'tw.cn.', 'cn.sd.', 21600, 'SOA', None,
+                        'a.dns.cn. root.cnnic.cn. 2009073112 7200 3600 2419200 21600')
+            else:
+                return None
+        sqlite3_ds.get_zone_soa = get_zone_soa
+
         ZonemgrRefresh.__init__(self, MySession(), "initdb.file",
             self._slave_socket, FakeConfig())
         current_time = time.time()
@@ -70,7 +92,7 @@ class MyZonemgrRefresh(ZonemgrRefresh):
          'next_refresh_time': current_time + 6500, 
          'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073105 7200 3600 2419200 21600', 
          'zone_state': 0},
-         ('tw.cn', 'CH'): {
+         ('tw.cn.', 'CH'): {
          'last_refresh_time': current_time, 
          'next_refresh_time': current_time + 6900, 
          'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073112 7200 3600 2419200 21600', 
@@ -272,28 +294,6 @@ class TestZonemgrRefresh(unittest.TestCase):
                                          ZONE_NAME_CLASS1_IN)
         sqlite3_ds.get_zone_soa = old_get_zone_soa
 
-    def test_build_zonemgr_refresh_info(self):
-        soa_rdata = 'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600'
-
-        def get_zones_info(db_file):
-            return [("sd.cn.", "IN")] 
-
-        def get_zone_soa(zone_name, db_file):
-            return (1, 2, 'sd.cn.', 'cn.sd.', 21600, 'SOA', None, 
-                    'a.dns.cn. root.cnnic.cn. 2009073106 1800 900 2419200 21600')
-
-        sqlite3_ds.get_zones_info = get_zones_info
-        sqlite3_ds.get_zone_soa = get_zone_soa
-
-        self.zone_refresh._zonemgr_refresh_info = {}
-        self.zone_refresh._build_zonemgr_refresh_info()
-        self.assertEqual(1, len(self.zone_refresh._zonemgr_refresh_info))
-        zone_soa_rdata = self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_soa_rdata"]
-        self.assertEqual(soa_rdata, zone_soa_rdata) 
-        self.assertEqual(ZONE_OK, self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["zone_state"])
-        self.assertTrue("last_refresh_time" in self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN].keys())
-        self.assertTrue("next_refresh_time" in self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN].keys())
-
     def test_zone_handle_notify(self):
         self.zone_refresh.zone_handle_notify(ZONE_NAME_CLASS1_IN,"127.0.0.1")
         notify_master = self.zone_refresh._zonemgr_refresh_info[ZONE_NAME_CLASS1_IN]["notify_master"]
@@ -356,7 +356,7 @@ class TestZonemgrRefresh(unittest.TestCase):
                     'next_refresh_time': time1 + 7200, 
                     'zone_soa_rdata': 'a.dns.cn. root.cnnic.cn. 2009073105 7200 3600 2419200 21600', 
                     'zone_state': ZONE_OK},
-                ("tw.cn","CH"):{
+                ("tw.cn.","CH"):{
                     'last_refresh_time': time1 - 7200, 
                     'next_refresh_time': time1, 
                     'refresh_timeout': time1 + MAX_TRANSFER_TIMEOUT, 
@@ -424,7 +424,8 @@ class TestZonemgrRefresh(unittest.TestCase):
                     "lowerbound_refresh" : 60,
                     "lowerbound_retry" : 30,
                     "max_transfer_timeout" : 19800,
-                    "jitter_scope" : 0.25
+                    "jitter_scope" : 0.25,
+                    "secondary_zones": []
                 }
         self.zone_refresh.update_config_data(config_data)
         self.assertEqual(60, self.zone_refresh._lowerbound_refresh)
@@ -440,6 +441,31 @@ class TestZonemgrRefresh(unittest.TestCase):
         self.zone_refresh.shutdown()
         self.assertFalse(listener.is_alive())
 
+    def test_secondary_zones(self):
+        """Test that we can modify the list of secondary zones"""
+        config = FakeConfig()
+        config.zone_list = []
+        # First, remove everything
+        self.zone_refresh.update_config_data(config)
+        self.assertEqual(self.zone_refresh._zonemgr_refresh_info, {})
+        # Put something in
+        config.set_zone_list_from_name_classes([ZONE_NAME_CLASS1_IN])
+        self.zone_refresh.update_config_data(config)
+        self.assertTrue(("sd.cn.", "IN") in
+                        self.zone_refresh._zonemgr_refresh_info)
+        # This one does not exist
+        config.set_zone_list_from_name_classes(["example.net", "CH"])
+        self.assertRaises(ZonemgrException,
+                          self.zone_refresh.update_config_data, config)
+        # So it should not affect the old ones
+        self.assertTrue(("sd.cn.", "IN") in
+                        self.zone_refresh._zonemgr_refresh_info)
+        # Make sure it works even when we "accidentally" forget the final dot
+        config.set_zone_list_from_name_classes([("sd.cn", "IN")])
+        self.zone_refresh.update_config_data(config)
+        self.assertTrue(("sd.cn.", "IN") in
+                        self.zone_refresh._zonemgr_refresh_info)
+
     def tearDown(self):
         sys.stderr= self.stderr_backup
 
@@ -464,10 +490,11 @@ class MyZonemgr(Zonemgr):
         self._cc = MySession()
         self._module_cc = MyCCSession()
         self._config_data = {
-                    "lowerbound_refresh" : 10, 
-                    "lowerbound_retry" : 5, 
+                    "lowerbound_refresh" : 10,
+                    "lowerbound_retry" : 5,
                     "max_transfer_timeout" : 14400,
-                    "jitter_scope" : 0.1
+                    "jitter_scope" : 0.1,
+                    "secondary_zones": []
                     }
 
     def _start_zone_refresh_timer(self):
@@ -480,12 +507,14 @@ class TestZonemgr(unittest.TestCase):
 
     def test_config_handler(self):
         config_data1 = {
-                    "lowerbound_refresh" : 60, 
-                    "lowerbound_retry" : 30, 
+                    "lowerbound_refresh" : 60,
+                    "lowerbound_retry" : 30,
                     "max_transfer_timeout" : 14400,
-                    "jitter_scope" : 0.1
+                    "jitter_scope" : 0.1,
+                    "secondary_zones": []
                     }
-        self.zonemgr.config_handler(config_data1)
+        self.assertEqual(self.zonemgr.config_handler(config_data1),
+                         {"result": [0]})
         self.assertEqual(config_data1, self.zonemgr._config_data)
         config_data2 = {"zone_name" : "sd.cn.", "port" : "53", "master" : "192.168.1.1"}
         self.zonemgr.config_handler(config_data2)
@@ -494,10 +523,19 @@ class TestZonemgr(unittest.TestCase):
         config_data3 = {"jitter_scope" : 0.7}
         self.zonemgr.config_handler(config_data3)
         self.assertEqual(0.5, self.zonemgr._config_data.get("jitter_scope"))
+        # The zone doesn't exist in database, it should be rejected
+        self.zonemgr._zone_refresh = ZonemgrRefresh(None, "initdb.file", None,
+                                                    config_data1)
+        config_data1["secondary_zones"] = [{"name": "nonexistent.example",
+                                            "class": "IN"}]
+        self.assertNotEqual(self.zonemgr.config_handler(config_data1),
+                            {"result": [0]})
+        # As it is rejected, the old value should be kept
+        self.assertEqual(0.5, self.zonemgr._config_data.get("jitter_scope"))
 
     def test_get_db_file(self):
         self.assertEqual("initdb.file", self.zonemgr.get_db_file())
-    
+
     def test_parse_cmd_params(self):
         params1 = {"zone_name" : "org.cn", "zone_class" : "CH", "master" : "127.0.0.1"}
         answer1 = (("org.cn", "CH"), "127.0.0.1")

+ 48 - 24
src/bin/zonemgr/zonemgr.py.in

@@ -100,9 +100,8 @@ class ZonemgrRefresh:
         self._cc = cc
         self._check_sock = slave_socket
         self._db_file = db_file
-        self.update_config_data(config_data)
         self._zonemgr_refresh_info = {}
-        self._build_zonemgr_refresh_info()
+        self.update_config_data(config_data)
         self._running = False
 
     def _random_jitter(self, max, jitter):
@@ -148,16 +147,13 @@ class ZonemgrRefresh:
 
     def _zone_not_exist(self, zone_name_class):
         """ Zone doesn't belong to zonemgr"""
-        if zone_name_class in self._zonemgr_refresh_info.keys():
-            return False
-        return True
+        return not zone_name_class in self._zonemgr_refresh_info
 
     def zone_refresh_success(self, zone_name_class):
         """Update zone info after zone refresh success"""
         if (self._zone_not_exist(zone_name_class)):
             raise ZonemgrException("[b10-zonemgr] Zone (%s, %s) doesn't "
                                    "belong to zonemgr" % zone_name_class)
-            return
         self.zonemgr_reload_zone(zone_name_class)
         self._set_zone_refresh_timer(zone_name_class)
         self._set_zone_state(zone_name_class, ZONE_OK)
@@ -168,7 +164,6 @@ class ZonemgrRefresh:
         if (self._zone_not_exist(zone_name_class)):
             raise ZonemgrException("[b10-zonemgr] Zone (%s, %s) doesn't "
                                    "belong to zonemgr" % zone_name_class)
-            return
         # Is zone expired?
         if (self._zone_is_expired(zone_name_class)):
             self._set_zone_state(zone_name_class, ZONE_EXPIRED)
@@ -181,7 +176,6 @@ class ZonemgrRefresh:
         if (self._zone_not_exist(zone_name_class)):
             raise ZonemgrException("[b10-zonemgr] Notified zone (%s, %s) "
                                    "doesn't belong to zonemgr" % zone_name_class)
-            return
         self._set_zone_notifier_master(zone_name_class, master)
         self._set_zone_notify_timer(zone_name_class)
 
@@ -192,6 +186,7 @@ class ZonemgrRefresh:
 
     def zonemgr_add_zone(self, zone_name_class):
         """ Add a zone into zone manager."""
+        log_msg("Loading zone (%s, %s)" % zone_name_class)
         zone_info = {}
         zone_soa = sqlite3_ds.get_zone_soa(str(zone_name_class[0]), self._db_file)
         if not zone_soa:
@@ -203,14 +198,6 @@ class ZonemgrRefresh:
                                          float(zone_soa[7].split(" ")[REFRESH_OFFSET])
         self._zonemgr_refresh_info[zone_name_class] = zone_info
 
-    def _build_zonemgr_refresh_info(self):
-        """ Build zonemgr refresh info map."""
-        log_msg("Start loading zone into zonemgr.")
-        for zone_name, zone_class in sqlite3_ds.get_zones_info(self._db_file):
-            zone_name_class = (zone_name, zone_class)
-            self.zonemgr_add_zone(zone_name_class)
-        log_msg("Finish loading zone into zonemgr.")
-
     def _zone_is_expired(self, zone_name_class):
         """Judge whether a zone is expired or not."""
         zone_expired_time = float(self._get_zone_soa_rdata(zone_name_class).split(" ")[EXPIRED_OFFSET])
@@ -415,6 +402,32 @@ class ZonemgrRefresh:
 
     def update_config_data(self, new_config):
         """ update ZonemgrRefresh config """
+        backup = self._zonemgr_refresh_info.copy()
+        try:
+            required = {}
+            # Add new zones
+            for secondary_zone in new_config.get('secondary_zones'):
+                name = secondary_zone['name']
+                # Be tolerant to sclerotic users who forget the final dot
+                if name[-1] != '.':
+                    name = name + '.'
+                name_class = (name, secondary_zone['class'])
+                required[name_class] = True
+                # Add it only if it isn't there already
+                if not name_class in self._zonemgr_refresh_info:
+                    self.zonemgr_add_zone(name_class)
+            # Drop the zones that are no longer there
+            # Do it in two phases, python doesn't like deleting while iterating
+            to_drop = []
+            for old_zone in self._zonemgr_refresh_info:
+                if not old_zone in required:
+                    to_drop.append(old_zone)
+            for drop in to_drop:
+                del self._zonemgr_refresh_info[drop]
+        # If we are not able to find it in database, restore the original
+        except:
+            self._zonemgr_refresh_info = backup
+            raise
         self._lowerbound_refresh = new_config.get('lowerbound_refresh')
         self._lowerbound_retry = new_config.get('lowerbound_retry')
         self._max_transfer_timeout = new_config.get('max_transfer_timeout')
@@ -471,26 +484,37 @@ class Zonemgr:
     def config_handler(self, new_config):
         """ Update config data. """
         answer = create_answer(0)
+        ok = True
+        complete = self._config_data.copy()
         for key in new_config:
-            if key not in self._config_data:
+            if key not in complete:
                 answer = create_answer(1, "Unknown config data: " + str(key))
+                ok = False
                 continue
-            self._config_data[key] = new_config[key]
+            complete[key] = new_config[key]
 
-        self._config_data_check(self._config_data)
-        if (self._zone_refresh):
-            self._zone_refresh.update_config_data(self._config_data)
+        self._config_data_check(complete)
+        if self._zone_refresh is not None:
+            try:
+                self._zone_refresh.update_config_data(complete)
+            except Exception as e:
+                answer = create_answer(1, str(e))
+                ok = False
+        if ok:
+            self._config_data = complete
 
         return answer
 
     def _config_data_check(self, config_data):
-        """Check whether the new config data is valid or 
-        not. """ 
+        """Check whether the new config data is valid or
+        not. It contains only basic logic, not full check against
+        database."""
         # jitter should not be bigger than half of the original value
         if config_data.get('jitter_scope') > 0.5:
             config_data['jitter_scope'] = 0.5
             log_msg("[b10-zonemgr] jitter_scope is too big, its value will "
-                      "be set to 0.5") 
+                      "be set to 0.5")
+
 
     def _parse_cmd_params(self, args, command):
         zone_name = args.get("zone_name")

+ 26 - 0
src/bin/zonemgr/zonemgr.spec.pre.in

@@ -25,6 +25,32 @@
          "item_type": "real",
          "item_optional": false,
          "item_default": 0.25
+       },
+       {
+         "item_name": "secondary_zones",
+         "item_type": "list",
+         "item_optional": false,
+         "item_default": [],
+         "list_item_spec": {
+           "item_name": "secondary_zone",
+           "item_type": "map",
+           "item_optional": false,
+           "item_default": {},
+           "map_item_spec": [
+             {
+               "item_name": "class",
+               "item_type": "string",
+               "item_optional": false,
+               "item_default": "IN"
+             },
+             {
+               "item_name": "name",
+               "item_type": "string",
+               "item_optional": false,
+               "item_default": ""
+             }
+           ]
+         }
        }
       ],
       "commands": [