Browse Source

[1371] Merge branch 'trac1372' into trac1371 with fixing conflicts.

JINMEI Tatuya 13 years ago
parent
commit
1d4541dfd0

BIN
src/bin/xfrout/tests/testdata/test.sqlite3


+ 245 - 35
src/bin/xfrout/tests/xfrout_test.py.in

@@ -30,6 +30,22 @@ import isc.acl.dns
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 TSIG_KEY = TSIGKey("example.com:SFuWd/q99SzF8Yzd1QbB9g==")
 
 
+#
+# Commonly used (mostly constant) test parameters
+#
+TEST_ZONE_NAME_STR = "example.com."
+TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)
+TEST_RRCLASS = RRClass.IN()
+IXFR_OK_VERSION = 2011111802
+IXFR_NG_VERSION = 2011112800
+
+# SOA intended to be used for the new SOA as a result of transfer.
+soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
+                  'master.example.com. admin.example.com ' +
+                  '1234 3600 1800 2419200 7200')
+soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(), RRTTL(3600))
+soa_rrset.add_rdata(soa_rdata)
+
 # our fake socket, where we can read and insert messages
 # our fake socket, where we can read and insert messages
 class MySocket():
 class MySocket():
     def __init__(self, family, type):
     def __init__(self, family, type):
@@ -69,6 +85,47 @@ class MockDataSrcClient:
     def __init__(self, type, config):
     def __init__(self, type, config):
         pass
         pass
 
 
+    def __create_soa(self):
+        soa_rrset = RRset(self._zone_name, RRClass.IN(), RRType.SOA(),
+                          RRTTL(3600))
+        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                                  'master.example.com. ' +
+                                  'admin.example.com. 1234 ' +
+                                  '3600 1800 2419200 7200'))
+        return soa_rrset
+
+    def find_zone(self, zone_name):
+        '''Mock version of find_zone().
+
+        It returns itself (subsequently acting as a mock ZoneFinder) for
+        some test zone names.  For some others it returns either NOTFOUND
+        or PARTIALMATCH.
+
+        '''
+        self._zone_name = zone_name
+        if zone_name == Name('notauth.example.com'):
+            return (isc.datasrc.DataSourceClient.NOTFOUND, None)
+        return (isc.datasrc.DataSourceClient.SUCCESS, self)
+
+    def find(self, name, rrtype, target, options):
+        '''Mock ZoneFinder.find().
+
+        It returns the predefined SOA RRset to queries for SOA of the common
+        test zone name.  It also emulates some unusual cases for special
+        zone names.
+
+        '''
+        if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
+            return (ZoneFinder.SUCCESS, self.__create_soa())
+        elif name == Name('nosoa.example.com') and rrtype == RRType.SOA():
+            return (ZoneFinder.NXDOMAIN, None)
+        elif name == Name('multisoa.example.com') and rrtype == RRType.SOA():
+            soa_rrset = self.__create_soa()
+            soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
+            return (ZoneFinder.SUCCESS, soa_rrset)
+        else:
+            return (ZoneFinder.SUCCESS, self.__create_soa())
+
     def get_iterator(self, zone_name, adjust_ttl=False):
     def get_iterator(self, zone_name, adjust_ttl=False):
         if zone_name == Name('notauth.example.com'):
         if zone_name == Name('notauth.example.com'):
             raise isc.datasrc.Error('no such zone')
             raise isc.datasrc.Error('no such zone')
@@ -78,19 +135,20 @@ class MockDataSrcClient:
     def get_soa(self):  # emulate ZoneIterator.get_soa()
     def get_soa(self):  # emulate ZoneIterator.get_soa()
         if self._zone_name == Name('nosoa.example.com'):
         if self._zone_name == Name('nosoa.example.com'):
             return None
             return None
-        soa_rrset = RRset(self._zone_name, RRClass.IN(), RRType.SOA(),
-                          RRTTL(3600))
-        soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                  'master.example.com. ' +
-                                  'admin.example.com. 1234 ' +
-                                  '3600 1800 2419200 7200'))
+        soa_rrset = self.__create_soa()
         if self._zone_name == Name('multisoa.example.com'):
         if self._zone_name == Name('multisoa.example.com'):
-            soa_rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
-                                      'master.example.com. ' +
-                                      'admin.example.com. 1300 ' +
-                                      '3600 1800 2419200 7200'))
+            soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
         return soa_rrset
         return soa_rrset
 
 
+    def get_journal_reader(self, zone_name, begin_serial, end_serial):
+        if zone_name == Name('notauth2.example.com'):
+            return isc.datasrc.ZoneJournalReader.NO_SUCH_ZONE, None
+        if zone_name == Name('nojournal.example.com'):
+            raise isc.datasrc.NotImplemented('journaling not supported')
+        if begin_serial == IXFR_NG_VERSION:
+            return isc.datasrc.ZoneJournalReader.NO_SUCH_VERSION, None
+        return isc.datasrc.ZoneJournalReader.SUCCESS, self
+
 class MyCCSession(isc.config.ConfigData):
 class MyCCSession(isc.config.ConfigData):
     def __init__(self):
     def __init__(self):
         module_spec = isc.config.module_spec_from_file(
         module_spec = isc.config.module_spec_from_file(
@@ -159,15 +217,44 @@ class TestXfroutSessionBase(unittest.TestCase):
     def message_has_tsig(self, msg):
     def message_has_tsig(self, msg):
         return msg.get_tsig_record() is not None
         return msg.get_tsig_record() is not None
 
 
-    def create_request_data(self, with_question=True, with_tsig=False):
+    def create_request_data(self, with_question=True, with_tsig=False,
+                            ixfr=None, qtype=None, zone_name=TEST_ZONE_NAME,
+                            soa_class=TEST_RRCLASS, num_soa=1):
+        '''Create a commonly used XFR request data.
+
+        By default the request type is AXFR; if 'ixfr' is an integer,
+        the request type will be IXFR and an SOA with the serial being
+        the value of the parameter will be included in the authority
+        section.
+
+        This method has various minor parameters only for creating bad
+        format requests for testing purposes:
+        qtype: the RR type of the question section.  By default automatically
+               determined by the value of ixfr, but could be an invalid type
+               for testing.
+        zone_name: the query (zone) name.  for IXFR, it's also used as
+                   the owner name of the SOA in the authority section.
+        soa_class: IXFR only.  The RR class of the SOA RR in the authority
+                   section.
+        num_soa: IXFR only.  The number of SOA RDATAs  in the authority
+                 section.
+        '''
         msg = Message(Message.RENDER)
         msg = Message(Message.RENDER)
         query_id = 0x1035
         query_id = 0x1035
         msg.set_qid(query_id)
         msg.set_qid(query_id)
         msg.set_opcode(Opcode.QUERY())
         msg.set_opcode(Opcode.QUERY())
         msg.set_rcode(Rcode.NOERROR())
         msg.set_rcode(Rcode.NOERROR())
+        req_type = RRType.AXFR() if ixfr is None else RRType.IXFR()
         if with_question:
         if with_question:
-            msg.add_question(Question(Name("example.com"), RRClass.IN(),
-                                      RRType.AXFR()))
+            msg.add_question(Question(zone_name, RRClass.IN(),
+                                      req_type if qtype is None else qtype))
+        if req_type == RRType.IXFR():
+            soa = RRset(zone_name, soa_class, RRType.SOA(), RRTTL(0))
+            # In the RDATA only the serial matters.
+            for i in range(0, num_soa):
+                soa.add_rdata(Rdata(RRType.SOA(), soa_class,
+                                    'm r ' + str(ixfr) + ' 1 1 1 1'))
+            msg.add_rrset(Message.SECTION_AUTHORITY, soa)
 
 
         renderer = MessageRenderer()
         renderer = MessageRenderer()
         if with_tsig:
         if with_tsig:
@@ -178,6 +265,13 @@ class TestXfroutSessionBase(unittest.TestCase):
         request_data = renderer.get_data()
         request_data = renderer.get_data()
         return request_data
         return request_data
 
 
+    def set_request_type(self, type):
+        self.xfrsess._request_type = type
+        if type == RRType.AXFR():
+            self.xfrsess._request_typestr = 'AXFR'
+        else:
+            self.xfrsess._request_typestr = 'IXFR'
+
     def setUp(self):
     def setUp(self):
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.sock = MySocket(socket.AF_INET,socket.SOCK_STREAM)
         self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
         self.xfrsess = MyXfroutSession(self.sock, None, Dbserver(),
@@ -188,6 +282,7 @@ class TestXfroutSessionBase(unittest.TestCase):
                                        isc.acl.dns.REQUEST_LOADER.load(
                                        isc.acl.dns.REQUEST_LOADER.load(
                                            [{"action": "ACCEPT"}]),
                                            [{"action": "ACCEPT"}]),
                                        {})
                                        {})
+        self.set_request_type(RRType.AXFR()) # test AXFR by default
         self.mdata = self.create_request_data()
         self.mdata = self.create_request_data()
         self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
         self.soa_rrset = RRset(Name('example.com'), RRClass.IN(), RRType.SOA(),
                                RRTTL(3600))
                                RRTTL(3600))
@@ -222,7 +317,7 @@ class TestXfroutSession(TestXfroutSessionBase):
         # set up a bogus request, which should result in FORMERR. (it only
         # set up a bogus request, which should result in FORMERR. (it only
         # has to be something that is different from the previous case)
         # has to be something that is different from the previous case)
         self.xfrsess._request_data = \
         self.xfrsess._request_data = \
-            self.create_request_data(with_question=False)
+            self.create_request_data(ixfr=IXFR_OK_VERSION, num_soa=2)
         # Replace the data source client to avoid datasrc related exceptions
         # Replace the data source client to avoid datasrc related exceptions
         self.xfrsess.ClientClass = MockDataSrcClient
         self.xfrsess.ClientClass = MockDataSrcClient
         XfroutSession._handle(self.xfrsess)
         XfroutSession._handle(self.xfrsess)
@@ -241,13 +336,24 @@ class TestXfroutSession(TestXfroutSessionBase):
         XfroutSession._handle(self.xfrsess)
         XfroutSession._handle(self.xfrsess)
 
 
     def test_parse_query_message(self):
     def test_parse_query_message(self):
+        # Valid AXFR
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
         [get_rcode, get_msg] = self.xfrsess._parse_query_message(self.mdata)
+        self.assertEqual(RRType.AXFR(), self.xfrsess._request_type)
         self.assertEqual(get_rcode.to_text(), "NOERROR")
         self.assertEqual(get_rcode.to_text(), "NOERROR")
 
 
-        # Broken request: no question
-        request_data = self.create_request_data(with_question=False)
+        # Valid IXFR
+        request_data = self.create_request_data(ixfr=2011111801)
         rcode, msg = self.xfrsess._parse_query_message(request_data)
         rcode, msg = self.xfrsess._parse_query_message(request_data)
-        self.assertEqual(Rcode.FORMERR(), rcode)
+        self.assertEqual(RRType.IXFR(), self.xfrsess._request_type)
+        self.assertEqual(Rcode.NOERROR(), rcode)
+
+        # Broken request: no question
+        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,
+                          self.create_request_data(with_question=False))
+
+        # Broken request: invalid RR type (neither AXFR nor IXFR)
+        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,
+                          self.create_request_data(qtype=RRType.A()))
 
 
         # tsig signed query message
         # tsig signed query message
         request_data = self.create_request_data(with_tsig=True)
         request_data = self.create_request_data(with_tsig=True)
@@ -587,16 +693,92 @@ class TestXfroutSession(TestXfroutSessionBase):
     def test_get_rrset_len(self):
     def test_get_rrset_len(self):
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
         self.assertEqual(82, get_rrset_len(self.soa_rrset))
 
 
-    def test_check_xfrout_available(self):
+    def test_xfrout_axfr_setup(self):
         self.xfrsess.ClientClass = MockDataSrcClient
         self.xfrsess.ClientClass = MockDataSrcClient
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('example.com')), Rcode.NOERROR())
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('notauth.example.com')), Rcode.NOTAUTH())
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('nosoa.example.com')), Rcode.SERVFAIL())
-        self.assertEqual(self.xfrsess._check_xfrout_available(
-                Name('multisoa.example.com')), Rcode.SERVFAIL())
+        # Successful case.  A zone iterator should be set up.
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())
+        self.assertNotEqual(None, self.xfrsess._iterator)
+
+        # Failure cases
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), Name('notauth.example.com'), TEST_RRCLASS),
+                         Rcode.NOTAUTH())
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), Name('nosoa.example.com'), TEST_RRCLASS),
+                         Rcode.SERVFAIL())
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), Name('multisoa.example.com'), TEST_RRCLASS),
+                         Rcode.SERVFAIL())
+
+    def test_xfrout_ixfr_setup(self):
+        self.xfrsess.ClientClass = MockDataSrcClient
+        self.set_request_type(RRType.IXFR())
+
+        # Successful case of pure IXFR.  A zone journal reader should be set
+        # up.
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())
+        self.assertNotEqual(None, self.xfrsess._jnl_reader)
+
+        # Successful case, but as a result of falling back to AXFR-style
+        # IXFR.  A zone iterator should be set up instead of a journal reader.
+        self.mdata = self.create_request_data(ixfr=IXFR_NG_VERSION)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), TEST_ZONE_NAME, TEST_RRCLASS), Rcode.NOERROR())
+        self.assertNotEqual(None, self.xfrsess._iterator)
+        self.assertEqual(None, self.xfrsess._jnl_reader)
+
+        # The data source doesn't support journaling.  Should fallback to AXFR.
+        zone_name = Name('nojournal.example.com')
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              zone_name=zone_name)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOERROR())
+        self.assertNotEqual(None, self.xfrsess._iterator)
+
+        # Failure cases
+        zone_name = Name('notauth.example.com')
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              zone_name=zone_name)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOTAUTH())
+        # this is a strange case: zone's SOA will be found but the journal
+        # reader won't be created due to 'no such zone'.
+        zone_name = Name('notauth2.example.com')
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              zone_name=zone_name)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.NOTAUTH())
+        zone_name = Name('nosoa.example.com')
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              zone_name=zone_name)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.SERVFAIL())
+        zone_name = Name('multisoa.example.com')
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              zone_name=zone_name)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.SERVFAIL())
+
+        # query name doesn't match the SOA's owner
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())
+
+        # query's RR class doesn't match the SOA's class
+        zone_name = TEST_ZONE_NAME # make sure the name matches this time
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              soa_class=RRClass.CH())
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())
+
+        # multiple SOA RRs
+        self.mdata = self.create_request_data(ixfr=IXFR_OK_VERSION,
+                                              num_soa=2)
+        self.assertEqual(self.xfrsess._xfrout_setup(
+                self.getmsg(), zone_name, TEST_RRCLASS), Rcode.FORMERR())
 
 
     def test_dns_xfrout_start_formerror(self):
     def test_dns_xfrout_start_formerror(self):
         # formerror
         # formerror
@@ -608,9 +790,9 @@ class TestXfroutSession(TestXfroutSessionBase):
         return "example.com"
         return "example.com"
 
 
     def test_dns_xfrout_start_notauth(self):
     def test_dns_xfrout_start_notauth(self):
-        def notauth(formpara):
+        def notauth(msg, name, rrclass):
             return Rcode.NOTAUTH()
             return Rcode.NOTAUTH()
-        self.xfrsess._check_xfrout_available = notauth
+        self.xfrsess._xfrout_setup = notauth
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         get_msg = self.sock.read_msg()
         get_msg = self.sock.read_msg()
         self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
         self.assertEqual(get_msg.get_rcode().to_text(), "NOTAUTH")
@@ -623,9 +805,9 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())
         self.assertEqual(self.sock.read_msg().get_rcode(), Rcode.SERVFAIL())
 
 
     def test_dns_xfrout_start_noerror(self):
     def test_dns_xfrout_start_noerror(self):
-        def noerror(form):
+        def noerror(msg, name, rrclass):
             return Rcode.NOERROR()
             return Rcode.NOERROR()
-        self.xfrsess._check_xfrout_available = noerror
+        self.xfrsess._xfrout_setup = noerror
 
 
         def myreply(msg, sock):
         def myreply(msg, sock):
             self.sock.send(b"success")
             self.sock.send(b"success")
@@ -684,19 +866,47 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         self.xfrsess._server.get_db_file = lambda : TESTDATA_SRCDIR + \
         self.xfrsess._server.get_db_file = lambda : TESTDATA_SRCDIR + \
             'test.sqlite3'
             'test.sqlite3'
 
 
-    def test_axfr_normal_session(self):
-        XfroutSession._handle(self.xfrsess)
-        response = self.sock.read_msg(Message.PRESERVE_ORDER);
-        self.assertEqual(Rcode.NOERROR(), response.get_rcode())
+    def check_axfr_stream(self, response):
+        '''Common checks for AXFR(-style) response for the test zone.
+        '''
         # This zone contains two A RRs for the same name with different TTLs.
         # This zone contains two A RRs for the same name with different TTLs.
         # These TTLs should be preseved in the AXFR stream.
         # These TTLs should be preseved in the AXFR stream.
+        # We'll check some important points as a valid AXFR response:
+        # the first and last RR must be SOA, and these should be the only
+        # SOAs in the response.  The total number of response RRs
+        # must be 5 (zone has 4 RRs, SOA is duplicated)
+        actual_records = response.get_section(Message.SECTION_ANSWER)
+        self.assertEqual(5, len(actual_records))
+        self.assertEqual(RRType.SOA(), actual_records[0].get_type())
+        self.assertEqual(RRType.SOA(), actual_records[-1].get_type())
         actual_ttls = []
         actual_ttls = []
-        for rr in response.get_section(Message.SECTION_ANSWER):
+        num_soa = 0
+        for rr in actual_records:
+            if rr.get_type() == RRType.SOA():
+                num_soa += 1
             if rr.get_type() == RRType.A() and \
             if rr.get_type() == RRType.A() and \
                     not rr.get_ttl() in actual_ttls:
                     not rr.get_ttl() in actual_ttls:
                 actual_ttls.append(rr.get_ttl().get_value())
                 actual_ttls.append(rr.get_ttl().get_value())
+        self.assertEqual(2, num_soa)
         self.assertEqual([3600, 7200], sorted(actual_ttls))
         self.assertEqual([3600, 7200], sorted(actual_ttls))
 
 
+    def test_axfr_normal_session(self):
+        XfroutSession._handle(self.xfrsess)
+        response = self.sock.read_msg(Message.PRESERVE_ORDER);
+        self.assertEqual(Rcode.NOERROR(), response.get_rcode())
+        self.check_axfr_stream(response)
+
+    def test_ixfr_to_axfr(self):
+        self.xfrsess._request_data = \
+            self.create_request_data(ixfr=IXFR_NG_VERSION)
+        XfroutSession._handle(self.xfrsess)
+        response = self.sock.read_msg(Message.PRESERVE_ORDER);
+        self.assertEqual(Rcode.NOERROR(), response.get_rcode())
+        # This is an AXFR-style IXFR.  So the question section should indicate
+        # that it's an IXFR resposne.
+        self.assertEqual(RRType.IXFR(), response.get_question()[0].get_type())
+        self.check_axfr_stream(response)
+
 class MyUnixSockServer(UnixSockServer):
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):
     def __init__(self):
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()

+ 148 - 35
src/bin/xfrout/xfrout.py.in

@@ -22,7 +22,7 @@ import isc.cc
 import threading
 import threading
 import struct
 import struct
 import signal
 import signal
-from isc.datasrc import DataSourceClient
+from isc.datasrc import DataSourceClient, ZoneFinder, ZoneJournalReader
 from socketserver import *
 from socketserver import *
 import os
 import os
 from isc.config.ccsession import *
 from isc.config.ccsession import *
@@ -102,7 +102,7 @@ def format_zone_str(zone_name, zone_class):
        zone_name (isc.dns.Name) name to format
        zone_name (isc.dns.Name) name to format
        zone_class (isc.dns.RRClass) class to format
        zone_class (isc.dns.RRClass) class to format
     """
     """
-    return zone_name.to_text() + '/' + str(zone_class)
+    return zone_name.to_text(True) + '/' + str(zone_class)
 
 
 # borrowed from xfrin.py @ #1298.
 # borrowed from xfrin.py @ #1298.
 def format_addrinfo(addrinfo):
 def format_addrinfo(addrinfo):
@@ -132,6 +132,11 @@ def get_rrset_len(rrset):
     rrset.to_wire(bytes)
     rrset.to_wire(bytes)
     return len(bytes)
     return len(bytes)
 
 
+def get_soa_serial(soa_rdata):
+    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+    (borrowed from xfrin)
+    '''
+    return int(soa_rdata.to_text().split()[2])
 
 
 class XfroutSession():
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
@@ -143,11 +148,12 @@ class XfroutSession():
         self._tsig_ctx = None
         self._tsig_ctx = None
         self._tsig_len = 0
         self._tsig_len = 0
         self._remote = remote
         self._remote = remote
-        self._request_type = 'AXFR' # could be IXFR when we support it
+        self._request_type = None
+        self._request_typestr = None
         self._acl = default_acl
         self._acl = default_acl
         self._zone_config = zone_config
         self._zone_config = zone_config
         self.ClientClass = client_class # parameterize this for testing
         self.ClientClass = client_class # parameterize this for testing
-        self._soa = None # will be set in _check_xfrout_available or in tests
+        self._soa = None # will be set in _xfrout_setup or in tests
         self._handle()
         self._handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -195,7 +201,8 @@ class XfroutSession():
         tsig_record = msg.get_tsig_record()
         tsig_record = msg.get_tsig_record()
         if tsig_record is not None:
         if tsig_record is not None:
             self._tsig_len = tsig_record.get_length()
             self._tsig_len = tsig_record.get_length()
-            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
+            self._tsig_ctx = self.create_tsig_ctx(tsig_record,
+                                                  self._tsig_key_ring)
             tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
             tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
             if tsig_error != TSIGError.NOERROR:
             if tsig_error != TSIGError.NOERROR:
                 return Rcode.NOTAUTH()
                 return Rcode.NOTAUTH()
@@ -218,24 +225,38 @@ class XfroutSession():
             return rcode, msg
             return rcode, msg
 
 
         # Make sure the question is valid.  This should be ensured by
         # Make sure the question is valid.  This should be ensured by
-        # the auth server, but since it's far from our xfrout itself,
-        # we check it by ourselves.
+        # the auth server, but since it's far from xfrout itself, we check
+        # it by ourselves.  A viloation would be an internal bug, so we
+        # raise and stop here rather than returning a FORMERR or SERVFAIL.
         if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
         if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
-            return Rcode.FORMERR(), msg
+            raise RuntimeError('Invalid number of question for XFR: ' +
+                               str(msg.get_rr_count(Message.SECTION_QUESTION)))
+        question = msg.get_question()[0]
+
+        # Identify the request type
+        self._request_type = question.get_type()
+        if self._request_type == RRType.AXFR():
+            self._request_typestr = 'AXFR'
+        elif self._request_type == RRType.IXFR():
+            self._request_typestr = 'IXFR'
+        else:
+            # Likewise, this should be impossible.
+            raise RuntimeError('Unexpected XFR type: ' +
+                               str(self._request_type))
 
 
         # ACL checks
         # ACL checks
-        zone_name = msg.get_question()[0].get_name()
-        zone_class = msg.get_question()[0].get_class()
+        zone_name = question.get_name()
+        zone_class = question.get_class()
         acl = self._get_transfer_acl(zone_name, zone_class)
         acl = self._get_transfer_acl(zone_name, zone_class)
         acl_result = acl.execute(
         acl_result = acl.execute(
             isc.acl.dns.RequestContext(self._remote[2], msg.get_tsig_record()))
             isc.acl.dns.RequestContext(self._remote[2], msg.get_tsig_record()))
         if acl_result == DROP:
         if acl_result == DROP:
-            logger.info(XFROUT_QUERY_DROPPED, self._request_type,
+            logger.info(XFROUT_QUERY_DROPPED, self._request_typestr,
                         format_addrinfo(self._remote),
                         format_addrinfo(self._remote),
                         format_zone_str(zone_name, zone_class))
                         format_zone_str(zone_name, zone_class))
             return None, None
             return None, None
         elif acl_result == REJECT:
         elif acl_result == REJECT:
-            logger.info(XFROUT_QUERY_REJECTED, self._request_type,
+            logger.info(XFROUT_QUERY_REJECTED, self._request_typestr,
                         format_addrinfo(self._remote),
                         format_addrinfo(self._remote),
                         format_zone_str(zone_name, zone_class))
                         format_zone_str(zone_name, zone_class))
             return Rcode.REFUSED(), msg
             return Rcode.REFUSED(), msg
@@ -295,23 +316,33 @@ class XfroutSession():
         msg.set_rcode(rcode_)
         msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
-    def _check_xfrout_available(self, zone_name):
-        '''Check if xfr request can be responsed.
-           TODO, Get zone's configuration from cfgmgr or some other place
-           eg. check allow_transfer setting,
+    def _get_zone_soa(self, zone_name):
+        '''Retrieve the SOA RR of the given zone.
+
+        It returns a pair of RCODE and the SOA (in the form of RRset).
+        On success RCODE is NOERROR and returned SOA is not None;
+        on failure RCODE indicates the appropriate code in the context of
+        xfr processing, and the returned SOA is None.
 
 
         '''
         '''
+        result, finder = self._datasrc_client.find_zone(zone_name)
+        if result != DataSourceClient.SUCCESS:
+            return (Rcode.NOTAUTH(), None)
+        result, soa_rrset = finder.find(zone_name, RRType.SOA(), None,
+                                        ZoneFinder.FIND_DEFAULT)
+        if result != ZoneFinder.SUCCESS:
+            return (Rcode.SERVFAIL(), None)
+        # Especially for database-based zones, a working zone may be in
+        # a broken state where it has more than one SOA RR.  We proactively
+        # check the condition and abort the xfr attempt if we identify it.
+        if soa_rrset.get_rdata_count() != 1:
+            return (Rcode.SERVFAIL(), None)
+        return (Rcode.NOERROR(), soa_rrset)
+
+    def __axfr_setup(self, zone_name):
+        '''Setup a zone iterator for AXFR or AXFR-style IXFR.
 
 
-        # Identify the data source for the requested zone and see if it has
-        # SOA while initializing objects used for request processing later.
-        # We should eventually generalize this so that we can choose the
-        # appropriate data source from (possible) multiple candidates.
-        # We should eventually take into account the RR class here.
-        # For now, we  hardcode a particular type (SQLite3-based), and only
-        # consider that one.
-        datasrc_config = '{ "database_file": "' + \
-            self._server.get_db_file() + '"}'
-        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
+        '''
         try:
         try:
             # Note that we enable 'separate_rrs'.  In xfr-out we need to
             # Note that we enable 'separate_rrs'.  In xfr-out we need to
             # preserve as many things as possible (even if it's half broken)
             # preserve as many things as possible (even if it's half broken)
@@ -336,6 +367,90 @@ class XfroutSession():
 
 
         return Rcode.NOERROR()
         return Rcode.NOERROR()
 
 
+    def __ixfr_setup(self, request_msg, zone_name, zone_class):
+        '''Setup a zone journal reader for IXFR.
+
+        If the underlying data source does not know the requested range
+        of zone differences it automatically falls back to AXFR-style
+        IXFR by setting up a zone iterator instead of a journal reader.
+
+        '''
+        # Check the authority section.  Look for a SOA record with
+        # the same name and class as the question.
+        remote_soa = None
+        for auth_rrset in request_msg.get_section(Message.SECTION_AUTHORITY):
+            # Ignore data whose owner name is not the zone apex, and
+            # ignore non-SOA or different class of records.
+            if auth_rrset.get_name() != zone_name or \
+                    auth_rrset.get_type() != RRType.SOA() or \
+                    auth_rrset.get_class() != zone_class:
+                continue
+            if auth_rrset.get_rdata_count() != 1:
+                logger.info(XFROUT_IXFR_MULTIPLE_SOA,
+                            format_addrinfo(self._remote))
+                return Rcode.FORMERR()
+            remote_soa = auth_rrset
+        if remote_soa is None:
+            logger.info(XFROUT_IXFR_NO_SOA, format_addrinfo(self._remote))
+            return Rcode.FORMERR()
+
+        # Retrieve the local SOA
+        rcode, self._soa = self._get_zone_soa(zone_name)
+        if rcode != Rcode.NOERROR():
+            return rcode
+        try:
+            begin_serial = get_soa_serial(remote_soa.get_rdata()[0])
+            end_serial = get_soa_serial(self._soa.get_rdata()[0])
+            code, self._jnl_reader = self._datasrc_client.get_journal_reader(
+                zone_name, begin_serial, end_serial)
+        except isc.datasrc.NotImplemented as ex:
+            # The underlying data source doesn't support journaling.
+            # Fall back to AXFR-style IXFR.
+            logger.info(XFROUT_IXFR_NO_JOURNAL_SUPPORT,
+                        format_addrinfo(self._remote),
+                        format_zone_str(zone_name, zone_class))
+            return self.__axfr_setup(zone_name)
+        if code == ZoneJournalReader.NO_SUCH_VERSION:
+            logger.info(XFROUT_IXFR_NO_VERSION, format_addrinfo(self._remote),
+                        format_zone_str(zone_name, zone_class),
+                        begin_serial, end_serial)
+            return self.__axfr_setup(zone_name)
+        if code == ZoneJournalReader.NO_SUCH_ZONE:
+            # this is quite unexpected as we know zone's SOA exists.
+            # It might be a bug or the data source is somehow broken,
+            # but it can still happen if someone has removed the zone
+            # between these two operations.  We treat it as NOTAUTH.
+            logger.warn(XFROUT_IXFR_NO_ZONE, format_addrinfo(self._remote),
+                        format_zone_str(zone_name, zone_class))
+            return Rcode.NOTAUTH()
+
+        return Rcode.NOERROR()
+
+    def _xfrout_setup(self, request_msg, zone_name, zone_class):
+        '''Setup a context for xfr responses according to the request type.
+
+        This method identifies the most appropriate data source for the
+        request and set up a zone iterator or journal reader depending on
+        whether the request is AXFR or IXFR.  If it identifies any protocol
+        level error it returns an RCODE other than NOERROR.
+
+        '''
+
+        # Identify the data source for the requested zone and see if it has
+        # SOA while initializing objects used for request processing later.
+        # We should eventually generalize this so that we can choose the
+        # appropriate data source from (possible) multiple candidates.
+        # We should eventually take into account the RR class here.
+        # For now, we hardcode a particular type (SQLite3-based), and only
+        # consider that one.
+        datasrc_config = '{ "database_file": "' + \
+            self._server.get_db_file() + '"}'
+        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
+
+        if self._request_type == RRType.AXFR():
+            return self.__axfr_setup(zone_name)
+        else:
+            return self.__ixfr_setup(request_msg, zone_name, zone_class)
 
 
     def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
     def dns_xfrout_start(self, sock_fd, msg_query, quota_ok=True):
         rcode_, msg = self._parse_query_message(msg_query)
         rcode_, msg = self._parse_query_message(msg_query)
@@ -348,7 +463,7 @@ class XfroutSession():
             return self._reply_query_with_error_rcode(msg, sock_fd,
             return self._reply_query_with_error_rcode(msg, sock_fd,
                                                       Rcode.FORMERR())
                                                       Rcode.FORMERR())
         elif not quota_ok:
         elif not quota_ok:
-            logger.warn(XFROUT_QUERY_QUOTA_EXCCEEDED, self._request_type,
+            logger.warn(XFROUT_QUERY_QUOTA_EXCCEEDED, self._request_typestr,
                         format_addrinfo(self._remote),
                         format_addrinfo(self._remote),
                         self._server._max_transfers_out)
                         self._server._max_transfers_out)
             return self._reply_query_with_error_rcode(msg, sock_fd,
             return self._reply_query_with_error_rcode(msg, sock_fd,
@@ -359,27 +474,26 @@ class XfroutSession():
         zone_class = question.get_class()
         zone_class = question.get_class()
         zone_str = format_zone_str(zone_name, zone_class) # for logging
         zone_str = format_zone_str(zone_name, zone_class) # for logging
 
 
-        # TODO: we should also include class in the check
         try:
         try:
-            rcode_ = self._check_xfrout_available(zone_name)
+            rcode_ = self._xfrout_setup(msg, zone_name, zone_class)
         except Exception as ex:
         except Exception as ex:
-            logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_type,
+            logger.error(XFROUT_XFR_TRANSFER_CHECK_ERROR, self._request_typestr,
                          format_addrinfo(self._remote), zone_str, ex)
                          format_addrinfo(self._remote), zone_str, ex)
             rcode_ = Rcode.SERVFAIL()
             rcode_ = Rcode.SERVFAIL()
         if rcode_ != Rcode.NOERROR():
         if rcode_ != Rcode.NOERROR():
-            logger.info(XFROUT_AXFR_TRANSFER_FAILED, self._request_type,
+            logger.info(XFROUT_AXFR_TRANSFER_FAILED, self._request_typestr,
                         format_addrinfo(self._remote), zone_str, rcode_)
                         format_addrinfo(self._remote), zone_str, rcode_)
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
 
         try:
         try:
-            logger.info(XFROUT_AXFR_TRANSFER_STARTED, self._request_type,
+            logger.info(XFROUT_AXFR_TRANSFER_STARTED, self._request_typestr,
                         format_addrinfo(self._remote), zone_str)
                         format_addrinfo(self._remote), zone_str)
             self._reply_xfrout_query(msg, sock_fd)
             self._reply_xfrout_query(msg, sock_fd)
         except Exception as err:
         except Exception as err:
-            logger.error(XFROUT_AXFR_TRANSFER_ERROR, self._request_type,
+            logger.error(XFROUT_AXFR_TRANSFER_ERROR, self._request_typestr,
                     format_addrinfo(self._remote), zone_str, err)
                     format_addrinfo(self._remote), zone_str, err)
             pass
             pass
-        logger.info(XFROUT_AXFR_TRANSFER_DONE, self._request_type,
+        logger.info(XFROUT_AXFR_TRANSFER_DONE, self._request_typestr,
                     format_addrinfo(self._remote), zone_str)
                     format_addrinfo(self._remote), zone_str)
 
 
     def _clear_message(self, msg):
     def _clear_message(self, msg):
@@ -409,7 +523,6 @@ class XfroutSession():
         msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
         msg.add_rrset(Message.SECTION_ANSWER, rrset_soa)
         self._send_message(sock_fd, msg, self._tsig_ctx)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
-
     def _reply_xfrout_query(self, msg, sock_fd):
     def _reply_xfrout_query(self, msg, sock_fd):
         #TODO, there should be a better way to insert rrset.
         #TODO, there should be a better way to insert rrset.
         msg.make_response()
         msg.make_response()

+ 31 - 0
src/bin/xfrout/xfrout_messages.mes

@@ -178,3 +178,34 @@ on, but the file is in use. The most likely cause is that another
 xfrout daemon process is still running. This xfrout daemon (the one
 xfrout daemon process is still running. This xfrout daemon (the one
 printing this message) will not start.
 printing this message) will not start.
 
 
+% XFROUT_IXFR_MULTIPLE_SOA IXFR client %1: authority section has multiple SOAs
+An IXFR request was received with more than one SOA RRs in the authority
+section.  The xfrout daemon rejects the request with an RCODE of
+FORMERR.
+
+% XFROUT_IXFR_NO_SOA IXFR client %1: missing SOA
+An IXFR request was received with no SOA RR in the authority section.
+The xfrout daemon rejects the request with an RCODE of FORMERR.
+
+% XFROUT_IXFR_NO_JOURNAL_SUPPORT IXFR client %1, %2: journaling not supported in the data source, falling back to AXFR
+An IXFR request was received but the underlying data source did
+not support journaling.  The xfrout daemon fell back to AXFR-style
+IXFR.
+
+% XFROUT_IXFR_NO_VERSION IXFR client %1, %2: version (%3 to %4) not in journal, falling back to AXFR
+An IXFR request was received, but the requested range of differences
+were not found in the data source.  The xfrout daemon fell back to
+AXFR-style IXFR.
+
+% XFROUT_IXFR_NO_ZONE IXFR client %1, %2: zone not found with journal
+The requested zone in IXFR was not found in the data source
+even though the xfrout daemon sucessfully found the SOA RR of the zone
+in the data source.  This can happen if the administrator removed the
+zone from the data source within the small duration between these
+operations, but it's more likely to be a bug or broken data source.
+Unless you know why this message was logged, and especially if it
+happens often, it's advisable to check whether the data source is
+valid for this zone.  The xfrout daemon considers it a possible,
+though unlikely, event, and returns a response with an RCODE of
+NOTAUTH.
+

+ 25 - 12
src/lib/python/isc/datasrc/client_python.cc

@@ -182,19 +182,32 @@ DataSourceClient_getJournalReader(PyObject* po_self, PyObject* args) {
 
 
     if (PyArg_ParseTuple(args, "O!kk", &name_type, &name_obj,
     if (PyArg_ParseTuple(args, "O!kk", &name_type, &name_obj,
                          &begin_obj, &end_obj)) {
                          &begin_obj, &end_obj)) {
-        pair<ZoneJournalReader::Result, ZoneJournalReaderPtr> result =
-            self->cppobj->getInstance().getJournalReader(
-                PyName_ToName(name_obj), static_cast<uint32_t>(begin_obj),
-                static_cast<uint32_t>(end_obj));
-        PyObject* po_reader;
-        if (result.first == ZoneJournalReader::SUCCESS) {
-            po_reader = createZoneJournalReaderObject(result.second, po_self);
-        } else {
-            po_reader = Py_None;
-            Py_INCREF(po_reader); // this will soon be released
+        try {
+            pair<ZoneJournalReader::Result, ZoneJournalReaderPtr> result =
+                self->cppobj->getInstance().getJournalReader(
+                    PyName_ToName(name_obj), static_cast<uint32_t>(begin_obj),
+                    static_cast<uint32_t>(end_obj));
+            PyObject* po_reader;
+            if (result.first == ZoneJournalReader::SUCCESS) {
+                po_reader = createZoneJournalReaderObject(result.second,
+                                                          po_self);
+            } else {
+                po_reader = Py_None;
+                Py_INCREF(po_reader); // this will soon be released
+            }
+            PyObjectContainer container(po_reader);
+            return (Py_BuildValue("(iO)", result.first, container.get()));
+        } catch (const isc::NotImplemented& ex) {
+            PyErr_SetString(getDataSourceException("NotImplemented"),
+                            ex.what());
+        } catch (const DataSourceError& ex) {
+            PyErr_SetString(getDataSourceException("Error"), ex.what());
+        } catch (const std::exception& ex) {
+            PyErr_SetString(getDataSourceException("Error"), ex.what());
+        } catch (...) {
+            PyErr_SetString(getDataSourceException("Error"),
+                            "Unexpected exception");
         }
         }
-        PyObjectContainer container(po_reader);
-        return (Py_BuildValue("(iO)", result.first, container.get()));
     }
     }
     return (NULL);
     return (NULL);
 }
 }

+ 1 - 0
src/lib/python/isc/datasrc/tests/Makefile.am

@@ -6,6 +6,7 @@ EXTRA_DIST = $(PYTESTS)
 
 
 EXTRA_DIST += testdata/brokendb.sqlite3
 EXTRA_DIST += testdata/brokendb.sqlite3
 EXTRA_DIST += testdata/example.com.sqlite3
 EXTRA_DIST += testdata/example.com.sqlite3
+EXTRA_DIST += testdata/test.sqlite3.nodiffs
 CLEANFILES = $(abs_builddir)/rwtest.sqlite3.copied
 CLEANFILES = $(abs_builddir)/rwtest.sqlite3.copied
 
 
 # If necessary (rare cases), explicitly specify paths to dynamic libraries
 # If necessary (rare cases), explicitly specify paths to dynamic libraries

+ 9 - 0
src/lib/python/isc/datasrc/tests/datasrc_test.py

@@ -803,6 +803,15 @@ class JournalRead(unittest.TestCase):
         # ZoneJournalReader can only be constructed via a factory
         # ZoneJournalReader can only be constructed via a factory
         self.assertRaises(TypeError, ZoneJournalReader)
         self.assertRaises(TypeError, ZoneJournalReader)
 
 
+    def test_journal_reader_old_schema(self):
+        # The database doesn't have a "diffs" table.
+        dbfile = TESTDATA_PATH + 'test.sqlite3.nodiffs'
+        client = isc.datasrc.DataSourceClient("sqlite3",
+                                              "{ \"database_file\": \"" + \
+                                                  dbfile + "\" }")
+        self.assertRaises(isc.datasrc.Error, client.get_journal_reader,
+                          self.zname, 0, 1)
+
 if __name__ == "__main__":
 if __name__ == "__main__":
     isc.log.init("bind10")
     isc.log.init("bind10")
     isc.log.resetUnitTestRootLogger()
     isc.log.resetUnitTestRootLogger()

BIN
src/lib/python/isc/datasrc/tests/testdata/test.sqlite3.nodiffs