Browse Source

[trac811] more review comments

Jelte Jansen 14 years ago
parent
commit
bcea2b3da4

+ 45 - 29
src/bin/xfrin/tests/xfrin_test.py

@@ -1,4 +1,4 @@
-# Copyright (C) 2011  Internet Systems Consortium.
+# Copyright (C) 2009-2011  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -67,7 +67,7 @@ class MockCC():
         if identifier == "zones/master_port":
             return TEST_MASTER_PORT
         if identifier == "zones/class":
-            return 'IN'
+            return TEST_RRCLASS_STR
 
 class MockXfrin(Xfrin):
     # This is a class attribute of a callable object that specifies a non
@@ -78,7 +78,7 @@ class MockXfrin(Xfrin):
     check_command_hook = None
 
     def _cc_setup(self):
-        self._tsig_key_str = None
+        self._tsig_key = None
         self._module_cc = MockCC()
         pass
 
@@ -91,17 +91,14 @@ class MockXfrin(Xfrin):
             MockXfrin.check_command_hook()
 
     def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo,
-                    tsig_key_str, check_soa=True):
+                    tsig_key, check_soa=True):
         # store some of the arguments for verification, then call this
         # method in the superclass
-        self.xfrin_started_zone_name = zone_name
-        self.xfrin_started_rrclass = rrclass
         self.xfrin_started_master_addr = master_addrinfo[2][0]
         self.xfrin_started_master_port = master_addrinfo[2][1]
-        self.xfrin_started_tsig_key_str = tsig_key_str
         return Xfrin.xfrin_start(self, zone_name, rrclass, db_file,
-                                 master_addrinfo, tsig_key_str,
-                                 check_soa=True)
+                                 master_addrinfo, tsig_key,
+                                 check_soa)
 
 class MockXfrinConnection(XfrinConnection):
     def __init__(self, sock_map, zone_name, rrclass, db_file, shutdown_event,
@@ -474,6 +471,7 @@ class TestXfrin(unittest.TestCase):
         self.xfr = MockXfrin()
         self.args = {}
         self.args['zone_name'] = TEST_ZONE_NAME_STR
+        self.args['class'] = TEST_RRCLASS_STR
         self.args['port'] = TEST_MASTER_PORT
         self.args['master'] = TEST_MASTER_IPV4_ADDRESS
         self.args['db_file'] = TEST_DB_FILE
@@ -515,7 +513,7 @@ class TestXfrin(unittest.TestCase):
 
     def test_parse_cmd_params_bogusclass(self):
         self.args['zone_class'] = 'XXX'
-        self.assertRaises(XfrinException, self._do_parse_zone_name_class)
+        self.assertRaises(XfrinZoneInfoException, self._do_parse_zone_name_class)
 
     def test_parse_cmd_params_nozone(self):
         # zone name is mandatory.
@@ -567,17 +565,29 @@ class TestXfrin(unittest.TestCase):
                                                   short_args)['result'][0], 1)
 
     def test_command_handler_retransfer_short_command2(self):
-        # try it when only specifying the zone name (of unknown zone)
-        # this should fail because master address is not specified.
+        # try it when only specifying the zone name (of known zone)
         short_args = {}
         short_args['zone_name'] = TEST_ZONE_NAME_STR
+
+        zones = { 'zones': [
+                  { 'name': TEST_ZONE_NAME_STR,
+                    'master_addr': TEST_MASTER_IPV4_ADDRESS,
+                    'master_port': TEST_MASTER_PORT
+                  }
+                ]}
+        self.xfr.config_handler(zones)
         self.assertEqual(self.xfr.command_handler("retransfer",
-                                                  short_args)['result'][0], 1)
+                                                  short_args)['result'][0], 0)
+        self.assertEqual(TEST_MASTER_IPV4_ADDRESS,
+                         self.xfr.xfrin_started_master_addr)
+        self.assertEqual(int(TEST_MASTER_PORT),
+                         self.xfr.xfrin_started_master_port)
 
     def test_command_handler_retransfer_short_command3(self):
         # try it when only specifying the zone name (of known zone)
         short_args = {}
-        short_args['zone_name'] = TEST_ZONE_NAME_STR
+        # test it without the trailing root dot
+        short_args['zone_name'] = TEST_ZONE_NAME_STR[:-1]
 
         zones = { 'zones': [
                   { 'name': TEST_ZONE_NAME_STR,
@@ -599,7 +609,7 @@ class TestXfrin(unittest.TestCase):
                                                   self.args)['result'][0], 1)
 
     def test_command_handler_retransfer_quota(self):
-        self.args['master'] = '127.0.0.1'
+        self.args['master'] = TEST_MASTER_IPV4_ADDRESS
 
         for i in range(self.xfr._max_transfers_in - 1):
             self.xfr.recorder.increment(Name(str(i) + TEST_ZONE_NAME_STR))
@@ -662,7 +672,7 @@ class TestXfrin(unittest.TestCase):
         # the config
         # This is actually NOT the address given in the command, which
         # would at this point not make sense, see the TODO in
-        # xfrin.py.in:542)
+        # xfrin.py.in Xfrin.command_handler())
         self.assertEqual(TEST_MASTER_IPV4_ADDRESS,
                          self.xfr.xfrin_started_master_addr)
         self.assertEqual(int(TEST_MASTER_PORT),
@@ -677,38 +687,44 @@ class TestXfrin(unittest.TestCase):
         self.assertEqual(self.xfr._max_transfers_in, 3)
 
     def _check_zones_config(self, config_given):
+        if 'transfers_in' in config_given:
+            self.assertEqual(config_given['transfers_in'],
+                             self.xfr._max_transfers_in)
         for zone_config in config_given['zones']:
             zone_name = zone_config['name']
             zone_info = self.xfr._get_zone_info(Name(zone_name), RRClass.IN())
-            self.assertEqual(zone_info.master_addr_str, zone_config['master_addr'])
-            self.assertEqual(zone_info.master_port_str, zone_config['master_port'])
+            self.assertEqual(str(zone_info.master_addr), zone_config['master_addr'])
+            self.assertEqual(zone_info.master_port, zone_config['master_port'])
             if 'tsig_key' in zone_config:
-                self.assertEqual(zone_info.tsig_key_str, zone_config['tsig_key'])
+                self.assertEqual(zone_info.tsig_key.to_text(), TSIGKey(zone_config['tsig_key']).to_text())
             else:
-                self.assertIsNone(zone_info.tsig_key_str)
+                self.assertIsNone(zone_info.tsig_key)
 
     def test_command_handler_zones(self):
-        zones1 = { 'zones': [
-                  { 'name': 'test.example.',
+        zones1 = { 'transfers_in': 3,
+                   'zones': [
+                   { 'name': 'test.example.',
                     'master_addr': '192.0.2.1',
                     'master_port': 53
-                  }
-                ]}
+                   }
+                 ]}
         self.assertEqual(self.xfr.config_handler(zones1)['result'][0], 0)
         self._check_zones_config(zones1)
 
-        zones2 = { 'zones': [
-                  { 'name': 'test.example.',
+        zones2 = { 'transfers_in': 4,
+                   'zones': [
+                   { 'name': 'test.example.',
                     'master_addr': '192.0.2.2',
                     'master_port': 53,
                     'tsig_key': "example.com:SFuWd/q99SzF8Yzd1QbB9g=="
-                  }
-                ]}
+                   }
+                 ]}
         self.assertEqual(self.xfr.config_handler(zones2)['result'][0], 0)
         self._check_zones_config(zones2)
 
         # test that configuring the zone multiple times fails
-        zones = { 'zones': [
+        zones = { 'transfers_in': 5,
+                  'zones': [
                   { 'name': 'test.example.',
                     'master_addr': '192.0.2.1',
                     'master_port': 53

+ 124 - 94
src/bin/xfrin/xfrin.py.in

@@ -1,6 +1,6 @@
 #!@PYTHON@
 
-# Copyright (C) 2011  Internet Systems Consortium.
+# Copyright (C) 2009-2011  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -56,7 +56,14 @@ XFROUT_MODULE_NAME = 'Xfrout'
 ZONE_MANAGER_MODULE_NAME = 'Zonemgr'
 REFRESH_FROM_ZONEMGR = 'refresh_from_zonemgr'
 ZONE_XFRIN_FAILED = 'zone_xfrin_failed'
+
+# These two default are currently hard-coded. For config this isn't
+# necessary, but we need these defaults for optional command arguments
+# (TODO: have similar support to get default values for command
+# arguments as we do for config options)
 DEFAULT_MASTER_PORT = 53
+DEFAULT_ZONE_CLASS = RRClass.IN()
+
 __version__ = 'BIND10'
 # define xfrin rcode
 XFRIN_OK = 0
@@ -68,9 +75,10 @@ def log_error(msg):
 class XfrinException(Exception):
     pass
 
-class XfrinConfigException(Exception):
+class XfrinZoneInfoException(Exception):
     """This exception is raised if there is an error in the given
-       configuration (part), for instance when the zone's master
+       configuration (part), or when a command does not have the
+       required or bad arguments, for instance when the zone's master
        address is not a valid IP address, when the zone does not
        have a name, or when multiple settings are given for the same
        zone."""
@@ -83,26 +91,26 @@ def _check_zone_name(zone_name_str):
         return Name(zone_name_str)
     except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape,
             TooLongName, IncompleteName) as ne:
-        raise XfrinConfigException("bad zone name: " + zone_name_str + " (" + str(ne) + ")")
+        raise XfrinZoneInfoException("bad zone name: " + zone_name_str + " (" + str(ne) + ")")
 
 def _check_zone_class(zone_class_str):
     """If the given argument is a string: checks if the given class is
        a valid one, and returns an RRClass object if so.
-       Raises XfrinConfigException if not.
+       Raises XfrinZoneInfoException if not.
        If it is None, this function returns the default RRClass.IN()"""
     if zone_class_str is None:
-        return RRClass.IN()
+        return DEFAULT_ZONE_CLASS
     try:
         return RRClass(zone_class_str)
     except InvalidRRClass as irce:
-        raise XfrinConfigException("bad zone class: " + zone_class_str + " (" + str(irce) + ")")
+        raise XfrinZoneInfoException("bad zone class: " + zone_class_str + " (" + str(irce) + ")")
 
 class XfrinConnection(asyncore.dispatcher):
     '''Do xfrin in this class. '''
 
     def __init__(self,
                  sock_map, zone_name, rrclass, db_file, shutdown_event,
-                 master_addrinfo, tsig_key_str = None, verbose = False,
+                 master_addrinfo, tsig_key = None, verbose = False,
                  idle_timeout = 60):
         ''' idle_timeout: max idle time for read data from socket.
             db_file: specify the data source file.
@@ -122,8 +130,8 @@ class XfrinConnection(asyncore.dispatcher):
         self._verbose = verbose
         self._master_address = master_addrinfo[2]
         self._tsig_ctx = None
-        if tsig_key_str is not None:
-            self._tsig_ctx = TSIGContext(TSIGKey(tsig_key_str))
+        if tsig_key is not None:
+            self._tsig_ctx = TSIGContext(tsig_key)
 
     def connect_to_master(self):
         '''Connect to master in TCP.'''
@@ -360,12 +368,12 @@ class XfrinConnection(asyncore.dispatcher):
 
 def process_xfrin(server, xfrin_recorder, zone_name, rrclass, db_file,
                   shutdown_event, master_addrinfo, check_soa, verbose,
-                  tsig_key_str):
+                  tsig_key):
     xfrin_recorder.increment(zone_name)
     sock_map = {}
     conn = XfrinConnection(sock_map, zone_name, rrclass, db_file,
                            shutdown_event, master_addrinfo,
-                           tsig_key_str, verbose)
+                           tsig_key, verbose)
     ret = XFRIN_FAIL
     if conn.connect_to_master():
         ret = conn.do_xfrin(check_soa)
@@ -406,58 +414,96 @@ class XfrinRecorder:
         return ret
 
 class ZoneInfo:
-    def __init__(self, config_data, module_cc=None):
+    def __init__(self, config_data, module_cc):
         """Creates a zone_info with the config data element as
            specified by the 'zones' list in xfrin.spec. Module_cc is
            needed to get the defaults from the specification"""
-        self.name_str = config_data.get('name')
-        self.class_str = config_data.get('class') or \
-            module_cc.get_default_value("zones/class")
-
-        if self.name_str is None:
-            raise XfrinConfigException("Configuration zones list "
-                                       "element does not contain "
-                                       "'name' attribute")
-
-        self.master_addr_str = config_data.get('master_addr')
-        self.master_port_str = config_data.get('master_port') or \
-            str(module_cc.get_default_value("zones/master_port"))
-
-        try:
-            self.name = Name(self.name_str)
-        except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape,
-                TooLongName, IncompleteName) as ne:
-            errmsg = "bad zone name: " + self.name_str + " (" + str(ne) + ")"
-            log_error(errmsg)
-            raise XfrinConfigException(errmsg)
-
-        try:
-            self.master_addr = isc.net.parse.addr_parse(self.master_addr_str)
-            self.master_port = isc.net.parse.port_parse(self.master_port_str)
-        except ValueError:
-            errmsg = "bad format for zone's master: " + str(config_data)
-            log_error(errmsg)
-            raise XfrinConfigException(errmsg)
-
-        try:
-            self.rrclass = RRClass(self.class_str)
-        except InvalidRRClass:
-            errmsg = "invalid class: " + self.class_str
-            log_error(errmsg)
-            raise XfrinConfigException(errmsg)
-
-        self.tsig_key_str = config_data.get('tsig_key') or None
-        if self.tsig_key_str is not None:
+        self._module_cc = module_cc
+        self.set_name(config_data.get('name'))
+        self.set_master_addr(config_data.get('master_addr'))
+
+        self.set_master_port(config_data.get('master_port'))
+        self.set_zone_class(config_data.get('class'))
+        self.set_tsig_key(config_data.get('tsig_key'))
+
+    def set_name(self, name_str):
+        """Set the name for this zone given a name string.
+           Raises XfrinZoneInfoException if name_str is None or if it
+           cannot be parsed."""
+        #TODO: remove name_str
+        self.name_str = name_str
+        if name_str is None:
+            raise XfrinZoneInfoException("Configuration zones list "
+                                         "element does not contain "
+                                         "'name' attribute")
+        else:
+            self.name = _check_zone_name(name_str)
+
+    def set_master_addr(self, master_addr_str):
+        """Set the master address for this zone given an IP address
+           string. Raises XfrinZoneInfoException if master_addr_str is
+           None or if it cannot be parsed."""
+        if master_addr_str is None:
+            raise XfrinZoneInfoException("master address missing from config data")
+        else:
+            try:
+                self.master_addr = isc.net.parse.addr_parse(master_addr_str)
+            except ValueError:
+                errmsg = "bad format for zone's master: " + master_addr_str
+                log_error(errmsg)
+                raise XfrinZoneInfoException(errmsg)
+
+    def set_master_port(self, master_port_str):
+        """Set the master port given a port number string. If
+           master_port_str is None, the default from the specification
+           for this module will be used. Raises XfrinZoneInfoException if
+           the string contains an invalid port number"""
+        if master_port_str is None:
+            self.master_port = self._module_cc.get_default_value("zones/master_port")
+        else:
+            try:
+                self.master_port = isc.net.parse.port_parse(master_port_str)
+            except ValueError:
+                errmsg = "bad format for zone's master port: " + master_port_str
+                log_error(errmsg)
+                raise XfrinZoneInfoException(errmsg)
+
+    def set_zone_class(self, zone_class_str):
+        """Set the zone class given an RR class str (e.g. "IN"). If
+           zone_class_str is None, it will default to what is specified
+           in the specification file for this module. Raises
+           XfrinZoneInfoException if the string cannot be parsed."""
+        # TODO: remove _str
+        self.class_str = zone_class_str or self._module_cc.get_default_value("zones/class")
+        if zone_class_str == None:
+            #TODO rrclass->zone_class
+            self.rrclass = RRClass(self._module_cc.get_default_value("zones/class"))
+        else:
+            try:
+                self.rrclass = RRClass(zone_class_str)
+            except InvalidRRClass:
+                errmsg = "invalid zone class: " + zone_class_str
+                log_error(errmsg)
+                raise XfrinZoneInfoException(errmsg)
+
+    def set_tsig_key(self, tsig_key_str):
+        """Set the tsig_key for this zone, given a TSIG key string
+           representation. If tsig_key_str is None, no TSIG key will
+           be set. Raises XfrinZoneInfoException if tsig_key_str cannot
+           be parsed."""
+        if tsig_key_str is None:
+            self.tsig_key = None
+        else:
             try:
-                tsig_key = TSIGKey(self.tsig_key_str)
+                self.tsig_key = TSIGKey(tsig_key_str)
             except InvalidParameter as ipe:
-                errmsg = "bad TSIG key string: " + self.tsig_key_str
+                errmsg = "bad TSIG key string: " + tsig_key_str
                 log_error(errmsg)
-                raise XfrinConfigException(errmsg)
+                raise XfrinZoneInfoException(errmsg)
 
     def get_master_addr_info(self):
         return (self.master_addr.family, socket.SOCK_STREAM,
-                (self.master_addr_str, self.master_port))
+                (str(self.master_addr), self.master_port))
 
 class Xfrin:
     def __init__(self, verbose = False):
@@ -493,46 +539,37 @@ class Xfrin:
         """Returns the ZoneInfo object containing the configured data
            for the given zone name. If the zone name did not have any
            data, returns None"""
-        key = (name.to_text(), rrclass.to_text())
-        if key in self._zones:
-            return self._zones[key]
-        else:
-            return None
-
-    def _get_all_zone_info(self):
-        """Returns the structure used to store ZoneInfo objects. This
-           method can be used (together with _set_all_zone_info()) to
-           revert to the previous zone info configuration when one
-           of the new config items turns out to be bad"""
-        return self._zones
+        return self._zones.get((name.to_text(), rrclass.to_text()))
 
     def _add_zone_info(self, zone_info):
-        """Add the zone info. Raises a XfrinConfigException if a zone
+        """Add the zone info. Raises a XfrinZoneInfoException if a zone
            with the same name and class is already configured"""
         key = (zone_info.name.to_text(), zone_info.class_str)
         if key in self._zones:
-            raise XfrinConfigException("zone " + str(key) +
+            raise XfrinZoneInfoException("zone " + str(key) +
                                        " configured multiple times")
         self._zones[key] = zone_info
 
-    def _set_all_zone_info(self, zones):
-        self._zones = zones
-
     def _clear_zone_info(self):
         self._zones = {}
 
     def config_handler(self, new_config):
+        # backup all config data (should there be a problem in the new
+        # data)
+        old_max_transfers_in = self._max_transfers_in
+        old_zones = self._zones
+        
         self._max_transfers_in = new_config.get("transfers_in") or self._max_transfers_in
 
         if 'zones' in new_config:
-            zones_backup = self._get_all_zone_info()
             self._clear_zone_info()
             for zone_config in new_config.get('zones'):
                 try:
                     zone_info = ZoneInfo(zone_config, self._module_cc)
                     self._add_zone_info(zone_info)
-                except XfrinConfigException as xce:
-                    self._set_all_zone_info(zones_backup)
+                except XfrinZoneInfoException as xce:
+                    self._zones = old_zones
+                    self._max_transfers_in = old_max_transfers_in
                     return create_answer(1, str(xce))
 
         return create_answer(0)
@@ -574,7 +611,7 @@ class Xfrin:
                                            rrclass,
                                            self._get_db_file(),
                                            master_addr,
-                                           zone_info.tsig_key_str,
+                                           zone_info.tsig_key,
                                            True)
                     answer = create_answer(ret[0], ret[1])
 
@@ -585,15 +622,15 @@ class Xfrin:
                 (zone_name, rrclass) = self._parse_zone_name_and_class(args)
                 master_addr = self._parse_master_and_port(args)
                 zone_info = self._get_zone_info(zone_name, rrclass)
-                tsig_key_str = None
+                tsig_key = None
                 if zone_info:
-                    tsig_key_str = zone_info.tsig_key_str
+                    tsig_key = zone_info.tsig_key
                 db_file = args.get('db_file') or self._get_db_file()
                 ret = self.xfrin_start(zone_name,
                                        rrclass,
                                        db_file,
                                        master_addr,
-                                       tsig_key_str,
+                                       tsig_key,
                                        (False if command == 'retransfer' else True))
                 answer = create_answer(ret[0], ret[1])
 
@@ -609,28 +646,20 @@ class Xfrin:
         if not zone_name_str:
             raise XfrinException('zone name should be provided')
 
-        rrclass = args.get('zone_class')
-        if not rrclass:
-            rrclass = RRClass.IN()
-        else:
-            try:
-                rrclass = RRClass(rrclass)
-            except InvalidRRClass as e:
-                raise XfrinException('invalid RRClass: ' + rrclass)
-
-        return _check_zone_name(zone_name_str), rrclass
+        return (_check_zone_name(zone_name_str), _check_zone_class(args.get('zone_class')))
 
     def _parse_master_and_port(self, args):
         # check if we have configured info about this zone, in case
         # port or master are not specified
         zone_name = _check_zone_name(args.get('zone_name'))
-        zone_class = _check_zone_class(args.get('class'))
+        zone_class = _check_zone_class(args.get('zone_class'))
         zone_info = self._get_zone_info(zone_name, zone_class)
 
         master = args.get('master')
         if master is None:
             if zone_info is not None:
-                master = zone_info.master_addr_str
+                # TODO [XX]
+                master = str(zone_info.master_addr)
             else:
                 raise XfrinException("Master address not given or "
                                      "configured for " + zone_name.to_text())
@@ -638,7 +667,8 @@ class Xfrin:
         port = args.get('port')
         if port is None:
             if zone_info is not None:
-                port = zone_info.master_port_str
+                # TODO [XX]
+                port = str(zone_info.master_port)
             else:
                 port = DEFAULT_MASTER_PORT
 
@@ -705,7 +735,7 @@ class Xfrin:
         while not self._shutdown_event.is_set():
             self._cc_check_command()
 
-    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, tsig_key_str,
+    def xfrin_start(self, zone_name, rrclass, db_file, master_addrinfo, tsig_key,
                     check_soa = True):
         if "pydnspp" not in sys.modules:
             return (1, "xfrin failed, can't load dns message python library: 'pydnspp'")
@@ -726,7 +756,7 @@ class Xfrin:
                                                 self._shutdown_event,
                                                 master_addrinfo, check_soa,
                                                 self._verbose,
-                                                tsig_key_str))
+                                                tsig_key))
 
         xfrin_thread.start()
         return (0, 'zone xfrin is started')

+ 82 - 3
src/bin/xfrout/tests/xfrout_test.py.in

@@ -1,4 +1,4 @@
-# Copyright (C) 2010  Internet Systems Consortium.
+# Copyright (C) 2010-2011  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -330,6 +330,7 @@ class MyUnixSockServer(UnixSockServer):
         self._transfers_counter = 0
         self._shutdown_event = threading.Event()
         self._max_transfers_out = 10
+        self._zones = {}
         self._cc = MyCCSession()
         self._log = isc.log.NSLogger('xfrout', '', severity = 'critical', log_to_console = False )
 
@@ -346,9 +347,87 @@ class TestUnixSockServer(unittest.TestCase):
         recv_msg = self.unix._receive_query_message(self.read_sock)
         self.assertEqual(recv_msg, send_msg)
 
+    def _check_config(self, config_data):
+        if 'transfers_out' in config_data:
+            self.assertEqual(config_data['transfers_out'],
+                             self.unix._max_transfers_out)
+        if 'zones' in config_data:
+            for zone_config in config_data['zones']:
+                self.assertIn(zone_config['name'], self.unix._zones)
+                zone_info = self.unix._zones[zone_config['name']]
+                if 'tsig_key' in zone_config:
+                    self.assertEqual(TSIGKey(zone_config['tsig_key']).to_text(),
+                                     zone_info.tsig_key.to_text())
+
     def test_updata_config_data(self):
-        self.unix.update_config_data({'transfers_out':10 })
-        self.assertEqual(self.unix._max_transfers_out, 10)
+        good_config1 = { 'transfers_out': 10,
+                        'zones': [
+                        { 'name': 'example.com.',
+                          'tsig_key': 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
+                        }
+                        ]
+                      }
+        answer = self.unix.update_config_data(good_config1)
+        self.assertEqual(0, parse_answer(answer)[0])
+        self._check_config(good_config1)
+
+        good_config2 = { 'transfers_out': 11,
+                        'zones': [
+                        { 'name': 'example.com.'
+                        },
+                        { 'name': 'example2.com.',
+                          'tsig_key': 'example2.com:SFuWd/q99SzF8Yzd1QbB9g=='
+                        }
+                        ]
+                      }
+        answer = self.unix.update_config_data(good_config2)
+        self.assertEqual(0, parse_answer(answer)[0])
+        self._check_config(good_config2)
+
+        bad_config = { 'transfers_out': 12,
+                        'zones': [
+                        {}
+                        ]
+                      }
+        answer = self.unix.update_config_data(bad_config)
+        self.assertEqual(1, parse_answer(answer)[0])
+        # Should still have the previous config
+        self._check_config(good_config2)
+
+        bad_config = { 'transfers_out': 13,
+                        'zones': [ { 'name': 'example..com.' } ]
+                      }
+        answer = self.unix.update_config_data(bad_config)
+        self.assertEqual(1, parse_answer(answer)[0])
+        # Should still have the previous config
+        self._check_config(good_config2)
+
+        bad_config = { 'transfers_out': 14,
+                        'zones': [
+                        { 'name': 'example.com.',
+                          'tsig_key': '::'
+                        }
+                        ]
+                      }
+        answer = self.unix.update_config_data(bad_config)
+        self.assertEqual(1, parse_answer(answer)[0])
+        # Should still have the previous config
+        self._check_config(good_config2)
+
+        bad_config = { 'transfers_out': 15,
+                        'zones': [
+                        { 'name': 'example.com.',
+                          'tsig_key': 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
+                        },
+                        { 'name': 'example.com.',
+                          'tsig_key': 'example.com:SFuWd/q99SzF8Yzd1QbB9g=='
+                        }
+                        ]
+                      }
+        answer = self.unix.update_config_data(bad_config)
+        self.assertEqual(1, parse_answer(answer)[0])
+        # Should still have the previous config
+        self._check_config(good_config2)
 
     def test_get_db_file(self):
         self.assertEqual(self.unix.get_db_file(), "initdb.file")

+ 64 - 10
src/bin/xfrout/xfrout.py.in

@@ -1,6 +1,6 @@
 #!@PYTHON@
 
-# Copyright (C) 2010  Internet Systems Consortium.
+# Copyright (C) 2010-2011  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -301,10 +301,37 @@ class XfroutSession():
 
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len)
 
+class XfroutZoneInfoException(Exception):
+    """This exception if raised if the given information for ZoneInfo
+       contains an error (i.e. if the given name or tsig key data does
+       not parse correctly, or if information for a zone is found
+       multiple times"""
+    pass
+
 class ZoneInfo:
     def __init__(self, zone_config):
-        self.name = zone_config.get('name')
-        self.tsig_key_str = zone_config.get('tsig_key')
+        self.set_zone_name(zone_config.get('name'))
+        self.set_tsig_key(zone_config.get('tsig_key'))
+
+    def set_zone_name(self, name_str):
+        if name_str is None:
+            raise XfroutZoneInfoException("Must have zone name for xfrout zone info")
+        else:
+            try:
+                self.name = Name(name_str)
+            except (EmptyLabel, TooLongLabel, BadLabelType, BadEscape,
+                    TooLongName, IncompleteName) as ne:
+                raise XfroutZoneInfoException("bad zone name: " + name_str
+                                              + " (" + str(ne) + ")")
+
+    def set_tsig_key(self, tsig_key_str):
+        if tsig_key_str is None:
+            self.tsig_key = None
+        else:
+            try:
+                self.tsig_key = TSIGKey(tsig_key_str)
+            except InvalidParameter as ipe:
+                raise XfroutZoneInfoException("bad TSIG key string: " + tsig_key_str)
 
 class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
@@ -319,6 +346,11 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._shutdown_event = shutdown_event
         self._write_sock, self._read_sock = socket.socketpair()
         self._log = log
+        # these values are directly (re)set by update_config_data,
+        # but the general error recovery there needs something to
+        # be set
+        self._zones = {}
+        self._max_transfers_out = 10
         self.update_config_data(config_data)
         self._cc = cc
 
@@ -451,17 +483,39 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
 
     def update_config_data(self, new_config):
         '''Apply the new config setting of xfrout module. '''
+        old_max_transfers_out = self._max_transfers_out
+        old_zones = self._zones
+        err_msg = None
+
         self._log.log_message('info', 'update config data start.')
         self._lock.acquire()
         self._max_transfers_out = new_config.get('transfers_out')
-        self._log.log_message('info', 'max transfer out : %d', self._max_transfers_out)
         zones = new_config.get('zones')
         if zones is not None:
-            for zone_config in zones:
-                zone_info = ZoneInfo(zone_config)
-                self.zones[zone_info.name] = zone_info
+            self._zones = {}
+            try:
+                for zone_config in zones:
+                    zone_info = ZoneInfo(zone_config)
+                    key = zone_info.name.to_text()
+                    if key in self._zones:
+                        raise XfroutZoneInfoException("zone " + key +
+                                                      " configured multiple times")
+                    self._zones[zone_info.name.to_text()] = zone_info
+            except XfroutZoneInfoException as xzie:
+                err_msg = "Bad zone information: " + str(xzie)
+
+        if err_msg is not None:
+            # restore previous config
+            self._max_transfers_out = old_max_transfers_out
+            self._zones = old_zones
+            answer = create_answer(1, err_msg)
+        else:
+            self._log.log_message('info', 'update config data complete.')
+            self._log.log_message('info', 'max transfer out : %d', self._max_transfers_out)
+            answer = create_answer(0)
+
         self._lock.release()
-        self._log.log_message('info', 'update config data complete.')
+        return answer
 
     def get_db_file(self):
         file, is_default = self._cc.get_remote_config_value("Auth", "database_file")
@@ -523,7 +577,7 @@ class XfroutServer:
         self._notifier.send_notify(zone_name, zone_class)
 
     def config_handler(self, new_config):
-        '''Update config data. TODO. Do error check'''
+        '''Update config data. TODO. Do error check for log_update_config'''
         answer = create_answer(0)
         for key in new_config:
             if key not in self._config_data:
@@ -535,7 +589,7 @@ class XfroutServer:
             self._log.update_config(new_config)
 
         if self._unix_socket_server:
-            self._unix_socket_server.update_config_data(self._config_data)
+            answer = self._unix_socket_server.update_config_data(self._config_data)
 
         return answer