Parcourir la source

[trac811] addressed review comments

Jelte Jansen il y a 14 ans
Parent
commit
b6982ea32a
3 fichiers modifiés avec 278 ajouts et 63 suppressions
  1. 159 24
      src/bin/xfrin/tests/xfrin_test.py
  2. 110 39
      src/bin/xfrin/xfrin.py.in
  3. 9 0
      src/lib/python/isc/config/config_data.py

+ 159 - 24
src/bin/xfrin/tests/xfrin_test.py

@@ -1,4 +1,4 @@
-# Copyright (C) 2009  Internet Systems Consortium.
+# Copyright (C) 2011  Internet Systems Consortium.
 #
 #
 # Permission to use, copy, modify, and distribute this software for any
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
 # purpose with or without fee is hereby granted, provided that the above
@@ -20,8 +20,10 @@ from xfrin import *
 #
 #
 # Commonly used (mostly constant) test parameters
 # Commonly used (mostly constant) test parameters
 #
 #
-TEST_ZONE_NAME = "example.com"
+TEST_ZONE_NAME_STR = "example.com."
+TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)
 TEST_RRCLASS = RRClass.IN()
 TEST_RRCLASS = RRClass.IN()
+TEST_RRCLASS_STR = 'IN'
 TEST_DB_FILE = 'db_file'
 TEST_DB_FILE = 'db_file'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
 TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
@@ -40,12 +42,12 @@ TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
 soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
                   'master.example.com. admin.example.com ' +
                   'master.example.com. admin.example.com ' +
                   '1234 3600 1800 2419200 7200')
                   '1234 3600 1800 2419200 7200')
-soa_rrset = RRset(Name(TEST_ZONE_NAME), TEST_RRCLASS, RRType.SOA(),
+soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(),
                   RRTTL(3600))
                   RRTTL(3600))
 soa_rrset.add_rdata(soa_rdata)
 soa_rrset.add_rdata(soa_rdata)
-example_axfr_question = Question(Name(TEST_ZONE_NAME), TEST_RRCLASS,
+example_axfr_question = Question(TEST_ZONE_NAME, TEST_RRCLASS,
                                  RRType.AXFR())
                                  RRType.AXFR())
-example_soa_question = Question(Name(TEST_ZONE_NAME), TEST_RRCLASS,
+example_soa_question = Question(TEST_ZONE_NAME, TEST_RRCLASS,
                                  RRType.SOA())
                                  RRType.SOA())
 default_questions = [example_axfr_question]
 default_questions = [example_axfr_question]
 default_answers = [soa_rrset]
 default_answers = [soa_rrset]
@@ -60,6 +62,13 @@ def strip_mutable_tsig_data(data):
     # Time Signed.
     # Time Signed.
     return data[0:-32] + data[-26:-22] + data[-6:]
     return data[0:-32] + data[-26:-22] + data[-6:]
 
 
+class MockCC():
+    def get_default_value(self, identifier):
+        if identifier == "zones/master_port":
+            return TEST_MASTER_PORT
+        if identifier == "zones/class":
+            return 'IN'
+
 class MockXfrin(Xfrin):
 class MockXfrin(Xfrin):
     # This is a class attribute of a callable object that specifies a non
     # This is a class attribute of a callable object that specifies a non
     # default behavior triggered in _cc_check_command().  Specific test methods
     # default behavior triggered in _cc_check_command().  Specific test methods
@@ -70,6 +79,7 @@ class MockXfrin(Xfrin):
 
 
     def _cc_setup(self):
     def _cc_setup(self):
         self._tsig_key_str = None
         self._tsig_key_str = None
+        self._module_cc = MockCC()
         pass
         pass
 
 
     def _get_db_file(self):
     def _get_db_file(self):
@@ -80,6 +90,19 @@ class MockXfrin(Xfrin):
         if MockXfrin.check_command_hook:
         if MockXfrin.check_command_hook:
             MockXfrin.check_command_hook()
             MockXfrin.check_command_hook()
 
 
+    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, tsig_key_str,
+                    check_soa = True):
+        # store some of the arguments for verification, then call this
+        # method in the superclass
+        self.xfrin_started_zone_name = zone_name
+        self.xfrin_started_rrclass = rrclass
+        self.xfrin_started_master_addr = master_addrinfo[2][0]
+        self.xfrin_started_master_port = master_addrinfo[2][1]
+        self.xfrin_started_tsig_key_str = tsig_key_str
+        return Xfrin.xfrin_start(self, zone_name, rrclass, db_file,
+                                 master_addrinfo, tsig_key_str,
+                                 check_soa = True)
+
 class MockXfrinConnection(XfrinConnection):
 class MockXfrinConnection(XfrinConnection):
     def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
     def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
                  master_addr):
                  master_addr):
@@ -450,7 +473,7 @@ class TestXfrin(unittest.TestCase):
         sys.stderr = open(os.devnull, 'w')
         sys.stderr = open(os.devnull, 'w')
         self.xfr = MockXfrin()
         self.xfr = MockXfrin()
         self.args = {}
         self.args = {}
-        self.args['zone_name'] = TEST_ZONE_NAME
+        self.args['zone_name'] = TEST_ZONE_NAME_STR
         self.args['port'] = TEST_MASTER_PORT
         self.args['port'] = TEST_MASTER_PORT
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['db_file'] = TEST_DB_FILE
         self.args['db_file'] = TEST_DB_FILE
@@ -502,8 +525,7 @@ class TestXfrin(unittest.TestCase):
     def test_parse_cmd_params_nomaster(self):
     def test_parse_cmd_params_nomaster(self):
         # master address is mandatory.
         # master address is mandatory.
         del self.args['master']
         del self.args['master']
-        master_addrinfo = self._do_parse_master_port()
-        self.assertEqual(master_addrinfo[2][0], DEFAULT_MASTER)
+        self.assertRaises(XfrinException, self._do_parse_master_port)
 
 
     def test_parse_cmd_params_bad_ip4(self):
     def test_parse_cmd_params_bad_ip4(self):
         self.args['master'] = '3.3.3.3.3'
         self.args['master'] = '3.3.3.3.3'
@@ -533,28 +555,30 @@ class TestXfrin(unittest.TestCase):
     def test_command_handler_retransfer(self):
     def test_command_handler_retransfer(self):
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
+        self.assertEqual(self.args['master'], self.xfr.xfrin_started_master_addr)
+        self.assertEqual(int(self.args['port']), self.xfr.xfrin_started_master_port)
 
 
     def test_command_handler_retransfer_short_command1(self):
     def test_command_handler_retransfer_short_command1(self):
         # try it when only specifying the zone name (of unknown zone)
         # try it when only specifying the zone name (of unknown zone)
         short_args = {}
         short_args = {}
-        short_args['zone_name'] = TEST_ZONE_NAME
+        short_args['zone_name'] = TEST_ZONE_NAME_STR
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
-                                                  short_args)['result'][0], 0)
+                                                  short_args)['result'][0], 1)
 
 
     def test_command_handler_retransfer_short_command2(self):
     def test_command_handler_retransfer_short_command2(self):
         # try it when only specifying the zone name (of unknown zone)
         # try it when only specifying the zone name (of unknown zone)
         short_args = {}
         short_args = {}
-        short_args['zone_name'] = TEST_ZONE_NAME + "."
+        short_args['zone_name'] = TEST_ZONE_NAME_STR
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
-                                                  short_args)['result'][0], 0)
+                                                  short_args)['result'][0], 1)
 
 
     def test_command_handler_retransfer_short_command3(self):
     def test_command_handler_retransfer_short_command3(self):
         # try it when only specifying the zone name (of known zone)
         # try it when only specifying the zone name (of known zone)
         short_args = {}
         short_args = {}
-        short_args['zone_name'] = TEST_ZONE_NAME
+        short_args['zone_name'] = TEST_ZONE_NAME_STR
 
 
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'name': TEST_ZONE_NAME,
+                  { 'name': TEST_ZONE_NAME_STR,
                     'master_addr': TEST_MASTER_IPV4_ADDRESS,
                     'master_addr': TEST_MASTER_IPV4_ADDRESS,
                     'master_port': TEST_MASTER_PORT
                     'master_port': TEST_MASTER_PORT
                   }
                   }
@@ -562,6 +586,10 @@ class TestXfrin(unittest.TestCase):
         self.xfr.config_handler(zones)
         self.xfr.config_handler(zones)
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   short_args)['result'][0], 0)
                                                   short_args)['result'][0], 0)
+        self.assertEqual(TEST_MASTER_IPV4_ADDRESS,
+                         self.xfr.xfrin_started_master_addr)
+        self.assertEqual(int(TEST_MASTER_PORT),
+                         self.xfr.xfrin_started_master_port)
 
 
     def test_command_handler_retransfer_badcommand(self):
     def test_command_handler_retransfer_badcommand(self):
         self.args['master'] = 'invalid'
         self.args['master'] = 'invalid'
@@ -569,13 +597,15 @@ class TestXfrin(unittest.TestCase):
                                                   self.args)['result'][0], 1)
                                                   self.args)['result'][0], 1)
 
 
     def test_command_handler_retransfer_quota(self):
     def test_command_handler_retransfer_quota(self):
+        self.args['master'] = '127.0.0.1'
+
         for i in range(self.xfr._max_transfers_in - 1):
         for i in range(self.xfr._max_transfers_in - 1):
-            self.xfr.recorder.increment(str(i) + TEST_ZONE_NAME)
+            self.xfr.recorder.increment(Name(str(i) + TEST_ZONE_NAME_STR))
         # there can be one more outstanding transfer.
         # there can be one more outstanding transfer.
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
         # make sure the # xfrs would excceed the quota
         # make sure the # xfrs would excceed the quota
-        self.xfr.recorder.increment(str(self.xfr._max_transfers_in) + TEST_ZONE_NAME)
+        self.xfr.recorder.increment(Name(str(self.xfr._max_transfers_in) + TEST_ZONE_NAME_STR))
         # this one should fail
         # this one should fail
         self.assertEqual(self.xfr.command_handler("retransfer",
         self.assertEqual(self.xfr.command_handler("retransfer",
                                                   self.args)['result'][0], 1)
                                                   self.args)['result'][0], 1)
@@ -599,6 +629,10 @@ class TestXfrin(unittest.TestCase):
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
         self.assertEqual(self.xfr.command_handler("refresh",
         self.assertEqual(self.xfr.command_handler("refresh",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
+        self.assertEqual(TEST_MASTER_IPV6_ADDRESS,
+                         self.xfr.xfrin_started_master_addr)
+        self.assertEqual(int(TEST_MASTER_PORT),
+                         self.xfr.xfrin_started_master_port)
 
 
     def test_command_handler_notify(self):
     def test_command_handler_notify(self):
         # at this level, refresh is no different than retransfer.
         # at this level, refresh is no different than retransfer.
@@ -611,8 +645,9 @@ class TestXfrin(unittest.TestCase):
         # try it with a known zone
         # try it with a known zone
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
         self.args['master'] = TEST_MASTER_IPV6_ADDRESS
 
 
+        # but use a different address in the actual command
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'name': TEST_ZONE_NAME,
+                  { 'name': TEST_ZONE_NAME_STR,
                     'master_addr': TEST_MASTER_IPV4_ADDRESS,
                     'master_addr': TEST_MASTER_IPV4_ADDRESS,
                     'master_port': TEST_MASTER_PORT
                     'master_port': TEST_MASTER_PORT
                   }
                   }
@@ -621,6 +656,16 @@ class TestXfrin(unittest.TestCase):
         self.assertEqual(self.xfr.command_handler("notify",
         self.assertEqual(self.xfr.command_handler("notify",
                                                   self.args)['result'][0], 0)
                                                   self.args)['result'][0], 0)
 
 
+        # and see if we used the address from the command, and not from
+        # the config
+        # This is actually NOT the address given in the command, which
+        # would at this point not make sense, see the TODO in
+        # xfrin.py.in:542)
+        self.assertEqual(TEST_MASTER_IPV4_ADDRESS,
+                         self.xfr.xfrin_started_master_addr)
+        self.assertEqual(int(TEST_MASTER_PORT),
+                         self.xfr.xfrin_started_master_port)
+
     def test_command_handler_unknown(self):
     def test_command_handler_unknown(self):
         self.assertEqual(self.xfr.command_handler("xxx", None)['result'][0], 1)
         self.assertEqual(self.xfr.command_handler("xxx", None)['result'][0], 1)
 
 
@@ -629,37 +674,127 @@ class TestXfrin(unittest.TestCase):
         self.assertEqual(self.xfr.config_handler({'transfers_in': 3})['result'][0], 0)
         self.assertEqual(self.xfr.config_handler({'transfers_in': 3})['result'][0], 0)
         self.assertEqual(self.xfr._max_transfers_in, 3)
         self.assertEqual(self.xfr._max_transfers_in, 3)
 
 
+    def _check_zones_config(self, config_given):
+        for zone_config in config_given['zones']:
+            zone_name = zone_config['name']
+            zone_info = self.xfr._get_zone_info(Name(zone_name), RRClass.IN())
+            self.assertEqual(zone_info.master_addr_str, zone_config['master_addr'])
+            self.assertEqual(zone_info.master_port_str, zone_config['master_port'])
+            if 'tsig_key' in zone_config:
+                self.assertEqual(zone_info.tsig_key_str, zone_config['tsig_key'])
+            else:
+                self.assertIsNone(zone_info.tsig_key_str)
+
     def test_command_handler_zones(self):
     def test_command_handler_zones(self):
+        zones1 = { 'zones': [
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.1',
+                    'master_port': 53
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones1)['result'][0], 0)
+        self._check_zones_config(zones1)
+
+        zones2 = { 'zones': [
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.2',
+                    'master_port': 53,
+                    'tsig_key': "example.com:SFuWd/q99SzF8Yzd1QbB9g=="
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones2)['result'][0], 0)
+        self._check_zones_config(zones2)
+
+        # test that configuring the zone multiple times fails
+        zones = { 'zones': [
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.1',
+                    'master_port': 53
+                  },
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.2',
+                    'master_port': 53
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
+
+        zones = { 'zones': [
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.3',
+                    'master_port': 53,
+                    'class': 'BADCLASS'
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        self._check_zones_config(zones2)
+
+        zones = { 'zones': [
+                  { 'master_addr': '192.0.2.4',
+                    'master_port': 53
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
+
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'name': 'test.com.',
-                    'master_addr': '1.1.1.1',
+                  { 'name': 'bad..zone.',
+                    'master_addr': '192.0.2.5',
                     'master_port': 53
                     'master_port': 53
                   }
                   }
                 ]}
                 ]}
-        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 0)
+        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
 
 
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'master_addr': '1.1.1.1',
+                  { 'name': '',
+                    'master_addr': '192.0.2.6',
                     'master_port': 53
                     'master_port': 53
                   }
                   }
                 ]}
                 ]}
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
 
 
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'name': 'test.com',
+                  { 'name': 'test.example',
                     'master_addr': 'badaddress',
                     'master_addr': 'badaddress',
                     'master_port': 53
                     'master_port': 53
                   }
                   }
                 ]}
                 ]}
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
 
 
         zones = { 'zones': [
         zones = { 'zones': [
-                  { 'name': 'test.com',
-                    'master_addr': '1.1.1.1',
+                  { 'name': 'test.example',
+                    'master_addr': '192.0.2.7',
                     'master_port': 'bad_port'
                     'master_port': 'bad_port'
                   }
                   }
                 ]}
                 ]}
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
         self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
+
+        # let's also add a zone that is correct too, and make sure
+        # that the new config is not partially taken
+        zones = { 'zones': [
+                  { 'name': 'test.example.',
+                    'master_addr': '192.0.2.8',
+                    'master_port': 53
+                  },
+                  { 'name': 'test2.example.',
+                    'master_addr': '192.0.2.9',
+                    'master_port': 53,
+                    'tsig_key': 'badkey'
+                  }
+                ]}
+        self.assertEqual(self.xfr.config_handler(zones)['result'][0], 1)
+        # since this has failed, we should still have the previous config
+        self._check_zones_config(zones2)
 
 
 
 
 def raise_interrupt():
 def raise_interrupt():

+ 110 - 39
src/bin/xfrin/xfrin.py.in

@@ -1,6 +1,6 @@
 #!@PYTHON@
 #!@PYTHON@
 
 
-# Copyright (C) 2010  Internet Systems Consortium.
+# Copyright (C) 2011  Internet Systems Consortium.
 #
 #
 # Permission to use, copy, modify, and distribute this software for any
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
 # purpose with or without fee is hereby granted, provided that the above
@@ -56,14 +56,12 @@ XFROUT_MODULE_NAME = 'Xfrout'
 ZONE_MANAGER_MODULE_NAME = 'Zonemgr'
 ZONE_MANAGER_MODULE_NAME = 'Zonemgr'
 REFRESH_FROM_ZONEMGR = 'refresh_from_zonemgr'
 REFRESH_FROM_ZONEMGR = 'refresh_from_zonemgr'
 ZONE_XFRIN_FAILED = 'zone_xfrin_failed'
 ZONE_XFRIN_FAILED = 'zone_xfrin_failed'
+DEFAULT_MASTER_PORT = 53
 __version__ = 'BIND10'
 __version__ = 'BIND10'
 # define xfrin rcode
 # define xfrin rcode
 XFRIN_OK = 0
 XFRIN_OK = 0
 XFRIN_FAIL = 1
 XFRIN_FAIL = 1
 
 
-DEFAULT_MASTER_PORT = '53'
-DEFAULT_MASTER = '127.0.0.1'
-
 def log_error(msg):
 def log_error(msg):
     sys.stderr.write("[b10-xfrin] %s\n" % str(msg))
     sys.stderr.write("[b10-xfrin] %s\n" % str(msg))
 
 
@@ -71,8 +69,34 @@ class XfrinException(Exception):
     pass
     pass
 
 
 class XfrinConfigException(Exception):
 class XfrinConfigException(Exception):
+    """This exception is raised if there is an error in the given
+       configuration (part), for instance when the zone's master
+       address is not a valid IP address, when the zone does not
+       have a name, or when multiple settings are given for the same
+       zone."""
     pass
     pass
 
 
+def _check_zone_name(zone_name_str):
+    """Checks if the given zone name is a valid domain name, and returns it as a Name object.
+       Raises an XfrinException if it is not."""
+    try:
+        return Name(zone_name_str)
+    except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape,
+            TooLongName, IncompleteName) as ne:
+        raise XfrinConfigException("bad zone name: " + zone_name_str + " (" + str(ne) + ")")
+
+def _check_zone_class(zone_class_str):
+    """If the given argument is a string: checks if the given class is
+       a valid one, and returns an RRClass object if so.
+       Raises XfrinConfigException if not.
+       If it is None, this function returns the default RRClass.IN()"""
+    if zone_class_str is None:
+        return RRClass.IN()
+    try:
+        return RRClass(zone_class_str)
+    except InvalidRRClass as irce:
+        raise XfrinConfigException("bad zone class: " + zone_class_str + " (" + str(irce) + ")")
+
 class XfrinConnection(asyncore.dispatcher):
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''
     '''Do xfrin in this class. '''
 
 
@@ -382,22 +406,31 @@ class XfrinRecorder:
         return ret
         return ret
 
 
 class ZoneInfo:
 class ZoneInfo:
-    def __init__(self, config_data):
+    def __init__(self, config_data, module_cc = None):
         """Creates a zone_info with the config data element as
         """Creates a zone_info with the config data element as
-           specified by the 'zones' list in xfrin.spec"""
-        self.name = config_data.get('name')
-        self.class_str = config_data.get('class') or 'IN'
+           specified by the 'zones' list in xfrin.spec. Module_cc is
+           needed to get the defaults from the specification"""
+        self.name_str = config_data.get('name')
+        self.class_str = config_data.get('class') or \
+            module_cc.get_default_value("zones/class")
 
 
-        if self.name is None:
+        if self.name_str is None:
             raise XfrinConfigException("Configuration zones list "
             raise XfrinConfigException("Configuration zones list "
                                        "element does not contain "
                                        "element does not contain "
                                        "'name' attribute")
                                        "'name' attribute")
 
 
-        # add the root dot if the user forgot
-        if len(self.name) > 0 and self.name[-1] != '.':
-            self.name += '.'
-        self.master_addr_str = config_data.get('master_addr') or DEFAULT_MASTER
-        self.master_port_str = config_data.get('master_port') or DEFAULT_MASTER_PORT
+        self.master_addr_str = config_data.get('master_addr')
+        self.master_port_str = config_data.get('master_port') or \
+            str(module_cc.get_default_value("zones/master_port"))
+
+        try:
+            self.name = Name(self.name_str)
+        except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape,
+                TooLongName, IncompleteName) as ne:
+            errmsg = "bad zone name: " + self.name_str + " (" + str(ne) + ")"
+            log_error(errmsg)
+            raise XfrinConfigException(errmsg)
+
         try:
         try:
             self.master_addr = isc.net.parse.addr_parse(self.master_addr_str)
             self.master_addr = isc.net.parse.addr_parse(self.master_addr_str)
             self.master_port = isc.net.parse.port_parse(self.master_port_str)
             self.master_port = isc.net.parse.port_parse(self.master_port_str)
@@ -406,7 +439,21 @@ class ZoneInfo:
             log_error(errmsg)
             log_error(errmsg)
             raise XfrinConfigException(errmsg)
             raise XfrinConfigException(errmsg)
 
 
+        try:
+            self.rrclass = RRClass(self.class_str)
+        except InvalidRRClass:
+            errmsg = "invalid class: " + self.class_str
+            log_error(errmsg)
+            raise XfrinConfigException(errmsg)
+
         self.tsig_key_str = config_data.get('tsig_key') or None
         self.tsig_key_str = config_data.get('tsig_key') or None
+        if self.tsig_key_str is not None:
+            try:
+                tsig_key = TSIGKey(self.tsig_key_str)
+            except InvalidParameter as ipe:
+                errmsg = "bad TSIG key string: " + self.tsig_key_str
+                log_error(errmsg)
+                raise XfrinConfigException(errmsg)
 
 
     def get_master_addr_info(self):
     def get_master_addr_info(self):
         return (self.master_addr.family, socket.SOCK_STREAM,
         return (self.master_addr.family, socket.SOCK_STREAM,
@@ -442,33 +489,50 @@ class Xfrin:
         of unit tests.'''
         of unit tests.'''
         self._module_cc.check_command(False)
         self._module_cc.check_command(False)
 
 
-    def _get_zone_info(self, name, class_str = "IN"):
+    def _get_zone_info(self, name, rrclass):
         """Returns the ZoneInfo object containing the configured data
         """Returns the ZoneInfo object containing the configured data
            for the given zone name. If the zone name did not have any
            for the given zone name. If the zone name did not have any
            data, returns None"""
            data, returns None"""
-        # add the root dot if the user forgot
-        if len(name) > 0 and name[-1] != '.':
-            name += '.'
-        if (name, class_str) in self._zones:
-            return self._zones[(name, class_str)]
+        key = (name.to_text(), rrclass.to_text())
+        if key in self._zones:
+            return self._zones[key]
         else:
         else:
             return None
             return None
 
 
-    def _clear_zone_info(self):
-        self._zones = {}
+    def _get_all_zone_info(self):
+        """Returns the structure used to store ZoneInfo objects. This
+           method can be used (together with _set_all_zone_info()) to
+           revert to the previous zone info configuration when one
+           of the new config items turns out to be bad"""
+        return self._zones
 
 
     def _add_zone_info(self, zone_info):
     def _add_zone_info(self, zone_info):
-        self._zones[(zone_info.name, zone_info.class_str)] = zone_info
+        """Add the zone info. Raises a XfrinConfigException if a zone
+           with the same name and class is already configured"""
+        key = (zone_info.name.to_text(), zone_info.class_str)
+        if key in self._zones:
+            raise XfrinConfigException("zone " + str(key) +
+                                       " configured multiple times")
+        self._zones[key] = zone_info
+
+    def _set_all_zone_info(self, zones):
+        self._zones = zones
+
+    def _clear_zone_info(self):
+        self._zones = {}
 
 
     def config_handler(self, new_config):
     def config_handler(self, new_config):
         self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in
         self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in
+
         if 'zones' in new_config:
         if 'zones' in new_config:
+            zones_backup = self._get_all_zone_info()
             self._clear_zone_info()
             self._clear_zone_info()
             for zone_config in new_config.get('zones'):
             for zone_config in new_config.get('zones'):
                 try:
                 try:
-                    zone_info = ZoneInfo(zone_config)
+                    zone_info = ZoneInfo(zone_config, self._module_cc)
                     self._add_zone_info(zone_info)
                     self._add_zone_info(zone_info)
                 except XfrinConfigException as xce:
                 except XfrinConfigException as xce:
+                    self._set_all_zone_info(zones_backup)
                     return create_answer(1, str(xce))
                     return create_answer(1, str(xce))
 
 
         return create_answer(0)
         return create_answer(0)
@@ -494,11 +558,14 @@ class Xfrin:
                 # notify command maybe has the parameters which
                 # notify command maybe has the parameters which
                 # specify the notifyfrom address and port, according the RFC1996, zone
                 # specify the notifyfrom address and port, according the RFC1996, zone
                 # transfer should starts first from the notifyfrom, but now, let 'TODO' it.
                 # transfer should starts first from the notifyfrom, but now, let 'TODO' it.
+                # (using the value now, while we can only set one master address, would be
+                # a security hole. Once we add the ability to have multiple master addresses,
+                # we should check if it matches one of them, and then use it.)
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
-                zone_info = self._get_zone_info(zone_name)
+                zone_info = self._get_zone_info(zone_name, rrclass)
                 if zone_info is None:
                 if zone_info is None:
                     # TODO what to do? no info known about zone. defaults?
                     # TODO what to do? no info known about zone. defaults?
-                    errmsg = "Got notification to retransfer unknown zone " + zone_name
+                    errmsg = "Got notification to retransfer unknown zone " + zone_name.to_text()
                     log_error(errmsg)
                     log_error(errmsg)
                     answer = create_answer(1, errmsg)
                     answer = create_answer(1, errmsg)
                 else:
                 else:
@@ -517,7 +584,7 @@ class Xfrin:
                 # master address, or else do transfer from the configured masters.
                 # master address, or else do transfer from the configured masters.
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 master_addr = self._parse_master_and_port(args)
                 master_addr = self._parse_master_and_port(args)
-                zone_info = self._get_zone_info(zone_name)
+                zone_info = self._get_zone_info(zone_name, rrclass)
                 tsig_key_str = None
                 tsig_key_str = None
                 if zone_info:
                 if zone_info:
                     tsig_key_str = zone_info.tsig_key_str
                     tsig_key_str = zone_info.tsig_key_str
@@ -538,8 +605,8 @@ class Xfrin:
         return answer
         return answer
 
 
     def _parse_zone_name_and_class(self, args):
     def _parse_zone_name_and_class(self, args):
-        zone_name = args.get('zone_name')
-        if not zone_name:
+        zone_name_str = args.get('zone_name')
+        if not zone_name_str:
             raise XfrinException('zone name should be provided')
             raise XfrinException('zone name should be provided')
 
 
         rrclass = args.get('zone_class')
         rrclass = args.get('zone_class')
@@ -551,12 +618,22 @@ class Xfrin:
             except InvalidRRClass as e:
             except InvalidRRClass as e:
                 raise XfrinException('invalid RRClass: ' + rrclass)
                 raise XfrinException('invalid RRClass: ' + rrclass)
 
 
-        return zone_name, rrclass
+        return _check_zone_name(zone_name_str), rrclass
 
 
     def _parse_master_and_port(self, args):
     def _parse_master_and_port(self, args):
         # check if we have configured info about this zone, in case
         # check if we have configured info about this zone, in case
         # port or master are not specified
         # port or master are not specified
-        zone_info = self._get_zone_info(args.get('zone_name'))
+        zone_name = _check_zone_name(args.get('zone_name'))
+        zone_class = _check_zone_class(args.get('class'))
+        zone_info = self._get_zone_info(zone_name, zone_class)
+
+        master = args.get('master')
+        if master is None:
+            if zone_info is not None:
+                master = zone_info.master_addr_str
+            else:
+                raise XfrinException("Master address not given or "
+                                     "configured for " + zone_name.to_text())
 
 
         port = args.get('port')
         port = args.get('port')
         if port is None:
         if port is None:
@@ -565,13 +642,6 @@ class Xfrin:
             else:
             else:
                 port = DEFAULT_MASTER_PORT
                 port = DEFAULT_MASTER_PORT
 
 
-        master = args.get('master')
-        if master is None:
-            if zone_info is not None:
-                master = zone_info.master_addr_str
-            else:
-                master = DEFAULT_MASTER
-
         return build_addr_info(master, port)
         return build_addr_info(master, port)
 
 
     def _get_db_file(self):
     def _get_db_file(self):
@@ -650,7 +720,8 @@ class Xfrin:
         xfrin_thread = threading.Thread(target = process_xfrin,
         xfrin_thread = threading.Thread(target = process_xfrin,
                                         args = (self,
                                         args = (self,
                                                 self.recorder,
                                                 self.recorder,
-                                                zone_name, rrclass,
+                                                zone_name.to_text(),
+                                                rrclass,
                                                 db_file,
                                                 db_file,
                                                 self._shutdown_event,
                                                 self._shutdown_event,
                                                 master_addrinfo, check_soa,
                                                 master_addrinfo, check_soa,

+ 9 - 0
src/lib/python/isc/config/config_data.py

@@ -213,6 +213,15 @@ class ConfigData:
             return spec['item_default'], True
             return spec['item_default'], True
         return None, False
         return None, False
 
 
+    def get_default_value(self, identifier):
+        """Returns the default from the specification, or None if there
+           is no default"""
+        spec = find_spec_part(self.specification.get_config_spec(), identifier)
+        if spec and 'item_default' in spec:
+            return spec['item_default']
+        else:
+            return None
+
     def get_module_spec(self):
     def get_module_spec(self):
         """Returns the ModuleSpec object associated with this ConfigData"""
         """Returns the ModuleSpec object associated with this ConfigData"""
         return self.specification
         return self.specification