Browse Source

[] Merge branch 'trac1165'

JINMEI Tatuya 13 years ago
parent
commit
698176eccd

+ 3 - 0
src/bin/xfrout/tests/Makefile.am

@@ -10,6 +10,8 @@ LIBRARY_PATH_PLACEHOLDER += $(ENV_LIBRARY_PATH)=$(abs_top_builddir)/src/lib/cryp
 endif
 endif
 
 
 # test using command-line arguments, so use check-local target instead of TESTS
 # test using command-line arguments, so use check-local target instead of TESTS
+# We set B10_FROM_BUILD below, so that the test can refer to the in-source
+# spec file.
 check-local:
 check-local:
 if ENABLE_PYTHON_COVERAGE
 if ENABLE_PYTHON_COVERAGE
 	touch $(abs_top_srcdir)/.coverage 
 	touch $(abs_top_srcdir)/.coverage 
@@ -19,6 +21,7 @@ endif
 	for pytest in $(PYTESTS) ; do \
 	for pytest in $(PYTESTS) ; do \
 	echo Running test: $$pytest ; \
 	echo Running test: $$pytest ; \
 	chmod +x $(abs_builddir)/$$pytest ; \
 	chmod +x $(abs_builddir)/$$pytest ; \
+	B10_FROM_BUILD=$(abs_top_builddir) \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	$(LIBRARY_PATH_PLACEHOLDER) \
 	PYTHONPATH=$(COMMON_PYTHON_PATH):$(abs_top_builddir)/src/bin/xfrout:$(abs_top_builddir)/src/lib/dns/python/.libs:$(abs_top_builddir)/src/lib/util/io/.libs \
 	PYTHONPATH=$(COMMON_PYTHON_PATH):$(abs_top_builddir)/src/bin/xfrout:$(abs_top_builddir)/src/lib/dns/python/.libs:$(abs_top_builddir)/src/lib/util/io/.libs \
 	$(PYCOVERAGE_RUN) $(abs_builddir)/$$pytest || exit ; \
 	$(PYCOVERAGE_RUN) $(abs_builddir)/$$pytest || exit ; \

+ 166 - 31
src/bin/xfrout/tests/xfrout_test.py.in

@@ -20,6 +20,7 @@ import unittest
 import os
 import os
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.cc.session import *
 from isc.cc.session import *
+import isc.config
 from pydnspp import *
 from pydnspp import *
 from xfrout import *
 from xfrout import *
 import xfrout
 import xfrout
@@ -101,20 +102,24 @@ class TestXfroutSession(unittest.TestCase):
     def message_has_tsig(self, msg):
     def message_has_tsig(self, msg):
         return msg.get_tsig_record() is not None
         return msg.get_tsig_record() is not None
 
 
-    def create_request_data_with_tsig(self):
+    def create_request_data(self, with_tsig=False):
         msg = Message(Message.RENDER)
         msg = Message(Message.RENDER)
         query_id = 0x1035
         query_id = 0x1035
         msg.set_qid(query_id)
         msg.set_qid(query_id)
         msg.set_opcode(Opcode.QUERY())
         msg.set_opcode(Opcode.QUERY())
         msg.set_rcode(Rcode.NOERROR())
         msg.set_rcode(Rcode.NOERROR())
-        query_question = Question(Name("example.com."), RRClass.IN(), RRType.AXFR())
+        query_question = Question(Name("example.com"), RRClass.IN(),
+                                  RRType.AXFR())
         msg.add_question(query_question)
         msg.add_question(query_question)
 
 
         renderer = MessageRenderer()
         renderer = MessageRenderer()
-        tsig_ctx = MockTSIGContext(TSIG_KEY)
-        msg.to_wire(renderer, tsig_ctx)
-        reply_data = renderer.get_data()
-        return reply_data
+        if with_tsig:
+            tsig_ctx = MockTSIGContext(TSIG_KEY)
+            msg.to_wire(renderer, tsig_ctx)
+        else:
+            msg.to_wire(renderer)
+        request_data = renderer.get_data()
+        return request_data
 
 
     def setUp(self):
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
@@ -122,8 +127,9 @@ class TestXfroutSession(unittest.TestCase):
                                        TSIGKeyRing(), ('127.0.0.1', 12345),
                                        TSIGKeyRing(), ('127.0.0.1', 12345),
                                        # When not testing ACLs, simply accept
                                        # When not testing ACLs, simply accept
                                        isc.acl.dns.REQUEST_LOADER.load(
                                        isc.acl.dns.REQUEST_LOADER.load(
-                                           [{"action": "ACCEPT"}]))
-        self.mdata = bytes(b'\xd6=\x00\x00\x00\x01\x00\x00\x00\x00\x00\x00\x07example\x03com\x00\x00\xfc\x00\x01')
+                                           [{"action": "ACCEPT"}]),
+                                       {})
+        self.mdata = self.create_request_data(False)
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
         self.soa_record = (4, 3, 'example.com.', 'com.example.', 3600, 'SOA', None, 'master.example.com. admin.example.com. 1234 3600 1800 2419200 7200')
 
 
     def test_parse_query_message(self):
     def test_parse_query_message(self):
@@ -131,7 +137,7 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(get_rcode.to_text(), "NOERROR")
         self.assertEqual(get_rcode.to_text(), "NOERROR")
 
 
         # tsig signed query message
         # tsig signed query message
-        request_data = self.create_request_data_with_tsig()
+        request_data = self.create_request_data(True)
         # BADKEY
         # BADKEY
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         self.assertEqual(rcode.to_text(), "NOTAUTH")
         self.assertEqual(rcode.to_text(), "NOTAUTH")
@@ -143,8 +149,9 @@ class TestXfroutSession(unittest.TestCase):
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
 
 
+    def check_transfer_acl(self, acl_setter):
         # ACL checks, put some ACL inside
         # ACL checks, put some ACL inside
-        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
             {
             {
                 "from": "127.0.0.1",
                 "from": "127.0.0.1",
                 "action": "ACCEPT"
                 "action": "ACCEPT"
@@ -153,7 +160,7 @@ class TestXfroutSession(unittest.TestCase):
                 "from": "192.0.2.1",
                 "from": "192.0.2.1",
                 "action": "DROP"
                 "action": "DROP"
             }
             }
-        ])
+        ]))
         # Localhost (the default in this test) is accepted
         # Localhost (the default in this test) is accepted
         rcode, msg = self.xfrsess._parse_query_message(self.mdata)
         rcode, msg = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertEqual(rcode.to_text(), "NOERROR")
@@ -165,6 +172,10 @@ class TestXfroutSession(unittest.TestCase):
         self.xfrsess._remote = ('192.0.2.2', 12345)
         self.xfrsess._remote = ('192.0.2.2', 12345)
         rcode, msg = self.xfrsess._parse_query_message(self.mdata)
         rcode, msg = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(rcode.to_text(), "REFUSED")
         self.assertEqual(rcode.to_text(), "REFUSED")
+
+        # TSIG signed request
+        request_data = self.create_request_data(True)
+
         # If the TSIG check fails, it should not check ACL
         # If the TSIG check fails, it should not check ACL
         # (If it checked ACL as well, it would just drop the request)
         # (If it checked ACL as well, it would just drop the request)
         self.xfrsess._remote = ('192.0.2.1', 12345)
         self.xfrsess._remote = ('192.0.2.1', 12345)
@@ -174,36 +185,36 @@ class TestXfroutSession(unittest.TestCase):
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
         self.assertTrue(self.xfrsess._tsig_ctx is not None)
 
 
         # ACL using TSIG: successful case
         # ACL using TSIG: successful case
-        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
             {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
             {"key": "example.com", "action": "ACCEPT"}, {"action": "REJECT"}
-        ])
+        ]))
         self.assertEqual(TSIGKeyRing.SUCCESS,
         self.assertEqual(TSIGKeyRing.SUCCESS,
                          self.xfrsess._tsig_key_ring.add(TSIG_KEY))
                          self.xfrsess._tsig_key_ring.add(TSIG_KEY))
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         self.assertEqual(rcode.to_text(), "NOERROR")
         self.assertEqual(rcode.to_text(), "NOERROR")
 
 
         # ACL using TSIG: key name doesn't match; should be rejected
         # ACL using TSIG: key name doesn't match; should be rejected
-        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
             {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
             {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
-        ])
+        ]))
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         self.assertEqual(rcode.to_text(), "REFUSED")
         self.assertEqual(rcode.to_text(), "REFUSED")
 
 
         # ACL using TSIG: no TSIG; should be rejected
         # ACL using TSIG: no TSIG; should be rejected
-        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
             {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
             {"key": "example.org", "action": "ACCEPT"}, {"action": "REJECT"}
-        ])
+        ]))
         [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
         [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(rcode.to_text(), "REFUSED")
         self.assertEqual(rcode.to_text(), "REFUSED")
 
 
         #
         #
         # ACL using IP + TSIG: both should match
         # ACL using IP + TSIG: both should match
         #
         #
-        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+        acl_setter(isc.acl.dns.REQUEST_LOADER.load([
                 {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
                 {"ALL": [{"key": "example.com"}, {"from": "192.0.2.1"}],
                  "action": "ACCEPT"},
                  "action": "ACCEPT"},
                 {"action": "REJECT"}
                 {"action": "REJECT"}
-        ])
+        ]))
         # both matches
         # both matches
         self.xfrsess._remote = ('192.0.2.1', 12345)
         self.xfrsess._remote = ('192.0.2.1', 12345)
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
         [rcode, msg] = self.xfrsess._parse_query_message(request_data)
@@ -221,6 +232,63 @@ class TestXfroutSession(unittest.TestCase):
         [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
         [rcode, msg] = self.xfrsess._parse_query_message(self.mdata)
         self.assertEqual(rcode.to_text(), "REFUSED")
         self.assertEqual(rcode.to_text(), "REFUSED")
 
 
+    def test_transfer_acl(self):
+        # ACL checks only with the default ACL
+        def acl_setter(acl):
+            self.xfrsess._acl = acl
+        self.check_transfer_acl(acl_setter)
+
+    def test_transfer_zoneacl(self):
+        # ACL check with a per zone ACL + default ACL.  The per zone ACL
+        # should match the queryied zone, so it should be used.
+        def acl_setter(acl):
+            zone_key = ('IN', 'example.com.')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = acl
+            self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+        self.check_transfer_acl(acl_setter)
+
+    def test_transfer_zoneacl_nomatch(self):
+        # similar to the previous one, but the per zone doesn't match the
+        # query.  The default should be used.
+        def acl_setter(acl):
+            zone_key = ('IN', 'example.org.')
+            self.xfrsess._zone_config[zone_key] = {}
+            self.xfrsess._zone_config[zone_key]['transfer_acl'] = \
+                isc.acl.dns.REQUEST_LOADER.load([
+                    {"from": "127.0.0.1", "action": "DROP"}])
+            self.xfrsess._acl = acl
+        self.check_transfer_acl(acl_setter)
+
+    def test_get_transfer_acl(self):
+        # set the default ACL.  If there's no specific zone ACL, this one
+        # should be used.
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                {"from": "127.0.0.1", "action": "ACCEPT"}])
+        acl = self.xfrsess._get_transfer_acl(Name('example.com'), RRClass.IN())
+        self.assertEqual(acl, self.xfrsess._acl)
+
+        # install a per zone config with transfer ACL for example.com.  Then
+        # that ACL will be used for example.com; for others the default ACL
+        # will still be used.
+        com_acl = isc.acl.dns.REQUEST_LOADER.load([
+                {"from": "127.0.0.1", "action": "REJECT"}])
+        self.xfrsess._zone_config[('IN', 'example.com.')] = {}
+        self.xfrsess._zone_config[('IN', 'example.com.')]['transfer_acl'] = \
+            com_acl
+        self.assertEqual(com_acl,
+                         self.xfrsess._get_transfer_acl(Name('example.com'),
+                                                        RRClass.IN()))
+        self.assertEqual(self.xfrsess._acl,
+                         self.xfrsess._get_transfer_acl(Name('example.org'),
+                                                        RRClass.IN()))
+
+        # Name matching should be case insensitive.
+        self.assertEqual(com_acl,
+                         self.xfrsess._get_transfer_acl(Name('EXAMPLE.COM'),
+                                                        RRClass.IN()))
+
     def test_get_query_zone_name(self):
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")
@@ -572,9 +640,11 @@ class TestXfroutSession(unittest.TestCase):
         # and it should not have sent anything else
         # and it should not have sent anything else
         self.assertEqual(0, len(self.sock.sendqueue))
         self.assertEqual(0, len(self.sock.sendqueue))
 
 
-class MyCCSession():
+class MyCCSession(isc.config.ConfigData):
     def __init__(self):
     def __init__(self):
-        pass
+        module_spec = isc.config.module_spec_from_file(
+            xfrout.SPECFILE_LOCATION)
+        ConfigData.__init__(self, module_spec)
 
 
     def get_remote_config_value(self, module_name, identifier):
     def get_remote_config_value(self, module_name, identifier):
         if module_name == "Auth" and identifier == "database_file":
         if module_name == "Auth" and identifier == "database_file":
@@ -586,9 +656,9 @@ class MyCCSession():
 class MyUnixSockServer(UnixSockServer):
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
     def __init__(self):
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
-        self._max_transfers_out = 10
-        self._cc = MyCCSession()
         self._common_init()
         self._common_init()
+        self._cc = MyCCSession()
+        self.update_config_data(self._cc.get_full_config())
 
 
 class TestUnixSockServer(unittest.TestCase):
 class TestUnixSockServer(unittest.TestCase):
     def setUp(self):
     def setUp(self):
@@ -636,17 +706,17 @@ class TestUnixSockServer(unittest.TestCase):
                                              socket.AI_NUMERICHOST)[0][4])
                                              socket.AI_NUMERICHOST)[0][4])
         self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
         self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
 
 
-    def check_loaded_ACL(self):
+    def check_loaded_ACL(self, acl):
         context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
         context = isc.acl.dns.RequestContext(socket.getaddrinfo("127.0.0.1",
                                              1234, 0, socket.SOCK_DGRAM,
                                              1234, 0, socket.SOCK_DGRAM,
                                              socket.IPPROTO_UDP,
                                              socket.IPPROTO_UDP,
                                              socket.AI_NUMERICHOST)[0][4])
                                              socket.AI_NUMERICHOST)[0][4])
-        self.assertEqual(isc.acl.acl.ACCEPT, self.unix._acl.execute(context))
+        self.assertEqual(isc.acl.acl.ACCEPT, acl.execute(context))
         context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
         context = isc.acl.dns.RequestContext(socket.getaddrinfo("192.0.2.1",
                                              1234, 0, socket.SOCK_DGRAM,
                                              1234, 0, socket.SOCK_DGRAM,
                                              socket.IPPROTO_UDP,
                                              socket.IPPROTO_UDP,
                                              socket.AI_NUMERICHOST)[0][4])
                                              socket.AI_NUMERICHOST)[0][4])
-        self.assertEqual(isc.acl.acl.REJECT, self.unix._acl.execute(context))
+        self.assertEqual(isc.acl.acl.REJECT, acl.execute(context))
 
 
     def test_update_config_data(self):
     def test_update_config_data(self):
         self.check_default_ACL()
         self.check_default_ACL()
@@ -671,14 +741,79 @@ class TestUnixSockServer(unittest.TestCase):
         self.assertEqual(self.unix.tsig_key_ring.size(), 0)
         self.assertEqual(self.unix.tsig_key_ring.size(), 0)
 
 
         # Load the ACL
         # Load the ACL
-        self.unix.update_config_data({'query_acl': [{'from': '127.0.0.1',
+        self.unix.update_config_data({'transfer_acl': [{'from': '127.0.0.1',
                                                'action': 'ACCEPT'}]})
                                                'action': 'ACCEPT'}]})
-        self.check_loaded_ACL()
+        self.check_loaded_ACL(self.unix._acl)
         # Pass a wrong data there and check it does not replace the old one
         # Pass a wrong data there and check it does not replace the old one
-        self.assertRaises(isc.acl.acl.LoaderError,
+        self.assertRaises(XfroutConfigError,
+                          self.unix.update_config_data,
+                          {'transfer_acl': ['Something bad']})
+        self.check_loaded_ACL(self.unix._acl)
+
+    def test_zone_config_data(self):
+        # By default, there's no specific zone config
+        self.assertEqual({}, self.unix._zone_config)
+
+        # Adding config for a specific zone.  The config is empty unless
+        # explicitly specified.
+        self.unix.update_config_data({'zone_config':
+                                          [{'origin': 'example.com',
+                                            'class': 'IN'}]})
+        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
+
+        # zone class can be omitted
+        self.unix.update_config_data({'zone_config':
+                                          [{'origin': 'example.com'}]})
+        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
+
+        # zone class, name are stored in the "normalized" form.  class
+        # strings are upper cased, names are down cased.
+        self.unix.update_config_data({'zone_config':
+                                          [{'origin': 'EXAMPLE.com'}]})
+        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
+
+        # invalid zone class, name will result in exceptions
+        self.assertRaises(EmptyLabel,
+                          self.unix.update_config_data,
+                          {'zone_config': [{'origin': 'bad..example'}]})
+        self.assertRaises(InvalidRRClass,
+                          self.unix.update_config_data,
+                          {'zone_config': [{'origin': 'example.com',
+                                            'class': 'badclass'}]})
+
+        # Configuring a couple of more zones
+        self.unix.update_config_data({'zone_config':
+                                          [{'origin': 'example.com'},
+                                           {'origin': 'example.com',
+                                            'class': 'CH'},
+                                           {'origin': 'example.org'}]})
+        self.assertEqual({}, self.unix._zone_config[('IN', 'example.com.')])
+        self.assertEqual({}, self.unix._zone_config[('CH', 'example.com.')])
+        self.assertEqual({}, self.unix._zone_config[('IN', 'example.org.')])
+
+        # Duplicate data: should be rejected with an exception
+        self.assertRaises(XfroutConfigError,
+                          self.unix.update_config_data,
+                          {'zone_config': [{'origin': 'example.com'},
+                                           {'origin': 'example.org'},
+                                           {'origin': 'example.com'}]})
+
+    def test_zone_config_data_with_acl(self):
+        # Similar to the previous test, but with transfer_acl config
+        self.unix.update_config_data({'zone_config':
+                                          [{'origin': 'example.com',
+                                            'transfer_acl':
+                                                [{'from': '127.0.0.1',
+                                                  'action': 'ACCEPT'}]}]})
+        acl = self.unix._zone_config[('IN', 'example.com.')]['transfer_acl']
+        self.check_loaded_ACL(acl)
+
+        # invalid ACL syntax will be rejected with exception
+        self.assertRaises(XfroutConfigError,
                           self.unix.update_config_data,
                           self.unix.update_config_data,
-                          {'query_acl': ['Something bad']})
-        self.check_loaded_ACL()
+                          {'zone_config': [{'origin': 'example.com',
+                                            'transfer_acl':
+                                                [{'action': 'BADACTION'}]}]})
 
 
     def test_get_db_file(self):
     def test_get_db_file(self):
         self.assertEqual(self.unix.get_db_file(), "initdb.file")
         self.assertEqual(self.unix.get_db_file(), "initdb.file")

+ 136 - 46
src/bin/xfrout/xfrout.py.in

@@ -48,11 +48,23 @@ except ImportError as e:
     # must keep running, so we warn about it and move forward.
     # must keep running, so we warn about it and move forward.
     log.error(XFROUT_IMPORT, str(e))
     log.error(XFROUT_IMPORT, str(e))
 
 
-from isc.acl.acl import ACCEPT, REJECT, DROP
+from isc.acl.acl import ACCEPT, REJECT, DROP, LoaderError
 from isc.acl.dns import REQUEST_LOADER
 from isc.acl.dns import REQUEST_LOADER
 
 
 isc.util.process.rename()
 isc.util.process.rename()
 
 
+class XfroutConfigError(Exception):
+    """An exception indicating an error in updating xfrout configuration.
+
+    This exception is raised when the xfrout process encouters an error in
+    handling configuration updates.  Not all syntax error can be caught
+    at the module-CC layer, so xfrout needs to (explicitly or implicitly)
+    validate the given configuration data itself.  When it finds an error
+    it raises this exception (either directly or by converting an exception
+    from other modules) as a unified error in configuration.
+    """
+    pass
+
 def init_paths():
 def init_paths():
     global SPECFILE_PATH
     global SPECFILE_PATH
     global AUTH_SPECFILE_PATH
     global AUTH_SPECFILE_PATH
@@ -79,14 +91,12 @@ init_paths()
 
 
 SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
 SPECFILE_LOCATION = SPECFILE_PATH + "/xfrout.spec"
 AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
 AUTH_SPECFILE_LOCATION = AUTH_SPECFILE_PATH + os.sep + "auth.spec"
-MAX_TRANSFERS_OUT = 10
 VERBOSE_MODE = False
 VERBOSE_MODE = False
 # tsig sign every N axfr packets.
 # tsig sign every N axfr packets.
 TSIG_SIGN_EVERY_NTH = 96
 TSIG_SIGN_EVERY_NTH = 96
 
 
 XFROUT_MAX_MESSAGE_SIZE = 65535
 XFROUT_MAX_MESSAGE_SIZE = 65535
 
 
-
 def get_rrset_len(rrset):
 def get_rrset_len(rrset):
     """Returns the wire length of the given RRset"""
     """Returns the wire length of the given RRset"""
     bytes = bytearray()
     bytes = bytearray()
@@ -96,7 +106,7 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
-                 acl):
+                 default_acl, zone_config):
         self._sock_fd = sock_fd
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._request_data = request_data
         self._server = server
         self._server = server
@@ -104,7 +114,8 @@ class XfroutSession():
         self._tsig_ctx = None
         self._tsig_ctx = None
         self._tsig_len = 0
         self._tsig_len = 0
         self._remote = remote
         self._remote = remote
-        self._acl = acl
+        self._acl = default_acl
+        self._zone_config = zone_config
         self.handle()
         self.handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -140,34 +151,50 @@ class XfroutSession():
         try:
         try:
             msg = Message(Message.PARSE)
             msg = Message(Message.PARSE)
             Message.from_wire(msg, mdata)
             Message.from_wire(msg, mdata)
-
-            # TSIG related checks
-            rcode = self._check_request_tsig(msg, mdata)
-
-            if rcode == Rcode.NOERROR():
-                # ACL checks
-                acl_result = self._acl.execute(
-                    isc.acl.dns.RequestContext(self._remote,
-                                               msg.get_tsig_record()))
-                if acl_result == DROP:
-                    logger.info(XFROUT_QUERY_DROPPED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return None, None
-                elif acl_result == REJECT:
-                    logger.info(XFROUT_QUERY_REJECTED,
-                                self._get_query_zone_name(msg),
-                                self._get_query_zone_class(msg),
-                                self._remote[0], self._remote[1])
-                    return Rcode.REFUSED(), msg
-
-        except Exception as err:
+        except Exception as err: # Exception is too broad
             logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             logger.error(XFROUT_PARSE_QUERY_ERROR, err)
             return Rcode.FORMERR(), None
             return Rcode.FORMERR(), None
 
 
+        # TSIG related checks
+        rcode = self._check_request_tsig(msg, mdata)
+
+        if rcode == Rcode.NOERROR():
+            # ACL checks
+            zone_name = msg.get_question()[0].get_name()
+            zone_class = msg.get_question()[0].get_class()
+            acl = self._get_transfer_acl(zone_name, zone_class)
+            acl_result = acl.execute(
+                isc.acl.dns.RequestContext(self._remote,
+                                           msg.get_tsig_record()))
+            if acl_result == DROP:
+                logger.info(XFROUT_QUERY_DROPPED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return None, None
+            elif acl_result == REJECT:
+                logger.info(XFROUT_QUERY_REJECTED, zone_name, zone_class,
+                            self._remote[0], self._remote[1])
+                return Rcode.REFUSED(), msg
+
         return rcode, msg
         return rcode, msg
 
 
+    def _get_transfer_acl(self, zone_name, zone_class):
+        '''Return the ACL that should be applied for a given zone.
+
+        The zone is identified by a tuple of name and RR class.
+        If a per zone configuration for the zone exists and contains
+        transfer_acl, that ACL will be used; otherwise, the default
+        ACL will be used.
+
+        '''
+        # Internally zone names are managed in lower cased label characters,
+        # so we first need to convert the name.
+        zone_name_lower = Name(zone_name.to_text(), True)
+        config_key = (zone_class.to_text(), zone_name_lower.to_text())
+        if config_key in self._zone_config and \
+                'transfer_acl' in self._zone_config[config_key]:
+            return self._zone_config[config_key]['transfer_acl']
+        return self._acl
+
     def _get_query_zone_name(self, msg):
     def _get_query_zone_name(self, msg):
         question = msg.get_question()[0]
         question = msg.get_question()[0]
         return question.get_name().to_text()
         return question.get_name().to_text()
@@ -384,10 +411,12 @@ class XfroutSession():
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
         self._send_message_with_last_soa(msg, sock_fd, rrset_soa, message_upper_len,
                                          count_since_last_tsig_sign)
                                          count_since_last_tsig_sign)
 
 
-class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
+class UnixSockServer(socketserver_mixin.NoPollMixIn,
+                     ThreadingUnixStreamServer):
     '''The unix domain socket server which accept xfr query sent from auth server.'''
     '''The unix domain socket server which accept xfr query sent from auth server.'''
 
 
-    def __init__(self, sock_file, handle_class, shutdown_event, config_data, cc):
+    def __init__(self, sock_file, handle_class, shutdown_event, config_data,
+                 cc):
         self._remove_unused_sock_file(sock_file)
         self._remove_unused_sock_file(sock_file)
         self._sock_file = sock_file
         self._sock_file = sock_file
         socketserver_mixin.NoPollMixIn.__init__(self)
         socketserver_mixin.NoPollMixIn.__init__(self)
@@ -395,16 +424,15 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         self._shutdown_event = shutdown_event
         self._shutdown_event = shutdown_event
         self._write_sock, self._read_sock = socket.socketpair()
         self._write_sock, self._read_sock = socket.socketpair()
         self._common_init()
         self._common_init()
-        self.update_config_data(config_data)
         self._cc = cc
         self._cc = cc
+        self.update_config_data(config_data)
 
 
     def _common_init(self):
     def _common_init(self):
+        '''Initialization shared with the mock server class used for tests'''
         self._lock = threading.Lock()
         self._lock = threading.Lock()
         self._transfers_counter = 0
         self._transfers_counter = 0
-        # This default value will probably get overwritten by the (same)
-        # default value from the spec file. This is here just to make
-        # sure and to make the default value in tests consistent.
-        self._acl = REQUEST_LOADER.load('[{"action": "ACCEPT"}]')
+        self._zone_config = {}
+        self._acl = None # this will be initialized in update_config_data()
 
 
     def _receive_query_message(self, sock):
     def _receive_query_message(self, sock):
         ''' receive request message from sock'''
         ''' receive request message from sock'''
@@ -482,7 +510,7 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         if not request_data:
         if not request_data:
             return
             return
 
 
-        t = threading.Thread(target = self.finish_request,
+        t = threading.Thread(target=self.finish_request,
                              args = (sock_fd, request_data))
                              args = (sock_fd, request_data))
         if self.daemon_threads:
         if self.daemon_threads:
             t.daemon = True
             t.daemon = True
@@ -506,10 +534,17 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
         return sock.getpeername()
         return sock.getpeername()
 
 
     def finish_request(self, sock_fd, request_data):
     def finish_request(self, sock_fd, request_data):
-        '''Finish one request by instantiating RequestHandlerClass.'''
+        '''Finish one request by instantiating RequestHandlerClass.
+
+        This method creates a XfroutSession object.
+        '''
+        self._lock.acquire()
+        acl = self._acl
+        zone_config = self._zone_config
+        self._lock.release()
         self.RequestHandlerClass(sock_fd, request_data, self,
         self.RequestHandlerClass(sock_fd, request_data, self,
                                  self.tsig_key_ring,
                                  self.tsig_key_ring,
-                                 self._guess_remote(sock_fd), self._acl)
+                                 self._guess_remote(sock_fd), acl, zone_config)
 
 
     def _remove_unused_sock_file(self, sock_file):
     def _remove_unused_sock_file(self, sock_file):
         '''Try to remove the socket file. If the file is being used
         '''Try to remove the socket file. If the file is being used
@@ -551,16 +586,65 @@ class UnixSockServer(socketserver_mixin.NoPollMixIn, ThreadingUnixStreamServer):
             pass
             pass
 
 
     def update_config_data(self, new_config):
     def update_config_data(self, new_config):
-        '''Apply the new config setting of xfrout module. '''
-        logger.info(XFROUT_NEW_CONFIG)
-        if 'query_acl' in new_config:
-            self._acl = REQUEST_LOADER.load(new_config['query_acl'])
+        '''Apply the new config setting of xfrout module.
+
+        '''
         self._lock.acquire()
         self._lock.acquire()
-        self._max_transfers_out = new_config.get('transfers_out')
-        self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
+        try:
+            logger.info(XFROUT_NEW_CONFIG)
+            new_acl = self._acl
+            if 'transfer_acl' in new_config:
+                try:
+                    new_acl = REQUEST_LOADER.load(new_config['transfer_acl'])
+                except LoaderError as e:
+                    raise XfroutConfigError('Failed to parse transfer_acl: ' +
+                                            str(e))
+
+            new_zone_config = self._zone_config
+            zconfig_data = new_config.get('zone_config')
+            if zconfig_data is not None:
+                new_zone_config = self.__create_zone_config(zconfig_data)
+
+            self._acl = new_acl
+            self._zone_config = new_zone_config
+            self._max_transfers_out = new_config.get('transfers_out')
+            self.set_tsig_key_ring(new_config.get('tsig_key_ring'))
+        except Exception as e:
+            self._lock.release()
+            raise e
         self._lock.release()
         self._lock.release()
         logger.info(XFROUT_NEW_CONFIG_DONE)
         logger.info(XFROUT_NEW_CONFIG_DONE)
 
 
+    def __create_zone_config(self, zone_config_list):
+        new_config = {}
+        for zconf in zone_config_list:
+            # convert the class, origin (name) pair.  First build pydnspp
+            # object to reject invalid input.
+            zclass_str = zconf.get('class')
+            if zclass_str is None:
+                #zclass_str = 'IN' # temporary
+                zclass_str = self._cc.get_default_value('zone_config/class')
+            zclass = RRClass(zclass_str)
+            zorigin = Name(zconf['origin'], True)
+            config_key = (zclass.to_text(), zorigin.to_text())
+
+            # reject duplicate config
+            if config_key in new_config:
+                raise XfroutConfigError('Duplicate zone_config for ' +
+                                        str(zorigin) + '/' + str(zclass))
+
+            # create a new config entry, build any given (and known) config
+            new_config[config_key] = {}
+            if 'transfer_acl' in zconf:
+                try:
+                    new_config[config_key]['transfer_acl'] = \
+                        REQUEST_LOADER.load(zconf['transfer_acl'])
+                except LoaderError as e:
+                    raise XfroutConfigError('Failed to parse transfer_acl ' +
+                                            'for ' + zorigin.to_text() + '/' +
+                                            zclass_str + ': ' + str(e))
+        return new_config
+
     def set_tsig_key_ring(self, key_list):
     def set_tsig_key_ring(self, key_list):
         """Set the tsig_key_ring , given a TSIG key string list representation. """
         """Set the tsig_key_ring , given a TSIG key string list representation. """
 
 
@@ -617,8 +701,10 @@ class XfroutServer:
 
 
     def _start_xfr_query_listener(self):
     def _start_xfr_query_listener(self):
         '''Start a new thread to accept xfr query. '''
         '''Start a new thread to accept xfr query. '''
-        self._unix_socket_server = UnixSockServer(self._listen_sock_file, XfroutSession,
-                                                  self._shutdown_event, self._config_data,
+        self._unix_socket_server = UnixSockServer(self._listen_sock_file,
+                                                  XfroutSession,
+                                                  self._shutdown_event,
+                                                  self._config_data,
                                                   self._cc)
                                                   self._cc)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener = threading.Thread(target=self._unix_socket_server.serve_forever)
         listener.start()
         listener.start()
@@ -726,6 +812,10 @@ if '__main__' == __name__:
         logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
         logger.INFO(XFROUT_STOPPED_BY_KEYBOARD)
     except SessionError as e:
     except SessionError as e:
         logger.error(XFROUT_CC_SESSION_ERROR, str(e))
         logger.error(XFROUT_CC_SESSION_ERROR, str(e))
+    except ModuleCCSessionError as e:
+        logger.error(XFROUT_MODULECC_SESSION_ERROR, str(e))
+    except XfroutConfigError as e:
+        logger.error(XFROUT_CONFIG_ERROR, str(e))
     except SessionTimeout as e:
     except SessionTimeout as e:
         logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
         logger.error(XFROUT_CC_SESSION_TIMEOUT_ERROR)
 
 

+ 40 - 1
src/bin/xfrout/xfrout.spec.pre.in

@@ -51,7 +51,7 @@
          }
          }
        },
        },
        {
        {
-         "item_name": "query_acl",
+         "item_name": "transfer_acl",
          "item_type": "list",
          "item_type": "list",
          "item_optional": false,
          "item_optional": false,
          "item_default": [{"action": "ACCEPT"}],
          "item_default": [{"action": "ACCEPT"}],
@@ -61,6 +61,45 @@
              "item_type": "any",
              "item_type": "any",
              "item_optional": true
              "item_optional": true
          }
          }
+       },
+       {
+         "item_name": "zone_config",
+         "item_type": "list",
+         "item_optional": true,
+         "item_default": [],
+         "list_item_spec":
+         {
+             "item_name": "zone_config_element",
+             "item_type": "map",
+             "item_optional": true,
+             "item_default": { "origin": "" },
+             "map_item_spec": [
+               {
+                   "item_name": "origin",
+                   "item_type": "string",
+                   "item_optional": false,
+                   "item_default": ""
+               },
+               {
+                   "item_name": "class",
+                   "item_type": "string",
+                   "item_optional": false,
+                   "item_default": "IN"
+               },
+               {
+                   "item_name": "transfer_acl",
+                   "item_type": "list",
+                   "item_optional": true,
+                   "item_default": [{"action": "ACCEPT"}],
+                   "list_item_spec":
+                   {
+                       "item_name": "acl_element",
+                       "item_type": "any",
+                       "item_optional": true
+                   }
+               }
+             ]
+         }
        }
        }
       ],
       ],
       "commands": [
       "commands": [

+ 11 - 0
src/bin/xfrout/xfrout_messages.mes

@@ -47,6 +47,17 @@ a valid TSIG key.
 There was a problem reading from the command and control channel. The
 There was a problem reading from the command and control channel. The
 most likely cause is that the msgq daemon is not running.
 most likely cause is that the msgq daemon is not running.
 
 
+% XFROUT_MODULECC_SESSION_ERROR error encountered by configuration/command module: %1
+There was a problem in the lower level module handling configuration and
+control commands.  This could happen for various reasons, but the most likely
+cause is that the configuration database contains a syntax error and xfrout
+failed to start at initialization.  A detailed error message from the module
+will also be displayed.
+
+% XFROUT_CONFIG_ERROR error found in configuration data: %1
+The xfrout process encountered an error when installing the configuration at
+startup time.  Details of the error are included in the log message.
+
 % XFROUT_CC_SESSION_TIMEOUT_ERROR timeout waiting for cc response
 % XFROUT_CC_SESSION_TIMEOUT_ERROR timeout waiting for cc response
 There was a problem reading a response from another module over the
 There was a problem reading a response from another module over the
 command and control channel. The most likely cause is that the
 command and control channel. The most likely cause is that the