Browse Source

[master] Merge branch 'trac1299'

JINMEI Tatuya 13 years ago
parent
commit
6ff03bb9d6

+ 270 - 47
src/bin/xfrin/tests/xfrin_test.py

@@ -20,6 +20,7 @@ import socket
 import sys
 import io
 from isc.testutils.tsigctx_mock import MockTSIGContext
+from isc.testutils.rrset_utils import *
 from xfrin import *
 import xfrin
 from isc.xfrin.diff import Diff
@@ -42,11 +43,9 @@ TEST_RRCLASS_STR = 'IN'
 TEST_DB_FILE = 'db_file'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
-                             socket.IPPROTO_TCP, '',
                              (TEST_MASTER_IPV4_ADDRESS, 53))
 TEST_MASTER_IPV6_ADDRESS = '::1'
 TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM,
-                             socket.IPPROTO_TCP, '',
                              (TEST_MASTER_IPV6_ADDRESS, 53))
 
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
@@ -230,7 +229,7 @@ class MockXfrinConnection(XfrinConnection):
     def __init__(self, sock_map, zone_name, rrclass, datasrc_client,
                  shutdown_event, master_addr, tsig_key=None):
         super().__init__(sock_map, zone_name, rrclass, MockDataSourceClient(),
-                         shutdown_event, master_addr)
+                         shutdown_event, master_addr, TEST_DB_FILE)
         self.query_data = b''
         self.reply_data = b''
         self.force_time_out = False
@@ -280,10 +279,11 @@ class MockXfrinConnection(XfrinConnection):
                 self.response_generator()
         return len(data)
 
-    def create_response_data(self, response=True, bad_qid=False,
+    def create_response_data(self, response=True, auth=True, bad_qid=False,
                              rcode=Rcode.NOERROR(),
                              questions=default_questions,
                              answers=default_answers,
+                             authorities=[],
                              tsig_ctx=None):
         resp = Message(Message.RENDER)
         qid = self.qid
@@ -294,8 +294,11 @@ class MockXfrinConnection(XfrinConnection):
         resp.set_rcode(rcode)
         if response:
             resp.set_header_flag(Message.HEADERFLAG_QR)
+        if auth:
+            resp.set_header_flag(Message.HEADERFLAG_AA)
         [resp.add_question(q) for q in questions]
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
+        [resp.add_rrset(Message.SECTION_AUTHORITY, a) for a in authorities]
 
         renderer = MessageRenderer()
         if tsig_ctx is not None:
@@ -348,13 +351,44 @@ class TestXfrinInitialSOA(TestXfrinState):
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinFirstData()),
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.conn._end_serial)
+        self.assertEqual(1234, self.conn._end_serial.get_value())
 
     def test_handle_not_soa(self):
         # The given RR is not of SOA
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                           self.ns_rrset)
 
+    def test_handle_ixfr_uptodate(self):
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(1234) # same as soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinIXFRUptodate()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_ixfr_uptodate2(self):
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(1235) # > soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinIXFRUptodate()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_ixfr_uptodate3(self):
+        # Similar to the previous case, but checking serial number arithmetic
+        # comparison
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(0xffffffff)
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinFirstData()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_axfr_uptodate(self):
+        # "request serial" should matter only for IXFR
+        self.conn._request_type = RRType.AXFR()
+        self.conn._request_serial = isc.dns.Serial(1234) # same as soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinFirstData()),
+                         type(self.conn.get_xfrstate()))
+
     def test_finish_message(self):
         self.assertTrue(self.state.finish_message(self.conn))
 
@@ -363,7 +397,8 @@ class TestXfrinFirstData(TestXfrinState):
         super().setUp()
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
-        self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        # arbitrary chosen serial < 1234:
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._diff = None           # should be replaced in the AXFR case
 
     def test_handle_ixfr_begin_soa(self):
@@ -443,7 +478,7 @@ class TestXfrinIXFRDelete(TestXfrinState):
         # false.
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual([], self.conn._diff.get_buffer())
-        self.assertEqual(1234, self.conn._current_serial)
+        self.assertEqual(1234, self.conn._current_serial.get_value())
         self.assertEqual(type(XfrinIXFRAddSOA()),
                          type(self.conn.get_xfrstate()))
 
@@ -474,7 +509,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         # We need record the state in 'conn' to check the case where the
         # state doesn't change.
         XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
-        self.conn._current_serial = 1230
+        self.conn._current_serial = isc.dns.Serial(1230)
         self.state = self.conn.get_xfrstate()
 
     def test_handle_add_rr(self):
@@ -486,7 +521,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
 
     def test_handle_end_soa(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
@@ -495,7 +530,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual([], self.conn._diff.get_buffer())
 
     def test_handle_new_delete(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         # SOA RR whose serial is the current one means we are going to a new
         # difference, starting with removing that SOA.
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
@@ -506,7 +541,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
 
     def test_handle_out_of_sync(self):
         # getting SOA with an inconsistent serial.  This is an error.
-        self.conn._end_serial = 1235
+        self.conn._end_serial = isc.dns.Serial(1235)
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
                           self.conn, soa_rrset)
 
@@ -525,11 +560,24 @@ class TestXfrinIXFREnd(TestXfrinState):
     def test_finish_message(self):
         self.assertFalse(self.state.finish_message(self.conn))
 
+class TestXfrinIXFREnd(TestXfrinState):
+    def setUp(self):
+        super().setUp()
+        self.state = XfrinIXFRUptodate()
+
+    def test_handle_rr(self):
+        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
+                          self.ns_rrset)
+
+    def test_finish_message(self):
+        self.assertRaises(XfrinZoneUptodate, self.state.finish_message,
+                          self.conn)
+
 class TestXfrinAXFR(TestXfrinState):
     def setUp(self):
         super().setUp()
         self.state = XfrinAXFR()
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
 
     def test_handle_rr(self):
         """
@@ -604,7 +652,10 @@ class TestXfrinConnection(unittest.TestCase):
             'questions': [example_soa_question],
             'bad_qid': False,
             'response': True,
+            'auth': True,
             'rcode': Rcode.NOERROR(),
+            'answers': default_answers,
+            'authorities': [],
             'tsig': False,
             'axfr_after_soa': self._create_normal_response_data
             }
@@ -661,8 +712,11 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = self.conn.create_response_data(
             bad_qid=self.soa_response_params['bad_qid'],
             response=self.soa_response_params['response'],
+            auth=self.soa_response_params['auth'],
             rcode=self.soa_response_params['rcode'],
             questions=self.soa_response_params['questions'],
+            answers=self.soa_response_params['answers'],
+            authorities=self.soa_response_params['authorities'],
             tsig_ctx=verify_ctx)
         if self.soa_response_params['axfr_after_soa'] != None:
             self.conn.response_generator = \
@@ -693,6 +747,15 @@ class TestXfrinConnection(unittest.TestCase):
         rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS, nsname))
         return rrset
 
+    def _set_test_zone(self, zone_name):
+        '''Set the zone name for transfer to the specified one.
+
+        It also make sure that the SOA RR (if exist) is correctly (re)set.
+
+        '''
+        self.conn._zone_name = zone_name
+        self.conn._zone_soa = self.conn._get_zone_soa()
+
 class TestAXFR(TestXfrinConnection):
     def setUp(self):
         super().setUp()
@@ -787,25 +850,26 @@ class TestAXFR(TestXfrinConnection):
         # IXFR query
         msg = self.conn._create_query(RRType.IXFR())
         check_query(RRType.IXFR(), begin_soa_rrset)
-        self.assertEqual(1230, self.conn._request_serial)
+        self.assertEqual(1230, self.conn._request_serial.get_value())
 
     def test_create_ixfr_query_fail(self):
         # In these cases _create_query() will fail to find a valid SOA RR to
         # insert in the IXFR query, and should raise an exception.
 
-        self.conn._zone_name = Name('no-such-zone.example')
+        self._set_test_zone(Name('no-such-zone.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
 
-        self.conn._zone_name = Name('partial-match-zone.example')
+        self._set_test_zone(Name('partial-match-zone.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
 
-        self.conn._zone_name = Name('no-soa.example')
+        self._set_test_zone(Name('no-soa.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
 
-        self.conn._zone_name = Name('dup-soa.example')
+        self._set_test_zone(Name('dup-soa.example'))
+        self.conn._zone_soa = self.conn._get_zone_soa()
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
 
@@ -836,8 +900,10 @@ class TestAXFR(TestXfrinConnection):
         self.conn._tsig_key = TSIG_KEY
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         self.conn._send_query(RRType.SOA())
-        self.conn.reply_data = self.conn.create_response_data(rcode=Rcode.NOTAUTH())
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.conn.reply_data = \
+            self.conn.create_response_data(rcode=Rcode.NOTAUTH())
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_without_end_soa(self):
         self.conn._send_query(RRType.AXFR())
@@ -850,7 +916,8 @@ class TestAXFR(TestXfrinConnection):
     def test_response_bad_qid(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_error_code_bad_sig(self):
         self.conn._tsig_key = TSIG_KEY
@@ -861,7 +928,7 @@ class TestAXFR(TestXfrinConnection):
                 rcode=Rcode.SERVFAIL())
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADSIG",
                                self.conn._handle_xfrin_responses)
 
@@ -873,7 +940,7 @@ class TestAXFR(TestXfrinConnection):
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADKEY",
                                self.conn._handle_xfrin_responses)
 
@@ -886,18 +953,21 @@ class TestAXFR(TestXfrinConnection):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
             rcode=Rcode.SERVFAIL())
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_multi_question(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
             questions=[example_axfr_question, example_axfr_question])
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
     def test_response_non_response(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(response = False)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
     def test_soacheck(self):
         # we need to defer the creation until we know the QID, which is
@@ -912,7 +982,7 @@ class TestAXFR(TestXfrinConnection):
     def test_soacheck_badqid(self):
         self.soa_response_params['bad_qid'] = True
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_bad_qid_bad_sig(self):
         self.conn._tsig_key = TSIG_KEY
@@ -922,19 +992,123 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADSIG",
                                self.conn._check_soa_serial)
 
     def test_soacheck_non_response(self):
         self.soa_response_params['response'] = False
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_error_code(self):
         self.soa_response_params['rcode'] = Rcode.SERVFAIL()
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_notauth(self):
+        self.soa_response_params['auth'] = False
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate(self):
+        # Primary's SOA serial is identical the local serial
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate2(self):
+        # Primary's SOA serial is "smaller" than the local serial
+        self.soa_response_params['answers'] = [create_soa(1229)]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate3(self):
+        # Similar to the previous case, but checking the comparison is based
+        # on the serial number arithmetic.
+        self.soa_response_params['answers'] = [create_soa(0xffffffff)]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_newzone(self):
+        # Primary's SOA is 'old', but this secondary doesn't know anything
+        # about the zone yet, so it should accept it.
+        def response_generator():
+            # _request_serial is set in _check_soa_serial().  Reset it here.
+            self.conn._request_serial = None
+            self._create_soa_response_data()
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = response_generator
+        self.assertEqual(XFRIN_OK, self.conn._check_soa_serial())
+
+    def test_soacheck_question_empty(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = []
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_name_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(Name('example.org'),
+                                                          TEST_RRCLASS,
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_class_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          RRClass.CH(),
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_type_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          TEST_RRCLASS,
+                                                          RRType.AAAA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_no_soa(self):
+        # The response just doesn't contain SOA without any other indication
+        # of errors.
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_soa_name_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = [create_soa(1234,
+                                                          Name('example.org'))]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_soa_class_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        soa = RRset(TEST_ZONE_NAME, RRClass.CH(), RRType.SOA(), RRTTL(0))
+        soa.add_rdata(Rdata(RRType.SOA(), RRClass.CH(), 'm. r. 1234 0 0 0 0'))
+        self.soa_response_params['answers'] = [soa]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_multiple_soa(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = [soa_rrset, soa_rrset]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_cname_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        # Add SOA to answer, too, to make sure that it that deceives the parser
+        self.soa_response_params['answers'] = [soa_rrset, create_cname()]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_referral_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.soa_response_params['authorities'] = [create_ns('ns.example.com')]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_nodata_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.soa_response_params['authorities'] = [soa_rrset]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_with_tsig(self):
         # Use a mock tsig context emulating a validly signed response
@@ -953,7 +1127,7 @@ class TestAXFR(TestXfrinConnection):
         self.soa_response_params['rcode'] = Rcode.NOTAUTH()
         self.conn.response_generator = self._create_soa_response_data
 
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_with_tsig_noerror_badsig(self):
         self.conn._tsig_key = TSIG_KEY
@@ -966,7 +1140,7 @@ class TestAXFR(TestXfrinConnection):
         # treat this as a final failure (just as BIND 9 does).
         self.conn.response_generator = self._create_soa_response_data
 
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_with_tsig_unsigned_response(self):
         # we can use a real TSIGContext for this.  the response doesn't
@@ -975,14 +1149,14 @@ class TestAXFR(TestXfrinConnection):
         # it as a fatal transaction failure, too.
         self.conn._tsig_key = TSIG_KEY
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_with_unexpected_tsig_response(self):
         # we reject unexpected TSIG in responses (following BIND 9's
         # behavior)
         self.soa_response_params['tsig'] = True
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
@@ -1244,6 +1418,18 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
 
+    def test_do_soacheck_uptodate(self):
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
+
+    def test_do_soacheck_protocol_error(self):
+        # There are several cases, but at this level it's sufficient to check
+        # only one.  We use the case where there's no SOA in the response.
+        self.soa_response_params['answers'] = []
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL)
+
     def test_do_soacheck_and_xfrin_with_tsig(self):
         # We are going to have a SOA query/response transaction, followed by
         # AXFR, all TSIG signed.  xfrin should use a new TSIG context for
@@ -1276,9 +1462,8 @@ class TestIXFRResponse(TestXfrinConnection):
     def setUp(self):
         super().setUp()
         self.conn._query_id = self.conn.qid = 1035
-        self.conn._request_serial = 1230
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._request_type = RRType.IXFR()
-        self._zone_name = TEST_ZONE_NAME
         self.conn._datasrc_client = MockDataSourceClient()
         XfrinInitialSOA().set_xfrstate(self.conn, XfrinInitialSOA())
 
@@ -1353,6 +1538,16 @@ class TestIXFRResponse(TestXfrinConnection):
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
 
+    def test_ixfr_response_uptodate(self):
+        '''IXFR response indicates the zone is new enough'''
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[begin_soa_rrset])
+        self.assertRaises(XfrinZoneUptodate, self.conn._handle_xfrin_responses)
+        # no diffs should have been committed
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
+
     def test_ixfr_response_broken(self):
         '''Test with a broken response.
 
@@ -1385,6 +1580,22 @@ class TestIXFRResponse(TestXfrinConnection):
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
 
+    def test_ixfr_response_uptodate_extra(self):
+        '''Similar to 'uptodate' test, but with extra bogus data.
+
+        In either case an exception will be raised, but in this case it's
+        considered an error.
+
+        '''
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[begin_soa_rrset, soa_rrset])
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
+        # no diffs should have been committed
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
+
     def test_ixfr_to_axfr_response(self):
         '''AXFR-style IXFR response.
 
@@ -1488,13 +1699,25 @@ class TestIXFRSession(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
-    def test_do_xfrin_fail(self):
+    def test_do_xfrin_fail2(self):
         '''IXFR fails due to a bogus DNS message.
 
         '''
         self._create_broken_response_data()
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
+    def test_do_xfrin_uptodate(self):
+        '''IXFR is (gracefully) aborted because serial is not new
+
+        '''
+        def create_response():
+            self.conn.reply_data = self.conn.create_response_data(
+                questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
+                                    RRType.IXFR())],
+                answers=[begin_soa_rrset])
+        self.conn.response_generator = create_response
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
+
 class TestXFRSessionWithSQLite3(TestXfrinConnection):
     '''Tests for XFR sessions using an SQLite3 DB.
 
@@ -1549,9 +1772,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
 
         # Confirm xfrin succeeds and SOA is updated
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
 
         # Also confirm the corresponding diffs are stored in the diffs table
         conn = sqlite3.connect(self.sqlite3db_obj)
@@ -1580,12 +1803,12 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                          self._create_soa('1235')])
         self.conn.response_generator = create_ixfr_response
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
 
     def test_do_ixfrin_nozone_sqlite3(self):
-        self.conn._zone_name = Name('nosuchzone.example')
+        self._set_test_zone(Name('nosuchzone.example'))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         # This should fail even before starting state transition
         self.assertEqual(None, self.conn.get_xfrstate())
@@ -1601,11 +1824,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_response
 
         # Confirm xfrin succeeds and SOA is updated, A RR is deleted.
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, type))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
 
@@ -1633,11 +1856,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                 answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
         self.conn.response_generator = create_response
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, type))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
 
@@ -1671,11 +1894,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                                     RRType.AXFR())],
                 answers=[soa_rrset, self._create_ns(), soa_rrset])
         self.conn.response_generator = create_response
-        self.conn._zone_name = Name('example.com')
+        self._set_test_zone(Name('example.com'))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.AXFR()))
         self.assertEqual(type(XfrinAXFREnd()),
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
 

+ 228 - 109
src/bin/xfrin/xfrin.py.in

@@ -24,6 +24,7 @@ import struct
 import threading
 import socket
 import random
+from functools import reduce
 from optparse import OptionParser, OptionValueError
 from isc.config.ccsession import *
 from isc.notify import notify_out
@@ -75,9 +76,10 @@ DEFAULT_MASTER_PORT = 53
 DEFAULT_ZONE_CLASS = RRClass.IN()
 
 __version__ = 'BIND10'
-# define xfrin rcode
-XFRIN_OK = 0
-XFRIN_FAIL = 1
+
+# Internal result codes of an xfr session
+XFRIN_OK = 0                    # normal success
+XFRIN_FAIL = 1                  # general failure (internal/external)
 
 class XfrinException(Exception):
     pass
@@ -87,6 +89,11 @@ class XfrinProtocolError(Exception):
     '''
     pass
 
+class XfrinZoneUptodate(Exception):
+    '''TBD
+    '''
+    pass
+
 class XfrinZoneInfoException(Exception):
     """This exception is raised if there is an error in the given
        configuration (part), or when a command does not have a required
@@ -153,7 +160,7 @@ def format_addrinfo(addrinfo):
                         "appear to be consisting of (family, socktype, (addr, port))")
 
 def get_soa_serial(soa_rdata):
-    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+    '''Extract the serial field of SOA RDATA and return it as a Serial object.
 
     We don't have to be very efficient here, so we first dump the entire RDATA
     as a string and convert the first corresponding field.  This should be
@@ -162,7 +169,7 @@ def get_soa_serial(soa_rdata):
     should be a more direct and convenient way to get access to the SOA
     fields.
     '''
-    return int(soa_rdata.to_text().split()[2])
+    return Serial(int(soa_rdata.to_text().split()[2]))
 
 class XfrinState:
     '''
@@ -181,12 +188,12 @@ class XfrinState:
                              (AXFR or
             (recv SOA)        AXFR-style IXFR)  (SOA, add)
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
-                          |                  |  ^         (post xfr
-                          |                  |  |        checks, then
-                          |                  +--+        commit)
-                          |            (non SOA, add)
-                          |
-                          |                     (non SOA, delete)
+         |                |                  |  ^         (post xfr
+         |(IXFR &&        |                  |  |        checks, then
+         | recv SOA       |                  +--+        commit)
+         | not new)       |            (non SOA, add)
+         V                |
+    IXFRUptodate          |                     (non SOA, delete)
                (pure IXFR,|                           +-------+
             keep handling)|             (Delete SOA)  V       |
                           + ->IXFRDeleteSOA------>IXFRDelete--+
@@ -300,13 +307,14 @@ class XfrinInitialSOA(XfrinState):
                                      + rr.get_type().to_text() + ' received)')
         conn._end_serial = get_soa_serial(rr.get_rdata()[0])
 
-        # FIXME: we need to check the serial is actually greater than ours.
-        # To do so, however, we need to implement serial number arithmetic.
-        # Although it wouldn't be a big task, we'll leave it for a separate
-        # task for now.  (Always performing xfr could be inefficient, but
-        # shouldn't do any harm otherwise)
+        if conn._request_type == RRType.IXFR() and \
+                conn._end_serial <= conn._request_serial:
+            logger.info(XFRIN_IXFR_UPTODATE, conn.zone_str(),
+                        conn._request_serial, conn._end_serial)
+            self.set_xfrstate(conn, XfrinIXFRUptodate())
+        else:
+            self.set_xfrstate(conn, XfrinFirstData())
 
-        self.set_xfrstate(conn, XfrinFirstData())
         return True
 
 class XfrinFirstData(XfrinState):
@@ -430,6 +438,14 @@ class XfrinIXFREnd(XfrinState):
         '''
         return False
 
+class XfrinIXFRUptodate(XfrinState):
+    def handle_rr(self, conn, rr):
+        raise XfrinProtocolError('Extra data after single IXFR response ' +
+                                 rr.to_text())
+
+    def finish_message(self, conn):
+        raise XfrinZoneUptodate
+
 class XfrinAXFR(XfrinState):
     def handle_rr(self, conn, rr):
         """
@@ -473,10 +489,13 @@ class XfrinConnection(asyncore.dispatcher):
 
     def __init__(self,
                  sock_map, zone_name, rrclass, datasrc_client,
-                 shutdown_event, master_addrinfo, tsig_key=None,
+                 shutdown_event, master_addrinfo, db_file, tsig_key=None,
                  idle_timeout=60):
         '''Constructor of the XfirnConnection class.
 
+        db_file: SQLite3 DB file.  Unforutnately we still need this for
+                 temporary workaround in _get_zone_soa().  This should be
+                 removed when we eliminate the need for the workaround.
         idle_timeout: max idle time for read data from socket.
         datasrc_client: the data source client object used for the XFR session.
                         This will eventually replace db_file completely.
@@ -500,7 +519,9 @@ class XfrinConnection(asyncore.dispatcher):
         self._rrclass = rrclass
 
         # Data source handler
+        self._db_file = db_file
         self._datasrc_client = datasrc_client
+        self._zone_soa = self._get_zone_soa()
 
         self._sock_map = sock_map
         self._soa_rr_count = 0
@@ -524,6 +545,55 @@ class XfrinConnection(asyncore.dispatcher):
         self.create_socket(self._master_addrinfo[0], self._master_addrinfo[1])
         self.setblocking(1)
 
+    def _get_zone_soa(self):
+        '''Retrieve the current SOA RR of the zone to be transferred.
+
+        It will be used for various purposes in subsequent xfr protocol
+        processing.   It is validly possible that the zone is currently
+        empty and therefore doesn't have an SOA, so this method doesn't
+        consider it an error and returns None in such a case.  It may or
+        may not result in failure in the actual processing depending on
+        how the SOA is used.
+
+        When the zone has an SOA RR, this method makes sure that it's
+        valid, i.e., it has exactly one RDATA; if it is not the case
+        this method returns None.
+
+        If the underlying data source doesn't even know the zone, this method
+        tries to provide backward compatible behavior where xfrin is
+        responsible for creating zone in the corresponding DB table.
+        For a longer term we should deprecate this behavior by introducing
+        more generic zone management framework, but at the moment we try
+        to not surprise existing users.  (Note also that the part of
+        providing the compatible behavior uses the old data source API.
+        We'll deprecate this API in a near future, too).
+
+        '''
+        # get the zone finder.  this must be SUCCESS (not even
+        # PARTIALMATCH) because we are specifying the zone origin name.
+        result, finder = self._datasrc_client.find_zone(self._zone_name)
+        if result != DataSourceClient.SUCCESS:
+            # The data source doesn't know the zone.  For now, we provide
+            # backward compatibility and creates a new one ourselves.
+            isc.datasrc.sqlite3_ds.load(self._db_file,
+                                        self._zone_name.to_text(),
+                                        lambda : [])
+            logger.warn(XFRIN_ZONE_CREATED, self.zone_str())
+            # try again
+            result, finder = self._datasrc_client.find_zone(self._zone_name)
+        if result != DataSourceClient.SUCCESS:
+            return None
+        result, soa_rrset = finder.find(self._zone_name, RRType.SOA(),
+                                        None, ZoneFinder.FIND_DEFAULT)
+        if result != ZoneFinder.SUCCESS:
+            logger.info(XFRIN_ZONE_NO_SOA, self.zone_str())
+            return None
+        if soa_rrset.get_rdata_count() != 1:
+            logger.warn(XFRIN_ZONE_MULTIPLE_SOA, self.zone_str(),
+                        soa_rrset.get_rdata_count())
+            return None
+        return soa_rrset
+
     def __set_xfrstate(self, new_state):
         self.__state = new_state
 
@@ -545,37 +615,16 @@ class XfrinConnection(asyncore.dispatcher):
                          str(e))
             return False
 
-    def _get_zone_soa(self):
-        result, finder = self._datasrc_client.find_zone(self._zone_name)
-        if result != DataSourceClient.SUCCESS:
-            raise XfrinException('Zone not found in the given data ' +
-                                 'source: ' + self.zone_str())
-        result, soa_rrset = finder.find(self._zone_name, RRType.SOA(),
-                                        None, ZoneFinder.FIND_DEFAULT)
-        if result != ZoneFinder.SUCCESS:
-            raise XfrinException('SOA RR not found in zone: ' +
-                                 self.zone_str())
-        # Especially for database-based zones, a working zone may be in
-        # a broken state where it has more than one SOA RR.  We proactively
-        # check the condition and abort the xfr attempt if we identify it.
-        if soa_rrset.get_rdata_count() != 1:
-            raise XfrinException('Invalid number of SOA RRs for ' +
-                                 self.zone_str() + ': ' +
-                                 str(soa_rrset.get_rdata_count()))
-        return soa_rrset
-
     def _create_query(self, query_type):
         '''Create an XFR-related query message.
 
-        query_type is either SOA, AXFR or IXFR.  For type IXFR, it searches
-        the associated data source for the current SOA record to include
-        it in the query.  If the corresponding zone or the SOA record
-        cannot be found, it raises an XfrinException exception.  Note that
-        this may not necessarily a broken configuration; for the first attempt
-        of transfer the secondary may not have any boot-strap zone
-        information, in which case IXFR simply won't work.  The xfrin
-        should then fall back to AXFR.  _request_serial is recorded for
-        later use.
+        query_type is either SOA, AXFR or IXFR.  An IXFR query needs the
+        zone's current SOA record.  If it's not known, it raises an
+        XfrinException exception.  Note that this may not necessarily a
+        broken configuration; for the first attempt of transfer the secondary
+        may not have any boot-strap zone information, in which case IXFR
+        simply won't work.  The xfrin should then fall back to AXFR.
+        _request_serial is recorded for later use.
 
         '''
         msg = Message(Message.RENDER)
@@ -585,27 +634,19 @@ class XfrinConnection(asyncore.dispatcher):
         msg.set_opcode(Opcode.QUERY())
         msg.set_rcode(Rcode.NOERROR())
         msg.add_question(Question(self._zone_name, self._rrclass, query_type))
+
+        # Remember our serial, if known
+        self._request_serial = get_soa_serial(self._zone_soa.get_rdata()[0]) \
+            if self._zone_soa is not None else None
+
+        # Set the authority section with our SOA for IXFR
         if query_type == RRType.IXFR():
-            # get the zone finder.  this must be SUCCESS (not even
-            # PARTIALMATCH) because we are specifying the zone origin name.
-            zone_soa_rr = self._get_zone_soa()
-            msg.add_rrset(Message.SECTION_AUTHORITY, zone_soa_rr)
-            self._request_serial = get_soa_serial(zone_soa_rr.get_rdata()[0])
-        else:
-            # For AXFR, we temporarily provide backward compatible behavior
-            # where xfrin is responsible for creating zone in the corresponding
-            # DB table.  Note that the code below uses the old data source
-            # API and assumes SQLite3 in an ugly manner.  We'll have to
-            # develop a better way of managing zones in a generic way and
-            # eliminate the code like the one here.
-            try:
-                self._get_zone_soa()
-            except XfrinException:
-                def empty_rr_generator():
-                    return []
-                isc.datasrc.sqlite3_ds.load(self._db_file,
-                                            self._zone_name.to_text(),
-                                            empty_rr_generator)
+            if self._zone_soa is None:
+                # (incremental) IXFR doesn't work without known SOA
+                raise XfrinException('Failed to create IXFR query due to no ' +
+                                     'SOA for ' + self.zone_str())
+            msg.add_rrset(Message.SECTION_AUTHORITY, self._zone_soa)
+
         return msg
 
     def _send_data(self, data):
@@ -659,7 +700,8 @@ class XfrinConnection(asyncore.dispatcher):
         if self._tsig_ctx is not None:
             tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
             if tsig_error != TSIGError.NOERROR:
-                raise XfrinException('TSIG verify fail: %s' % str(tsig_error))
+                raise XfrinProtocolError('TSIG verify fail: %s' %
+                                         str(tsig_error))
         elif tsig_record is not None:
             # If the response includes a TSIG while we didn't sign the query,
             # we treat it as an error.  RFC doesn't say anything about this
@@ -668,13 +710,78 @@ class XfrinConnection(asyncore.dispatcher):
             # implementation would return such a response, and since this is
             # part of security mechanism, it's probably better to be more
             # strict.
-            raise XfrinException('Unexpected TSIG in response')
+            raise XfrinProtocolError('Unexpected TSIG in response')
+
+    def __parse_soa_response(self, msg, response_data):
+        '''Parse a response to SOA query and extract the SOA from answer.
+
+        This is a subroutine of _check_soa_serial().  This method also
+        validates message, and rejects bogus responses with XfrinProtocolError.
+
+        If everything is okay, it returns the SOA RR from the answer section
+        of the response.
+
+        '''
+        # Check TSIG integrity and validate the header.  Unlike AXFR/IXFR,
+        # we should be more strict for SOA queries and check the AA flag, too.
+        self._check_response_tsig(msg, response_data)
+        self._check_response_header(msg)
+        if not msg.get_header_flag(Message.HEADERFLAG_AA):
+            raise XfrinProtocolError('non-authoritative answer to SOA query')
+
+        # Validate the question section
+        n_question = msg.get_rr_count(Message.SECTION_QUESTION)
+        if n_question != 1:
+            raise XfrinProtocolError('Invalid response to SOA query: ' +
+                                     '(' + str(n_question) + ' questions, 1 ' +
+                                     'expected)')
+        resp_question = msg.get_question()[0]
+        if resp_question.get_name() != self._zone_name or \
+                resp_question.get_class() != self._rrclass or \
+                resp_question.get_type() != RRType.SOA():
+            raise XfrinProtocolError('Invalid response to SOA query: '
+                                     'question mismatch: ' +
+                                     str(resp_question))
+
+        # Look into the answer section for SOA
+        soa = None
+        for rr in msg.get_section(Message.SECTION_ANSWER):
+            if rr.get_type() == RRType.SOA():
+                if soa is not None:
+                    raise XfrinProtocolError('SOA response had multiple SOAs')
+                soa = rr
+            # There should not be a CNAME record at top of zone.
+            if rr.get_type() == RRType.CNAME():
+                raise XfrinProtocolError('SOA query resulted in CNAME')
+
+        # If SOA is not found, try to figure out the reason then report it.
+        if soa is None:
+            # See if we have any SOA records in the authority section.
+            for rr in msg.get_section(Message.SECTION_AUTHORITY):
+                if rr.get_type() == RRType.NS():
+                    raise XfrinProtocolError('SOA query resulted in referral')
+                if rr.get_type() == RRType.SOA():
+                    raise XfrinProtocolError('SOA query resulted in NODATA')
+            raise XfrinProtocolError('No SOA record found in response to ' +
+                                     'SOA query')
+
+        # Check if the SOA is really what we asked for
+        if soa.get_name() != self._zone_name or \
+                soa.get_class() != self._rrclass:
+            raise XfrinProtocolError("SOA response doesn't match query: " +
+                                     str(soa))
+
+        # All okay, return it
+        return soa
+
 
     def _check_soa_serial(self):
-        ''' Compare the soa serial, if soa serial in master is less than
-        the soa serial in local, Finish xfrin.
-        False: soa serial in master is less or equal to the local one.
-        True: soa serial in master is bigger
+        '''Send SOA query and compare the local and remote serials.
+
+        If we know our local serial and the remote serial isn't newer
+        than ours, we abort the session with XfrinZoneUptodate.
+        On success it returns XFRIN_OK for testing.  The caller won't use it.
+
         '''
 
         self._send_query(RRType.SOA())
@@ -682,18 +789,23 @@ class XfrinConnection(asyncore.dispatcher):
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         soa_response = self._get_request_response(msg_len)
         msg = Message(Message.PARSE)
-        msg.from_wire(soa_response)
+        msg.from_wire(soa_response, Message.PRESERVE_ORDER)
+
+        # Validate/parse the rest of the response, and extract the SOA
+        # from the answer section
+        soa = self.__parse_soa_response(msg, soa_response)
+
+        # Compare the two serials.  If ours is 'new', abort with ZoneUptodate.
+        primary_serial = get_soa_serial(soa.get_rdata()[0])
+        if self._request_serial is not None and \
+                self._request_serial >= primary_serial:
+            if self._request_serial != primary_serial:
+                logger.info(XFRIN_ZONE_SERIAL_AHEAD, primary_serial,
+                            self.zone_str(),
+                            format_addrinfo(self._master_addrinfo),
+                            self._request_serial)
+            raise XfrinZoneUptodate
 
-        # TSIG related checks, including an unexpected signed response
-        self._check_response_tsig(msg, soa_response)
-
-        # perform some minimal level validation.  It's an open issue how
-        # strict we should be (see the comment in _check_response_header())
-        self._check_response_header(msg)
-
-        # TODO, need select soa record from data source then compare the two
-        # serial, current just return OK, since this function hasn't been used
-        # now.
         return XFRIN_OK
 
     def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
@@ -704,22 +816,30 @@ class XfrinConnection(asyncore.dispatcher):
             self._request_type = request_type
             # Right now RRType.[IA]XFR().to_text() is 'TYPExxx', so we need
             # to hardcode here.
-            request_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
+            req_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
             if check_soa:
-                ret =  self._check_soa_serial()
-
-            if ret == XFRIN_OK:
-                logger.info(XFRIN_XFR_TRANSFER_STARTED, request_str,
-                            self.zone_str())
-                self._send_query(self._request_type)
-                self.__state = XfrinInitialSOA()
-                self._handle_xfrin_responses()
-                logger.info(XFRIN_XFR_TRANSFER_SUCCESS, request_str,
-                            self.zone_str())
-
-        except (XfrinException, XfrinProtocolError) as e:
-            logger.error(XFRIN_XFR_TRANSFER_FAILURE, request_str,
-                         self.zone_str(), str(e))
+                self._check_soa_serial()
+
+            logger.info(XFRIN_XFR_TRANSFER_STARTED, req_str, self.zone_str())
+            self._send_query(self._request_type)
+            self.__state = XfrinInitialSOA()
+            self._handle_xfrin_responses()
+            logger.info(XFRIN_XFR_TRANSFER_SUCCESS, req_str, self.zone_str())
+
+        except XfrinZoneUptodate:
+            # Eventually we'll probably have to treat this case as a trigger
+            # of trying another primary server, etc, but for now we treat it
+            # as "success".
+            pass
+        except XfrinProtocolError as e:
+            logger.info(XFRIN_XFR_TRANSFER_PROTOCOL_ERROR, req_str,
+                        self.zone_str(),
+                        format_addrinfo(self._master_addrinfo), str(e))
+            ret = XFRIN_FAIL
+        except XfrinException as e:
+            logger.error(XFRIN_XFR_TRANSFER_FAILURE, req_str,
+                         self.zone_str(),
+                         format_addrinfo(self._master_addrinfo), str(e))
             ret = XFRIN_FAIL
         except Exception as e:
             # Catching all possible exceptions like this is generally not a
@@ -730,7 +850,7 @@ class XfrinConnection(asyncore.dispatcher):
             # catch it here, but until then we need broadest coverage so that
             # we won't miss anything.
 
-            logger.error(XFRIN_XFR_OTHER_FAILURE, request_str,
+            logger.error(XFRIN_XFR_OTHER_FAILURE, req_str,
                          self.zone_str(), str(e))
             ret = XFRIN_FAIL
         finally:
@@ -754,13 +874,14 @@ class XfrinConnection(asyncore.dispatcher):
 
         msg_rcode = msg.get_rcode()
         if msg_rcode != Rcode.NOERROR():
-            raise XfrinException('error response: %s' % msg_rcode.to_text())
+            raise XfrinProtocolError('error response: %s' %
+                                     msg_rcode.to_text())
 
         if not msg.get_header_flag(Message.HEADERFLAG_QR):
-            raise XfrinException('response is not a response')
+            raise XfrinProtocolError('response is not a response')
 
         if msg.get_qid() != self._query_id:
-            raise XfrinException('bad query id')
+            raise XfrinProtocolError('bad query id')
 
     def _check_response_status(self, msg):
         '''Check validation of xfr response. '''
@@ -768,7 +889,7 @@ class XfrinConnection(asyncore.dispatcher):
         self._check_response_header(msg)
 
         if msg.get_rr_count(Message.SECTION_QUESTION) > 1:
-            raise XfrinException('query section count greater than 1')
+            raise XfrinProtocolError('query section count greater than 1')
 
     def _handle_xfrin_responses(self):
         read_next_msg = True
@@ -808,8 +929,8 @@ class XfrinConnection(asyncore.dispatcher):
         return False
 
 def __process_xfrin(server, zone_name, rrclass, db_file,
-                  shutdown_event, master_addrinfo, check_soa, tsig_key,
-                  request_type, conn_class):
+                    shutdown_event, master_addrinfo, check_soa, tsig_key,
+                    request_type, conn_class):
     conn = None
     exception = None
     ret = XFRIN_FAIL
@@ -840,11 +961,9 @@ def __process_xfrin(server, zone_name, rrclass, db_file,
         while retry:
             retry = False
             conn = conn_class(sock_map, zone_name, rrclass, datasrc_client,
-                              shutdown_event, master_addrinfo, tsig_key)
+                              shutdown_event, master_addrinfo, db_file,
+                              tsig_key)
             conn.init_socket()
-            # XXX: We still need _db_file for temporary workaround in _create_query().
-            # This should be removed when we eliminate the need for the workaround.
-            conn._db_file = db_file
             ret = XFRIN_FAIL
             if conn.connect_to_master():
                 ret = conn.do_xfrin(check_soa, request_type)

+ 62 - 7
src/bin/xfrin/xfrin_messages.mes

@@ -15,18 +15,63 @@
 # No namespace declaration - these constants go in the global namespace
 # of the xfrin messages python module.
 
+% XFRIN_ZONE_CREATED Zone %1 not found in the given data source, newly created
+On starting an xfrin session, it is identified that the zone to be
+transferred is not found in the data source.  This can happen if a
+secondary DNS server first tries to perform AXFR from a primary server
+without creating the zone image beforehand (e.g. by b10-loadzone).  As
+of this writing the xfrin process provides backward compatible
+behavior to previous versions: creating a new one in the data source
+not to surprise existing users too much.  This is probably not a good
+idea, however, in terms of who should be responsible for managing
+zones at a higher level.  In future it is more likely that a separate
+zone management framework is provided, and the situation where the
+given zone isn't found in xfrout will be treated as an error.
+
+% XFRIN_ZONE_NO_SOA Zone %1 does not have SOA
+On starting an xfrin session, it is identified that the zone to be
+transferred does not have an SOA RR in the data source.  This is not
+necessarily an error; if a secondary DNS server first tries to perform
+transfer from a primary server, the zone can be empty, and therefore
+doesn't have an SOA.  Subsequent AXFR will fill in the zone; if the
+attempt is IXFR it will fail in query creation.
+
+% XFRIN_ZONE_MULTIPLE_SOA Zone %1 has %2 SOA RRs
+On starting an xfrin session, it is identified that the zone to be
+transferred has multiple SOA RRs.  Such a zone is broken, but could be
+accidentally configured especially in a data source using "non
+captive" backend database.  The implementation ignores entire SOA RRs
+and tries to continue processing as if the zone were empty.  This
+means subsequent AXFR can succeed and possibly replace the zone with
+valid content, but an IXFR attempt will fail.
+
+% XFRIN_ZONE_SERIAL_AHEAD Serial number (%1) for %2 received from master %3 < ours (%4)
+The response to an SOA query prior to xfr indicated that the zone's
+SOA serial at the primary server is smaller than that of the xfrin
+client.  This is not necessarily an error especially if that
+particular primary server is another secondary server which hasn't got
+the latest version of the zone.  But if the primary server is known to
+be the real source of the zone, some unexpected inconsistency may have
+happened, and you may want to take a closer look.  In this case xfrin
+doesn't perform subsequent zone transfer.
+
 % XFRIN_XFR_OTHER_FAILURE %1 transfer of zone %2 failed: %3
 The XFR transfer for the given zone has failed due to a problem outside
 of the xfrin module.  Possible reasons are a broken DNS message or failure
 in database connection.  The error is shown in the log message.
 
-% XFRIN_AXFR_DATABASE_FAILURE AXFR transfer of zone %1 failed: %2
-The AXFR transfer for the given zone has failed due to a database problem.
-The error is shown in the log message.  Note: due to the code structure
-this can only happen for AXFR.
-
-% XFRIN_XFR_TRANSFER_FAILURE %1 transfer of zone %2 failed: %3
-The XFR transfer for the given zone has failed due to a protocol error.
+% XFRIN_XFR_TRANSFER_PROTOCOL_ERROR %1 transfer of zone %2 with %3 failed: %4
+The XFR transfer for the given zone has failed due to a protocol
+error, such as an unexpected response from the primary server.  The
+error is shown in the log message.  It may be because the primary
+server implementation is broken or (although less likely) there was
+some attack attempt, but it can also happen due to configuration
+mismatch such as the remote server does not have authority for the
+zone any more but the local configuration hasn't been updated.  So it
+is recommended to check the primary server configuration.
+
+% XFRIN_XFR_TRANSFER_FAILURE %1 transfer of zone %2 with %3 failed: %4
+The XFR transfer for the given zone has failed due to an internal error.
 The error is shown in the log message.
 
 % XFRIN_XFR_TRANSFER_FALLBACK falling back from IXFR to AXFR for %1
@@ -118,6 +163,16 @@ daemon will now shut down.
 An uncaught exception was raised while running the xfrin daemon. The
 exception message is printed in the log message.
 
+% XFRIN_IXFR_UPTODATE IXFR requested serial for %1 is %2, master has %3, not updating
+The first SOA record in an IXFR response indicates the zone's serial
+at the primary server is not newer than the client's.  This is
+basically unexpected event because normally the client first checks
+the SOA serial by an SOA query, but can still happen if the transfer
+is manually invoked or (although unlikely) there is a rapid change at
+the primary server between the SOA and IXFR queries.  The client
+implementation confirms the whole response is this single SOA, and
+aborts the transfer just like a successful case.
+
 % XFRIN_GOT_INCREMENTAL_RESP got incremental response for %1
 In an attempt of IXFR processing, the begenning SOA of the first difference
 (following the initial SOA that specified the final SOA for all the

+ 2 - 0
src/lib/dns/python/tests/serial_python_test.py

@@ -77,9 +77,11 @@ class SerialTest(unittest.TestCase):
         self.assertLessEqual(self.one, self.one)
         self.assertLessEqual(self.one, self.one_2)
         self.assertLess(self.one, self.two)
+        self.assertLessEqual(self.one, self.one)
         self.assertLessEqual(self.one, self.two)
         self.assertGreater(self.two, self.one)
         self.assertGreaterEqual(self.two, self.two)
+        self.assertGreaterEqual(self.two, self.one)
         self.assertLess(self.one, self.number_low)
         self.assertLess(self.number_low, self.number_medium)
         self.assertLess(self.number_medium, self.number_high)

+ 1 - 1
src/lib/dns/serial.cc

@@ -51,7 +51,7 @@ Serial::operator>(const Serial& other) const {
 
 bool
 Serial::operator>=(const Serial& other) const {
-    return (operator==(other) || !operator>(other));
+    return (!operator<(other));
 }
 
 Serial

+ 1 - 1
src/lib/dns/tests/serial_unittest.cc

@@ -49,7 +49,6 @@ TEST_F(SerialTest, get_value) {
 
 TEST_F(SerialTest, equals) {
     EXPECT_EQ(one, one);
-    EXPECT_EQ(one, one);
     EXPECT_EQ(one, one_2);
     EXPECT_NE(one, two);
     EXPECT_NE(two, one);
@@ -65,6 +64,7 @@ TEST_F(SerialTest, comparison) {
     EXPECT_LE(one, two);
     EXPECT_GE(two, two);
     EXPECT_GT(two, one);
+    EXPECT_GE(two, one);
     EXPECT_LT(one, number_low);
     EXPECT_LT(number_low, number_medium);
     EXPECT_LT(number_medium, number_high);

+ 6 - 0
src/lib/python/isc/testutils/rrset_utils.py

@@ -53,6 +53,12 @@ def create_ns(nsname, name=Name('example.com'), ttl=3600):
     rrset.add_rdata(Rdata(RRType.NS(), RRClass.IN(), nsname))
     return rrset
 
+def create_cname(target='target.example.com', name=Name('example.com'),
+                 ttl=3600):
+    rrset = RRset(name, RRClass.IN(), RRType.CNAME(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.CNAME(), RRClass.IN(), target))
+    return rrset
+
 def create_generic(name, rdlen, type=RRType('TYPE65300'), ttl=3600):
     '''Create an RR of a general type with an arbitrary length of RDATA