Browse Source

[1165] added _get_transfer_acl to XfroutSession class to retrieve the
best matching ACL for the given zone.

JINMEI Tatuya 13 years ago
parent
commit
e602f86dae
2 changed files with 47 additions and 0 deletions
  1. 28 0
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 19 0
      src/bin/xfrout/xfrout.py.in

+ 28 - 0
src/bin/xfrout/tests/xfrout_test.py.in

@@ -235,6 +235,34 @@ class TestXfroutSession(unittest.TestCase):
             self.xfrsess._acl = acl
         self.check_transfer_acl(acl_setter)
 
+    def test_get_transfer_acl(self):
+        # set the default ACL.  If there's no specific zone ACL, this one
+        # should be used.
+        self.xfrsess._acl = isc.acl.dns.REQUEST_LOADER.load([
+                {"from": "127.0.0.1", "action": "ACCEPT"}])
+        acl = self.xfrsess._get_transfer_acl(Name('example.com'), RRClass.IN())
+        self.assertEqual(acl, self.xfrsess._acl)
+
+        # install a per zone config with transfer ACL for example.com.  Then
+        # that ACL will be used for example.com; for others the default ACL
+        # will still be used.
+        com_acl = isc.acl.dns.REQUEST_LOADER.load([
+                {"from": "127.0.0.1", "action": "REJECT"}])
+        self.xfrsess._zone_config[('example.com.', 'IN')] = {}
+        self.xfrsess._zone_config[('example.com.', 'IN')]['transfer_acl'] = \
+            com_acl
+        self.assertEqual(com_acl,
+                         self.xfrsess._get_transfer_acl(Name('example.com'),
+                                                        RRClass.IN()))
+        self.assertEqual(self.xfrsess._acl,
+                         self.xfrsess._get_transfer_acl(Name('example.org'),
+                                                        RRClass.IN()))
+
+        # Name matching should be case insensitive.
+        self.assertEqual(com_acl,
+                         self.xfrsess._get_transfer_acl(Name('EXAMPLE.COM'),
+                                                        RRClass.IN()))
+
     def test_get_query_zone_name(self):
         msg = self.getmsg()
         self.assertEqual(self.xfrsess._get_query_zone_name(msg), "example.com.")

+ 19 - 0
src/bin/xfrout/xfrout.py.in

@@ -108,6 +108,7 @@ class XfroutSession():
         self._tsig_len = 0
         self._remote = remote
         self._acl = acl
+        self._zone_config = {}
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -171,6 +172,24 @@ class XfroutSession():
 
         return rcode, msg
 
+    def _get_transfer_acl(self, zone_name, zone_class):
+        '''Return the ACL that should be applied for a given zone.
+
+        The zone is identified by a tuple of name and RR class.
+        If a per zone configuration for the zone exists and contains
+        transfer_acl, that ACL will be used; otherwise, the default
+        ACL will be used.
+
+        '''
+        # Internally zone names are managed in lower cased label characters,
+        # so we first need to convert the name.
+        zone_name_lower = Name(zone_name.to_text(), True)
+        config_key = (zone_name_lower.to_text(), zone_class.to_text())
+        if config_key in self._zone_config and \
+                'transfer_acl' in self._zone_config[config_key]:
+            return self._zone_config[config_key]['transfer_acl']
+        return self._acl
+
     def _get_query_zone_name(self, msg):
         question = msg.get_question()[0]
         return question.get_name().to_text()