Browse Source

[1372] updated _check_xfrout_available to support IXFR. many corner cases
are still ignored.

JINMEI Tatuya 13 years ago
parent
commit
5c92f567d9
2 changed files with 118 additions and 25 deletions
  1. 64 7
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 54 18
      src/bin/xfrout/xfrout.py.in

+ 64 - 7
src/bin/xfrout/tests/xfrout_test.py.in

@@ -30,6 +30,20 @@ import isc.acl.dns
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 
 
+#
+# Commonly used (mostly constant) test parameters
+#
+TEST_ZONE_NAME_STR = "example.com."
+TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)
+TEST_RRCLASS = RRClass.IN()
+
+# SOA intended to be used for the new SOA as a result of transfer.
+soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
+                  'master.example.com. admin.example.com ' +
+                  '1234 3600 1800 2419200 7200')
+soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(), RRTTL(3600))
+soa_rrset.add_rdata(soa_rdata)
+
 # our fake socket, where we can read and insert messages
 # our fake socket, where we can read and insert messages
 class MySocket():
 class MySocket():
     def __init__(self, family, type):
     def __init__(self, family, type):
@@ -69,6 +83,30 @@ class MockDataSrcClient:
     def __init__(self, type, config):
     def __init__(self, type, config):
         pass
         pass
 
 
+    def find_zone(self, zone_name):
+        '''Mock version of find_zone().
+
+        It returns itself (subsequently acting as a mock ZoneFinder) for
+        some test zone names.  For some others it returns either NOTFOUND
+        or PARTIALMATCH.
+
+        '''
+        if zone_name == TEST_ZONE_NAME:
+            return (isc.datasrc.DataSourceClient.SUCCESS, self)
+        raise ValueError('Unexpected input to mock client: bug in test case?')
+
+    def find(self, name, rrtype, target, options):
+        '''Mock ZoneFinder.find().
+
+        It returns the predefined SOA RRset to queries for SOA of the common
+        test zone name.  It also emulates some unusual cases for special
+        zone names.
+
+        '''
+        if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
+            return (ZoneFinder.SUCCESS, soa_rrset)
+        raise ValueError('Unexpected input to mock finder: bug in test case?')
+
     def get_iterator(self, zone_name, adjust_ttl=False):
     def get_iterator(self, zone_name, adjust_ttl=False):
         if zone_name == Name('notauth.example.com'):
         if zone_name == Name('notauth.example.com'):
             raise isc.datasrc.Error('no such zone')
             raise isc.datasrc.Error('no such zone')
@@ -91,6 +129,9 @@ class MockDataSrcClient:
                                       '3600 1800 2419200 7200'))
                                       '3600 1800 2419200 7200'))
         return soa_rrset
         return soa_rrset
 
 
+    def get_journal_reader(self, zone_name, begin_serial, end_serial):
+        return isc.datasrc.ZoneJournalReader.SUCCESS, self
+
 class MyCCSession(isc.config.ConfigData):
 class MyCCSession(isc.config.ConfigData):
     def __init__(self):
     def __init__(self):
         module_spec = isc.config.module_spec_from_file(
         module_spec = isc.config.module_spec_from_file(
@@ -195,6 +236,13 @@ class TestXfroutSessionBase(unittest.TestCase):
         request_data = renderer.get_data()
         request_data = renderer.get_data()
         return request_data
         return request_data
 
 
+    def set_request_type(self, type):
+        self.xfrsess._request_type = type
+        if type == RRType.AXFR():
+            self.xfrsess._request_typestr = 'AXFR'
+        else:
+            self.xfrsess._request_typestr = 'IXFR'
+
     def setUp(self):
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
         self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
@@ -205,6 +253,7 @@ class TestXfroutSessionBase(unittest.TestCase):
                                        isc.acl.dns.REQUEST_LOADER.load(
                                        isc.acl.dns.REQUEST_LOADER.load(
                                            [{"action": "ACCEPT"}]),
                                            [{"action": "ACCEPT"}]),
                                        {})
                                        {})
+        self.set_request_type(RRType.AXFR()) # test AXFR by default
         self.mdata = self.create_request_data()
         self.mdata = self.create_request_data()
         self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
         self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
                                RRTTL(3600))
                                RRTTL(3600))
@@ -612,16 +661,24 @@ class TestXfroutSession(TestXfroutSessionBase):
     def test_get_rrset_len(self):
     def test_get_rrset_len(self):
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
 
 
-    def test_check_xfrout_available(self):
+    def test_check_xfrout_axfr_available(self):
         self.xfrsess.ClientClass = MockDataSrcClient
         self.xfrsess.ClientClass = MockDataSrcClient
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.NOERROR())
+                self.getmsg(), Name('example.com')), Rcode.NOERROR())
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                self.getmsg(), Name('notauth.example.com')), Rcode.NOTAUTH())
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('notauth.example.com')), Rcode.NOTAUTH())
+                self.getmsg(), Name('nosoa.example.com')), Rcode.SERVFAIL())
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('nosoa.example.com')), Rcode.SERVFAIL())
+                self.getmsg(), Name('multisoa.example.com')), Rcode.SERVFAIL())
+
+    def test_check_xfrout_ixfr_available(self):
+        self.xfrsess.ClientClass = MockDataSrcClient
+        self.set_request_type(RRType.IXFR())
+        self.mdata = self.create_request_data(ixfr=2011111802)
+        request_msg = self.getmsg()
         self.assertEqual(self.xfrsess._check_xfrout_available(
         self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('multisoa.example.com')), Rcode.SERVFAIL())
+                self.getmsg(), Name('example.com')), Rcode.NOERROR())
 
 
     def test_dns_xfrout_start_formerror(self):
     def test_dns_xfrout_start_formerror(self):
         # formerror
         # formerror
@@ -633,7 +690,7 @@ class TestXfroutSession(TestXfroutSessionBase):
         return "example.com"
         return "example.com"
 
 
     def test_dns_xfrout_start_notauth(self):
     def test_dns_xfrout_start_notauth(self):
-        def notauth(formpara):
+        def notauth(msg, name):
             return Rcode.NOTAUTH()
             return Rcode.NOTAUTH()
         self.xfrsess._check_xfrout_available = notauth
         self.xfrsess._check_xfrout_available = notauth
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
@@ -648,7 +705,7 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())
         self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())
 
 
     def test_dns_xfrout_start_noerror(self):
     def test_dns_xfrout_start_noerror(self):
-        def noerror(form):
+        def noerror(msg, name):
             return Rcode.NOERROR()
             return Rcode.NOERROR()
         self.xfrsess._check_xfrout_available = noerror
         self.xfrsess._check_xfrout_available = noerror
 
 

+ 54 - 18
src/bin/xfrout/xfrout.py.in

@@ -22,7 +22,7 @@ import isc.cc
 import threading
 import threading
 import struct
 import struct
 import signal
 import signal
-from isc.datasrc import DataSourceClient
+from isc.datasrc import DataSourceClient, ZoneFinder
 from socketserver import *
 from socketserver import *
 import os
 import os
 from isc.config.ccsession import *
 from isc.config.ccsession import *
@@ -132,6 +132,11 @@ def get_rrset_len(rrset):
     rrset.to_wire(bytes)
     rrset.to_wire(bytes)
     return len(bytes)
     return len(bytes)
 
 
+def get_soa_serial(soa_rdata):
+    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+    (borrowed from xfrin)
+    '''
+    return int(soa_rdata.to_text().split()[2])
 
 
 class XfroutSession():
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
@@ -308,7 +313,22 @@ class XfroutSession():
         msg.set_rcode(rcode_)
         msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
-    def _check_xfrout_available(self, zone_name):
+    def _get_zone_soa(self, zone_name):
+        result, finder = self._datasrc_client.find_zone(zone_name)
+        if result != DataSourceClient.SUCCESS:
+            return None         # XXX
+        result, soa_rrset = finder.find(zone_name, RRType.SOA(),
+                                        None, ZoneFinder.FIND_DEFAULT)
+        if result != ZoneFinder.SUCCESS:
+            return None
+        # Especially for database-based zones, a working zone may be in
+        # a broken state where it has more than one SOA RR.  We proactively
+        # check the condition and abort the xfr attempt if we identify it.
+        if soa_rrset.get_rdata_count() != 1:
+            return None
+        return soa_rrset
+
+    def _check_xfrout_available(self, request_msg, zone_name):
         '''Check if xfr request can be responsed.
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
            TODO, Get zone's configuration from cfgmgr or some other place
            eg. check allow_transfer setting,
            eg. check allow_transfer setting,
@@ -325,25 +345,41 @@ class XfroutSession():
         datasrc_config = '{ "database_file": "' + \
         datasrc_config = '{ "database_file": "' + \
             self._server.get_db_file() + '"}'
             self._server.get_db_file() + '"}'
         self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
         self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
-        try:
-            # Note that we disable 'adjust_ttl'.  In xfr-out we need to
-            # preserve as many things as possible (even if it's half broken)
-            # stored in the zone.
-            self._iterator = self._datasrc_client.get_iterator(zone_name,
-                                                               False)
-        except isc.datasrc.Error:
-            # If the current name server does not have authority for the
-            # zone, xfrout can't serve for it, return rcode NOTAUTH.
-            # Note: this exception can happen for other reasons.  We should
-            # update get_iterator() API so that we can distinguish "no such
-            # zone" and other cases (#1373).  For now we consider all these
-            # cases as NOTAUTH.
-            return Rcode.NOTAUTH()
+
+        if self._request_type == RRType.AXFR():
+            try:
+                # Note that we disable 'adjust_ttl'.  In xfr-out we need to
+                # preserve as many things as possible (even if it's half
+                # broken) stored in the zone.
+                self._iterator = self._datasrc_client.get_iterator(zone_name,
+                                                                   False)
+            except isc.datasrc.Error:
+                # If the current name server does not have authority for the
+                # zone, xfrout can't serve for it, return rcode NOTAUTH.
+                # Note: this exception can happen for other reasons.  We should
+                # update get_iterator() API so that we can distinguish "no such
+                # zone" and other cases (#1373).  For now we consider all these
+                # cases as NOTAUTH.
+                return Rcode.NOTAUTH()
+
+            self._soa = self._iterator.get_soa()
+        else:
+            # TODO: error case handling
+            remote_soa = None
+            for auth_rrset in \
+                    request_msg.get_section(Message.SECTION_AUTHORITY):
+                if auth_rrset.get_type() != RRType.SOA():
+                    continue
+                remote_soa = auth_rrset
+            self._soa = self._get_zone_soa(remote_soa.get_name())
+            code, self._jnl_reader = self._datasrc_client.get_journal_reader(
+                remote_soa.get_name(),
+                get_soa_serial(remote_soa.get_rdata()[0]),
+                get_soa_serial(self._soa.get_rdata()[0]))
 
 
         # If we are an authoritative name server for the zone, but fail
         # If we are an authoritative name server for the zone, but fail
         # to find the zone's SOA record in datasource, xfrout can't
         # to find the zone's SOA record in datasource, xfrout can't
         # provide zone transfer for it.
         # provide zone transfer for it.
-        self._soa = self._iterator.get_soa()
         if self._soa is None or self._soa.get_rdata_count() != 1:
         if self._soa is None or self._soa.get_rdata_count() != 1:
             return Rcode.SERVFAIL()
             return Rcode.SERVFAIL()
 
 
@@ -374,7 +410,7 @@ class XfroutSession():
 
 
         # TODO: we should also include class in the check
         # TODO: we should also include class in the check
         try:
         try:
-            rcode_ = self._check_xfrout_available(zone_name)
+            rcode_ = self._check_xfrout_available(msg, zone_name)
         except Exception as ex:
         except Exception as ex:
             logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
             logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
                          format_addrinfo(self._remote), zone_str, ex)
                          format_addrinfo(self._remote), zone_str, ex)