Browse Source

[master] Merge branch 'trac1299'

JINMEI Tatuya 13 years ago
parent
commit
6ff03bb9d6

+ 270 - 47
src/bin/xfrin/tests/xfrin_test.py

@@ -20,6 +20,7 @@ import socket
 import sys
 import sys
 import io
 import io
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.testutils.tsigctx_mock import MockTSIGContext
+from isc.testutils.rrset_utils import *
 from xfrin import *
 from xfrin import *
 import xfrin
 import xfrin
 from isc.xfrin.diff import Diff
 from isc.xfrin.diff import Diff
@@ -42,11 +43,9 @@ TEST_RRCLASS_STR = 'IN'
 TEST_DB_FILE = 'db_file'
 TEST_DB_FILE = 'db_file'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRESS = '127.0.0.1'
 TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
 TEST_MASTER_IPV4_ADDRINFO = (socket.AF_INET, socket.SOCK_STREAM,
-                             socket.IPPROTO_TCP, '',
                              (TEST_MASTER_IPV4_ADDRESS, 53))
                              (TEST_MASTER_IPV4_ADDRESS, 53))
 TEST_MASTER_IPV6_ADDRESS = '::1'
 TEST_MASTER_IPV6_ADDRESS = '::1'
 TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM,
 TEST_MASTER_IPV6_ADDRINFO = (socket.AF_INET6, socket.SOCK_STREAM,
-                             socket.IPPROTO_TCP, '',
                              (TEST_MASTER_IPV6_ADDRESS, 53))
                              (TEST_MASTER_IPV6_ADDRESS, 53))
 
 
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
 TESTDATA_SRCDIR = os.getenv("TESTDATASRCDIR")
@@ -230,7 +229,7 @@ class MockXfrinConnection(XfrinConnection):
     def __init__(self, sock_map, zone_name, rrclass, datasrc_client,
     def __init__(self, sock_map, zone_name, rrclass, datasrc_client,
                  shutdown_event, master_addr, tsig_key=None):
                  shutdown_event, master_addr, tsig_key=None):
         super().__init__(sock_map, zone_name, rrclass, MockDataSourceClient(),
         super().__init__(sock_map, zone_name, rrclass, MockDataSourceClient(),
-                         shutdown_event, master_addr)
+                         shutdown_event, master_addr, TEST_DB_FILE)
         self.query_data = b''
         self.query_data = b''
         self.reply_data = b''
         self.reply_data = b''
         self.force_time_out = False
         self.force_time_out = False
@@ -280,10 +279,11 @@ class MockXfrinConnection(XfrinConnection):
                 self.response_generator()
                 self.response_generator()
         return len(data)
         return len(data)
 
 
-    def create_response_data(self, response=True, bad_qid=False,
+    def create_response_data(self, response=True, auth=True, bad_qid=False,
                              rcode=Rcode.NOERROR(),
                              rcode=Rcode.NOERROR(),
                              questions=default_questions,
                              questions=default_questions,
                              answers=default_answers,
                              answers=default_answers,
+                             authorities=[],
                              tsig_ctx=None):
                              tsig_ctx=None):
         resp = Message(Message.RENDER)
         resp = Message(Message.RENDER)
         qid = self.qid
         qid = self.qid
@@ -294,8 +294,11 @@ class MockXfrinConnection(XfrinConnection):
         resp.set_rcode(rcode)
         resp.set_rcode(rcode)
         if response:
         if response:
             resp.set_header_flag(Message.HEADERFLAG_QR)
             resp.set_header_flag(Message.HEADERFLAG_QR)
+        if auth:
+            resp.set_header_flag(Message.HEADERFLAG_AA)
         [resp.add_question(q) for q in questions]
         [resp.add_question(q) for q in questions]
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
+        [resp.add_rrset(Message.SECTION_AUTHORITY, a) for a in authorities]
 
 
         renderer = MessageRenderer()
         renderer = MessageRenderer()
         if tsig_ctx is not None:
         if tsig_ctx is not None:
@@ -348,13 +351,44 @@ class TestXfrinInitialSOA(TestXfrinState):
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinFirstData()),
         self.assertEqual(type(XfrinFirstData()),
                          type(self.conn.get_xfrstate()))
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.conn._end_serial)
+        self.assertEqual(1234, self.conn._end_serial.get_value())
 
 
     def test_handle_not_soa(self):
     def test_handle_not_soa(self):
         # The given RR is not of SOA
         # The given RR is not of SOA
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
         self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
                           self.ns_rrset)
                           self.ns_rrset)
 
 
+    def test_handle_ixfr_uptodate(self):
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(1234) # same as soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinIXFRUptodate()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_ixfr_uptodate2(self):
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(1235) # > soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinIXFRUptodate()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_ixfr_uptodate3(self):
+        # Similar to the previous case, but checking serial number arithmetic
+        # comparison
+        self.conn._request_type = RRType.IXFR()
+        self.conn._request_serial = isc.dns.Serial(0xffffffff)
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinFirstData()),
+                         type(self.conn.get_xfrstate()))
+
+    def test_handle_axfr_uptodate(self):
+        # "request serial" should matter only for IXFR
+        self.conn._request_type = RRType.AXFR()
+        self.conn._request_serial = isc.dns.Serial(1234) # same as soa_rrset
+        self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
+        self.assertEqual(type(XfrinFirstData()),
+                         type(self.conn.get_xfrstate()))
+
     def test_finish_message(self):
     def test_finish_message(self):
         self.assertTrue(self.state.finish_message(self.conn))
         self.assertTrue(self.state.finish_message(self.conn))
 
 
@@ -363,7 +397,8 @@ class TestXfrinFirstData(TestXfrinState):
         super().setUp()
         super().setUp()
         self.state = XfrinFirstData()
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
         self.conn._request_type = RRType.IXFR()
-        self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        # arbitrary chosen serial < 1234:
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._diff = None           # should be replaced in the AXFR case
         self.conn._diff = None           # should be replaced in the AXFR case
 
 
     def test_handle_ixfr_begin_soa(self):
     def test_handle_ixfr_begin_soa(self):
@@ -443,7 +478,7 @@ class TestXfrinIXFRDelete(TestXfrinState):
         # false.
         # false.
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual([], self.conn._diff.get_buffer())
         self.assertEqual([], self.conn._diff.get_buffer())
-        self.assertEqual(1234, self.conn._current_serial)
+        self.assertEqual(1234, self.conn._current_serial.get_value())
         self.assertEqual(type(XfrinIXFRAddSOA()),
         self.assertEqual(type(XfrinIXFRAddSOA()),
                          type(self.conn.get_xfrstate()))
                          type(self.conn.get_xfrstate()))
 
 
@@ -474,7 +509,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         # We need record the state in 'conn' to check the case where the
         # We need record the state in 'conn' to check the case where the
         # state doesn't change.
         # state doesn't change.
         XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
         XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
-        self.conn._current_serial = 1230
+        self.conn._current_serial = isc.dns.Serial(1230)
         self.state = self.conn.get_xfrstate()
         self.state = self.conn.get_xfrstate()
 
 
     def test_handle_add_rr(self):
     def test_handle_add_rr(self):
@@ -486,7 +521,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
 
 
     def test_handle_end_soa(self):
     def test_handle_end_soa(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
@@ -495,7 +530,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual([], self.conn._diff.get_buffer())
         self.assertEqual([], self.conn._diff.get_buffer())
 
 
     def test_handle_new_delete(self):
     def test_handle_new_delete(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         # SOA RR whose serial is the current one means we are going to a new
         # SOA RR whose serial is the current one means we are going to a new
         # difference, starting with removing that SOA.
         # difference, starting with removing that SOA.
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
@@ -506,7 +541,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
 
 
     def test_handle_out_of_sync(self):
     def test_handle_out_of_sync(self):
         # getting SOA with an inconsistent serial.  This is an error.
         # getting SOA with an inconsistent serial.  This is an error.
-        self.conn._end_serial = 1235
+        self.conn._end_serial = isc.dns.Serial(1235)
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
                           self.conn, soa_rrset)
                           self.conn, soa_rrset)
 
 
@@ -525,11 +560,24 @@ class TestXfrinIXFREnd(TestXfrinState):
     def test_finish_message(self):
     def test_finish_message(self):
         self.assertFalse(self.state.finish_message(self.conn))
         self.assertFalse(self.state.finish_message(self.conn))
 
 
+class TestXfrinIXFREnd(TestXfrinState):
+    def setUp(self):
+        super().setUp()
+        self.state = XfrinIXFRUptodate()
+
+    def test_handle_rr(self):
+        self.assertRaises(XfrinProtocolError, self.state.handle_rr, self.conn,
+                          self.ns_rrset)
+
+    def test_finish_message(self):
+        self.assertRaises(XfrinZoneUptodate, self.state.finish_message,
+                          self.conn)
+
 class TestXfrinAXFR(TestXfrinState):
 class TestXfrinAXFR(TestXfrinState):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
         self.state = XfrinAXFR()
         self.state = XfrinAXFR()
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
 
 
     def test_handle_rr(self):
     def test_handle_rr(self):
         """
         """
@@ -604,7 +652,10 @@ class TestXfrinConnection(unittest.TestCase):
             'questions': [example_soa_question],
             'questions': [example_soa_question],
             'bad_qid': False,
             'bad_qid': False,
             'response': True,
             'response': True,
+            'auth': True,
             'rcode': Rcode.NOERROR(),
             'rcode': Rcode.NOERROR(),
+            'answers': default_answers,
+            'authorities': [],
             'tsig': False,
             'tsig': False,
             'axfr_after_soa': self._create_normal_response_data
             'axfr_after_soa': self._create_normal_response_data
             }
             }
@@ -661,8 +712,11 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = self.conn.create_response_data(
         self.conn.reply_data = self.conn.create_response_data(
             bad_qid=self.soa_response_params['bad_qid'],
             bad_qid=self.soa_response_params['bad_qid'],
             response=self.soa_response_params['response'],
             response=self.soa_response_params['response'],
+            auth=self.soa_response_params['auth'],
             rcode=self.soa_response_params['rcode'],
             rcode=self.soa_response_params['rcode'],
             questions=self.soa_response_params['questions'],
             questions=self.soa_response_params['questions'],
+            answers=self.soa_response_params['answers'],
+            authorities=self.soa_response_params['authorities'],
             tsig_ctx=verify_ctx)
             tsig_ctx=verify_ctx)
         if self.soa_response_params['axfr_after_soa'] != None:
         if self.soa_response_params['axfr_after_soa'] != None:
             self.conn.response_generator = \
             self.conn.response_generator = \
@@ -693,6 +747,15 @@ class TestXfrinConnection(unittest.TestCase):
         rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS, nsname))
         rrset.add_rdata(Rdata(RRType.NS(), TEST_RRCLASS, nsname))
         return rrset
         return rrset
 
 
+    def _set_test_zone(self, zone_name):
+        '''Set the zone name for transfer to the specified one.
+
+        It also make sure that the SOA RR (if exist) is correctly (re)set.
+
+        '''
+        self.conn._zone_name = zone_name
+        self.conn._zone_soa = self.conn._get_zone_soa()
+
 class TestAXFR(TestXfrinConnection):
 class TestAXFR(TestXfrinConnection):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
@@ -787,25 +850,26 @@ class TestAXFR(TestXfrinConnection):
         # IXFR query
         # IXFR query
         msg = self.conn._create_query(RRType.IXFR())
         msg = self.conn._create_query(RRType.IXFR())
         check_query(RRType.IXFR(), begin_soa_rrset)
         check_query(RRType.IXFR(), begin_soa_rrset)
-        self.assertEqual(1230, self.conn._request_serial)
+        self.assertEqual(1230, self.conn._request_serial.get_value())
 
 
     def test_create_ixfr_query_fail(self):
     def test_create_ixfr_query_fail(self):
         # In these cases _create_query() will fail to find a valid SOA RR to
         # In these cases _create_query() will fail to find a valid SOA RR to
         # insert in the IXFR query, and should raise an exception.
         # insert in the IXFR query, and should raise an exception.
 
 
-        self.conn._zone_name = Name('no-such-zone.example')
+        self._set_test_zone(Name('no-such-zone.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
                           RRType.IXFR())
 
 
-        self.conn._zone_name = Name('partial-match-zone.example')
+        self._set_test_zone(Name('partial-match-zone.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
                           RRType.IXFR())
 
 
-        self.conn._zone_name = Name('no-soa.example')
+        self._set_test_zone(Name('no-soa.example'))
         self.assertRaises(XfrinException, self.conn._create_query,
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
                           RRType.IXFR())
 
 
-        self.conn._zone_name = Name('dup-soa.example')
+        self._set_test_zone(Name('dup-soa.example'))
+        self.conn._zone_soa = self.conn._get_zone_soa()
         self.assertRaises(XfrinException, self.conn._create_query,
         self.assertRaises(XfrinException, self.conn._create_query,
                           RRType.IXFR())
                           RRType.IXFR())
 
 
@@ -836,8 +900,10 @@ class TestAXFR(TestXfrinConnection):
         self.conn._tsig_key = TSIG_KEY
         self.conn._tsig_key = TSIG_KEY
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         self.conn._send_query(RRType.SOA())
         self.conn._send_query(RRType.SOA())
-        self.conn.reply_data = self.conn.create_response_data(rcode=Rcode.NOTAUTH())
+        self.conn.reply_data = \
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+            self.conn.create_response_data(rcode=Rcode.NOTAUTH())
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
 
     def test_response_without_end_soa(self):
     def test_response_without_end_soa(self):
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
@@ -850,7 +916,8 @@ class TestAXFR(TestXfrinConnection):
     def test_response_bad_qid(self):
     def test_response_bad_qid(self):
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
 
     def test_response_error_code_bad_sig(self):
     def test_response_error_code_bad_sig(self):
         self.conn._tsig_key = TSIG_KEY
         self.conn._tsig_key = TSIG_KEY
@@ -861,7 +928,7 @@ class TestAXFR(TestXfrinConnection):
                 rcode=Rcode.SERVFAIL())
                 rcode=Rcode.SERVFAIL())
         # xfrin should check TSIG before other part of incoming message
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADSIG",
                                "TSIG verify fail: BADSIG",
                                self.conn._handle_xfrin_responses)
                                self.conn._handle_xfrin_responses)
 
 
@@ -873,7 +940,7 @@ class TestAXFR(TestXfrinConnection):
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
         self.conn.reply_data = self.conn.create_response_data(bad_qid=True)
         # xfrin should check TSIG before other part of incoming message
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADKEY",
                                "TSIG verify fail: BADKEY",
                                self.conn._handle_xfrin_responses)
                                self.conn._handle_xfrin_responses)
 
 
@@ -886,18 +953,21 @@ class TestAXFR(TestXfrinConnection):
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
         self.conn.reply_data = self.conn.create_response_data(
             rcode=Rcode.SERVFAIL())
             rcode=Rcode.SERVFAIL())
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
 
     def test_response_multi_question(self):
     def test_response_multi_question(self):
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(
         self.conn.reply_data = self.conn.create_response_data(
             questions=[example_axfr_question, example_axfr_question])
             questions=[example_axfr_question, example_axfr_question])
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
 
     def test_response_non_response(self):
     def test_response_non_response(self):
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(response = False)
         self.conn.reply_data = self.conn.create_response_data(response = False)
-        self.assertRaises(XfrinException, self.conn._handle_xfrin_responses)
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
 
 
     def test_soacheck(self):
     def test_soacheck(self):
         # we need to defer the creation until we know the QID, which is
         # we need to defer the creation until we know the QID, which is
@@ -912,7 +982,7 @@ class TestAXFR(TestXfrinConnection):
     def test_soacheck_badqid(self):
     def test_soacheck_badqid(self):
         self.soa_response_params['bad_qid'] = True
         self.soa_response_params['bad_qid'] = True
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_bad_qid_bad_sig(self):
     def test_soacheck_bad_qid_bad_sig(self):
         self.conn._tsig_key = TSIG_KEY
         self.conn._tsig_key = TSIG_KEY
@@ -922,19 +992,123 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
         # xfrin should check TSIG before other part of incoming message
         # xfrin should check TSIG before other part of incoming message
         # validate log message for XfrinException
         # validate log message for XfrinException
-        self.__match_exception(XfrinException,
+        self.__match_exception(XfrinProtocolError,
                                "TSIG verify fail: BADSIG",
                                "TSIG verify fail: BADSIG",
                                self.conn._check_soa_serial)
                                self.conn._check_soa_serial)
 
 
     def test_soacheck_non_response(self):
     def test_soacheck_non_response(self):
         self.soa_response_params['response'] = False
         self.soa_response_params['response'] = False
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_error_code(self):
     def test_soacheck_error_code(self):
         self.soa_response_params['rcode'] = Rcode.SERVFAIL()
         self.soa_response_params['rcode'] = Rcode.SERVFAIL()
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_notauth(self):
+        self.soa_response_params['auth'] = False
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate(self):
+        # Primary's SOA serial is identical the local serial
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate2(self):
+        # Primary's SOA serial is "smaller" than the local serial
+        self.soa_response_params['answers'] = [create_soa(1229)]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_uptodate3(self):
+        # Similar to the previous case, but checking the comparison is based
+        # on the serial number arithmetic.
+        self.soa_response_params['answers'] = [create_soa(0xffffffff)]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
+
+    def test_soacheck_newzone(self):
+        # Primary's SOA is 'old', but this secondary doesn't know anything
+        # about the zone yet, so it should accept it.
+        def response_generator():
+            # _request_serial is set in _check_soa_serial().  Reset it here.
+            self.conn._request_serial = None
+            self._create_soa_response_data()
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = response_generator
+        self.assertEqual(XFRIN_OK, self.conn._check_soa_serial())
+
+    def test_soacheck_question_empty(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = []
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_name_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(Name('example.org'),
+                                                          TEST_RRCLASS,
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_class_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          RRClass.CH(),
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_type_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          TEST_RRCLASS,
+                                                          RRType.AAAA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_no_soa(self):
+        # The response just doesn't contain SOA without any other indication
+        # of errors.
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_soa_name_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = [create_soa(1234,
+                                                          Name('example.org'))]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_soa_class_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        soa = RRset(TEST_ZONE_NAME, RRClass.CH(), RRType.SOA(), RRTTL(0))
+        soa.add_rdata(Rdata(RRType.SOA(), RRClass.CH(), 'm. r. 1234 0 0 0 0'))
+        self.soa_response_params['answers'] = [soa]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_multiple_soa(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = [soa_rrset, soa_rrset]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_cname_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        # Add SOA to answer, too, to make sure that it that deceives the parser
+        self.soa_response_params['answers'] = [soa_rrset, create_cname()]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_referral_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.soa_response_params['authorities'] = [create_ns('ns.example.com')]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_nodata_response(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['answers'] = []
+        self.soa_response_params['authorities'] = [soa_rrset]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_with_tsig(self):
     def test_soacheck_with_tsig(self):
         # Use a mock tsig context emulating a validly signed response
         # Use a mock tsig context emulating a validly signed response
@@ -953,7 +1127,7 @@ class TestAXFR(TestXfrinConnection):
         self.soa_response_params['rcode'] = Rcode.NOTAUTH()
         self.soa_response_params['rcode'] = Rcode.NOTAUTH()
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
 
 
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_with_tsig_noerror_badsig(self):
     def test_soacheck_with_tsig_noerror_badsig(self):
         self.conn._tsig_key = TSIG_KEY
         self.conn._tsig_key = TSIG_KEY
@@ -966,7 +1140,7 @@ class TestAXFR(TestXfrinConnection):
         # treat this as a final failure (just as BIND 9 does).
         # treat this as a final failure (just as BIND 9 does).
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
 
 
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_with_tsig_unsigned_response(self):
     def test_soacheck_with_tsig_unsigned_response(self):
         # we can use a real TSIGContext for this.  the response doesn't
         # we can use a real TSIGContext for this.  the response doesn't
@@ -975,14 +1149,14 @@ class TestAXFR(TestXfrinConnection):
         # it as a fatal transaction failure, too.
         # it as a fatal transaction failure, too.
         self.conn._tsig_key = TSIG_KEY
         self.conn._tsig_key = TSIG_KEY
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_soacheck_with_unexpected_tsig_response(self):
     def test_soacheck_with_unexpected_tsig_response(self):
         # we reject unexpected TSIG in responses (following BIND 9's
         # we reject unexpected TSIG in responses (following BIND 9's
         # behavior)
         # behavior)
         self.soa_response_params['tsig'] = True
         self.soa_response_params['tsig'] = True
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
 
     def test_response_shutdown(self):
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
@@ -1244,6 +1418,18 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
         self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
         self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
 
 
+    def test_do_soacheck_uptodate(self):
+        self.soa_response_params['answers'] = [begin_soa_rrset]
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_OK)
+
+    def test_do_soacheck_protocol_error(self):
+        # There are several cases, but at this level it's sufficient to check
+        # only one.  We use the case where there's no SOA in the response.
+        self.soa_response_params['answers'] = []
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertEqual(self.conn.do_xfrin(True), XFRIN_FAIL)
+
     def test_do_soacheck_and_xfrin_with_tsig(self):
     def test_do_soacheck_and_xfrin_with_tsig(self):
         # We are going to have a SOA query/response transaction, followed by
         # We are going to have a SOA query/response transaction, followed by
         # AXFR, all TSIG signed.  xfrin should use a new TSIG context for
         # AXFR, all TSIG signed.  xfrin should use a new TSIG context for
@@ -1276,9 +1462,8 @@ class TestIXFRResponse(TestXfrinConnection):
     def setUp(self):
     def setUp(self):
         super().setUp()
         super().setUp()
         self.conn._query_id = self.conn.qid = 1035
         self.conn._query_id = self.conn.qid = 1035
-        self.conn._request_serial = 1230
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._request_type = RRType.IXFR()
         self.conn._request_type = RRType.IXFR()
-        self._zone_name = TEST_ZONE_NAME
         self.conn._datasrc_client = MockDataSourceClient()
         self.conn._datasrc_client = MockDataSourceClient()
         XfrinInitialSOA().set_xfrstate(self.conn, XfrinInitialSOA())
         XfrinInitialSOA().set_xfrstate(self.conn, XfrinInitialSOA())
 
 
@@ -1353,6 +1538,16 @@ class TestIXFRResponse(TestXfrinConnection):
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
                     self.conn._datasrc_client.committed_diffs)
 
 
+    def test_ixfr_response_uptodate(self):
+        '''IXFR response indicates the zone is new enough'''
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[begin_soa_rrset])
+        self.assertRaises(XfrinZoneUptodate, self.conn._handle_xfrin_responses)
+        # no diffs should have been committed
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
+
     def test_ixfr_response_broken(self):
     def test_ixfr_response_broken(self):
         '''Test with a broken response.
         '''Test with a broken response.
 
 
@@ -1385,6 +1580,22 @@ class TestIXFRResponse(TestXfrinConnection):
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     [[('delete', begin_soa_rrset), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
                     self.conn._datasrc_client.committed_diffs)
 
 
+    def test_ixfr_response_uptodate_extra(self):
+        '''Similar to 'uptodate' test, but with extra bogus data.
+
+        In either case an exception will be raised, but in this case it's
+        considered an error.
+
+        '''
+        self.conn.reply_data = self.conn.create_response_data(
+            questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS, RRType.IXFR())],
+            answers=[begin_soa_rrset, soa_rrset])
+        self.assertRaises(XfrinProtocolError,
+                          self.conn._handle_xfrin_responses)
+        # no diffs should have been committed
+        check_diffs(self.assertEqual,
+                    [], self.conn._datasrc_client.committed_diffs)
+
     def test_ixfr_to_axfr_response(self):
     def test_ixfr_to_axfr_response(self):
         '''AXFR-style IXFR response.
         '''AXFR-style IXFR response.
 
 
@@ -1488,13 +1699,25 @@ class TestIXFRSession(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
         self.conn.response_generator = create_ixfr_response
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
 
-    def test_do_xfrin_fail(self):
+    def test_do_xfrin_fail2(self):
         '''IXFR fails due to a bogus DNS message.
         '''IXFR fails due to a bogus DNS message.
 
 
         '''
         '''
         self._create_broken_response_data()
         self._create_broken_response_data()
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
 
 
+    def test_do_xfrin_uptodate(self):
+        '''IXFR is (gracefully) aborted because serial is not new
+
+        '''
+        def create_response():
+            self.conn.reply_data = self.conn.create_response_data(
+                questions=[Question(TEST_ZONE_NAME, TEST_RRCLASS,
+                                    RRType.IXFR())],
+                answers=[begin_soa_rrset])
+        self.conn.response_generator = create_response
+        self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
+
 class TestXFRSessionWithSQLite3(TestXfrinConnection):
 class TestXFRSessionWithSQLite3(TestXfrinConnection):
     '''Tests for XFR sessions using an SQLite3 DB.
     '''Tests for XFR sessions using an SQLite3 DB.
 
 
@@ -1549,9 +1772,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
         self.conn.response_generator = create_ixfr_response
 
 
         # Confirm xfrin succeeds and SOA is updated
         # Confirm xfrin succeeds and SOA is updated
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
 
 
         # Also confirm the corresponding diffs are stored in the diffs table
         # Also confirm the corresponding diffs are stored in the diffs table
         conn = sqlite3.connect(self.sqlite3db_obj)
         conn = sqlite3.connect(self.sqlite3db_obj)
@@ -1580,12 +1803,12 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                          self._create_soa('1235')])
                          self._create_soa('1235')])
         self.conn.response_generator = create_ixfr_response
         self.conn.response_generator = create_ixfr_response
 
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
 
 
     def test_do_ixfrin_nozone_sqlite3(self):
     def test_do_ixfrin_nozone_sqlite3(self):
-        self.conn._zone_name = Name('nosuchzone.example')
+        self._set_test_zone(Name('nosuchzone.example'))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
         # This should fail even before starting state transition
         # This should fail even before starting state transition
         self.assertEqual(None, self.conn.get_xfrstate())
         self.assertEqual(None, self.conn.get_xfrstate())
@@ -1601,11 +1824,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_response
         self.conn.response_generator = create_response
 
 
         # Confirm xfrin succeeds and SOA is updated, A RR is deleted.
         # Confirm xfrin succeeds and SOA is updated, A RR is deleted.
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
                                           RRType.A()))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, type))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, type))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
                                            RRType.A()))
 
 
@@ -1633,11 +1856,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                 answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
                 answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
         self.conn.response_generator = create_response
         self.conn.response_generator = create_response
 
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
                                           RRType.A()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, type))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, type))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
                                           RRType.A()))
 
 
@@ -1671,11 +1894,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                                     RRType.AXFR())],
                                     RRType.AXFR())],
                 answers=[soa_rrset, self._create_ns(), soa_rrset])
                 answers=[soa_rrset, self._create_ns(), soa_rrset])
         self.conn.response_generator = create_response
         self.conn.response_generator = create_response
-        self.conn._zone_name = Name('example.com')
+        self._set_test_zone(Name('example.com'))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.AXFR()))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.AXFR()))
         self.assertEqual(type(XfrinAXFREnd()),
         self.assertEqual(type(XfrinAXFREnd()),
                          type(self.conn.get_xfrstate()))
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
                                            RRType.A()))
 
 

+ 228 - 109
src/bin/xfrin/xfrin.py.in

@@ -24,6 +24,7 @@ import struct
 import threading
 import threading
 import socket
 import socket
 import random
 import random
+from functools import reduce
 from optparse import OptionParser, OptionValueError
 from optparse import OptionParser, OptionValueError
 from isc.config.ccsession import *
 from isc.config.ccsession import *
 from isc.notify import notify_out
 from isc.notify import notify_out
@@ -75,9 +76,10 @@ DEFAULT_MASTER_PORT = 53
 DEFAULT_ZONE_CLASS = RRClass.IN()
 DEFAULT_ZONE_CLASS = RRClass.IN()
 
 
 __version__ = 'BIND10'
 __version__ = 'BIND10'
-# define xfrin rcode
+
-XFRIN_OK = 0
+# Internal result codes of an xfr session
-XFRIN_FAIL = 1
+XFRIN_OK = 0                    # normal success
+XFRIN_FAIL = 1                  # general failure (internal/external)
 
 
 class XfrinException(Exception):
 class XfrinException(Exception):
     pass
     pass
@@ -87,6 +89,11 @@ class XfrinProtocolError(Exception):
     '''
     '''
     pass
     pass
 
 
+class XfrinZoneUptodate(Exception):
+    '''TBD
+    '''
+    pass
+
 class XfrinZoneInfoException(Exception):
 class XfrinZoneInfoException(Exception):
     """This exception is raised if there is an error in the given
     """This exception is raised if there is an error in the given
        configuration (part), or when a command does not have a required
        configuration (part), or when a command does not have a required
@@ -153,7 +160,7 @@ def format_addrinfo(addrinfo):
                         "appear to be consisting of (family, socktype, (addr, port))")
                         "appear to be consisting of (family, socktype, (addr, port))")
 
 
 def get_soa_serial(soa_rdata):
 def get_soa_serial(soa_rdata):
-    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+    '''Extract the serial field of SOA RDATA and return it as a Serial object.
 
 
     We don't have to be very efficient here, so we first dump the entire RDATA
     We don't have to be very efficient here, so we first dump the entire RDATA
     as a string and convert the first corresponding field.  This should be
     as a string and convert the first corresponding field.  This should be
@@ -162,7 +169,7 @@ def get_soa_serial(soa_rdata):
     should be a more direct and convenient way to get access to the SOA
     should be a more direct and convenient way to get access to the SOA
     fields.
     fields.
     '''
     '''
-    return int(soa_rdata.to_text().split()[2])
+    return Serial(int(soa_rdata.to_text().split()[2]))
 
 
 class XfrinState:
 class XfrinState:
     '''
     '''
@@ -181,12 +188,12 @@ class XfrinState:
                              (AXFR or
                              (AXFR or
             (recv SOA)        AXFR-style IXFR)  (SOA, add)
             (recv SOA)        AXFR-style IXFR)  (SOA, add)
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
     InitialSOA------->FirstData------------->AXFR--------->AXFREnd
-                          |                  |  ^         (post xfr
+         |                |                  |  ^         (post xfr
-                          |                  |  |        checks, then
+         |(IXFR &&        |                  |  |        checks, then
-                          |                  +--+        commit)
+         | recv SOA       |                  +--+        commit)
-                          |            (non SOA, add)
+         | not new)       |            (non SOA, add)
-                          |
+         V                |
-                          |                     (non SOA, delete)
+    IXFRUptodate          |                     (non SOA, delete)
                (pure IXFR,|                           +-------+
                (pure IXFR,|                           +-------+
             keep handling)|             (Delete SOA)  V       |
             keep handling)|             (Delete SOA)  V       |
                           + ->IXFRDeleteSOA------>IXFRDelete--+
                           + ->IXFRDeleteSOA------>IXFRDelete--+
@@ -300,13 +307,14 @@ class XfrinInitialSOA(XfrinState):
                                      + rr.get_type().to_text() + ' received)')
                                      + rr.get_type().to_text() + ' received)')
         conn._end_serial = get_soa_serial(rr.get_rdata()[0])
         conn._end_serial = get_soa_serial(rr.get_rdata()[0])
 
 
-        # FIXME: we need to check the serial is actually greater than ours.
+        if conn._request_type == RRType.IXFR() and \
-        # To do so, however, we need to implement serial number arithmetic.
+                conn._end_serial <= conn._request_serial:
-        # Although it wouldn't be a big task, we'll leave it for a separate
+            logger.info(XFRIN_IXFR_UPTODATE, conn.zone_str(),
-        # task for now.  (Always performing xfr could be inefficient, but
+                        conn._request_serial, conn._end_serial)
-        # shouldn't do any harm otherwise)
+            self.set_xfrstate(conn, XfrinIXFRUptodate())
+        else:
+            self.set_xfrstate(conn, XfrinFirstData())
 
 
-        self.set_xfrstate(conn, XfrinFirstData())
         return True
         return True
 
 
 class XfrinFirstData(XfrinState):
 class XfrinFirstData(XfrinState):
@@ -430,6 +438,14 @@ class XfrinIXFREnd(XfrinState):
         '''
         '''
         return False
         return False
 
 
+class XfrinIXFRUptodate(XfrinState):
+    def handle_rr(self, conn, rr):
+        raise XfrinProtocolError('Extra data after single IXFR response ' +
+                                 rr.to_text())
+
+    def finish_message(self, conn):
+        raise XfrinZoneUptodate
+
 class XfrinAXFR(XfrinState):
 class XfrinAXFR(XfrinState):
     def handle_rr(self, conn, rr):
     def handle_rr(self, conn, rr):
         """
         """
@@ -473,10 +489,13 @@ class XfrinConnection(asyncore.dispatcher):
 
 
     def __init__(self,
     def __init__(self,
                  sock_map, zone_name, rrclass, datasrc_client,
                  sock_map, zone_name, rrclass, datasrc_client,
-                 shutdown_event, master_addrinfo, tsig_key=None,
+                 shutdown_event, master_addrinfo, db_file, tsig_key=None,
                  idle_timeout=60):
                  idle_timeout=60):
         '''Constructor of the XfirnConnection class.
         '''Constructor of the XfirnConnection class.
 
 
+        db_file: SQLite3 DB file.  Unforutnately we still need this for
+                 temporary workaround in _get_zone_soa().  This should be
+                 removed when we eliminate the need for the workaround.
         idle_timeout: max idle time for read data from socket.
         idle_timeout: max idle time for read data from socket.
         datasrc_client: the data source client object used for the XFR session.
         datasrc_client: the data source client object used for the XFR session.
                         This will eventually replace db_file completely.
                         This will eventually replace db_file completely.
@@ -500,7 +519,9 @@ class XfrinConnection(asyncore.dispatcher):
         self._rrclass = rrclass
         self._rrclass = rrclass
 
 
         # Data source handler
         # Data source handler
+        self._db_file = db_file
         self._datasrc_client = datasrc_client
         self._datasrc_client = datasrc_client
+        self._zone_soa = self._get_zone_soa()
 
 
         self._sock_map = sock_map
         self._sock_map = sock_map
         self._soa_rr_count = 0
         self._soa_rr_count = 0
@@ -524,6 +545,55 @@ class XfrinConnection(asyncore.dispatcher):
         self.create_socket(self._master_addrinfo[0], self._master_addrinfo[1])
         self.create_socket(self._master_addrinfo[0], self._master_addrinfo[1])
         self.setblocking(1)
         self.setblocking(1)
 
 
+    def _get_zone_soa(self):
+        '''Retrieve the current SOA RR of the zone to be transferred.
+
+        It will be used for various purposes in subsequent xfr protocol
+        processing.   It is validly possible that the zone is currently
+        empty and therefore doesn't have an SOA, so this method doesn't
+        consider it an error and returns None in such a case.  It may or
+        may not result in failure in the actual processing depending on
+        how the SOA is used.
+
+        When the zone has an SOA RR, this method makes sure that it's
+        valid, i.e., it has exactly one RDATA; if it is not the case
+        this method returns None.
+
+        If the underlying data source doesn't even know the zone, this method
+        tries to provide backward compatible behavior where xfrin is
+        responsible for creating zone in the corresponding DB table.
+        For a longer term we should deprecate this behavior by introducing
+        more generic zone management framework, but at the moment we try
+        to not surprise existing users.  (Note also that the part of
+        providing the compatible behavior uses the old data source API.
+        We'll deprecate this API in a near future, too).
+
+        '''
+        # get the zone finder.  this must be SUCCESS (not even
+        # PARTIALMATCH) because we are specifying the zone origin name.
+        result, finder = self._datasrc_client.find_zone(self._zone_name)
+        if result != DataSourceClient.SUCCESS:
+            # The data source doesn't know the zone.  For now, we provide
+            # backward compatibility and creates a new one ourselves.
+            isc.datasrc.sqlite3_ds.load(self._db_file,
+                                        self._zone_name.to_text(),
+                                        lambda : [])
+            logger.warn(XFRIN_ZONE_CREATED, self.zone_str())
+            # try again
+            result, finder = self._datasrc_client.find_zone(self._zone_name)
+        if result != DataSourceClient.SUCCESS:
+            return None
+        result, soa_rrset = finder.find(self._zone_name, RRType.SOA(),
+                                        None, ZoneFinder.FIND_DEFAULT)
+        if result != ZoneFinder.SUCCESS:
+            logger.info(XFRIN_ZONE_NO_SOA, self.zone_str())
+            return None
+        if soa_rrset.get_rdata_count() != 1:
+            logger.warn(XFRIN_ZONE_MULTIPLE_SOA, self.zone_str(),
+                        soa_rrset.get_rdata_count())
+            return None
+        return soa_rrset
+
     def __set_xfrstate(self, new_state):
     def __set_xfrstate(self, new_state):
         self.__state = new_state
         self.__state = new_state
 
 
@@ -545,37 +615,16 @@ class XfrinConnection(asyncore.dispatcher):
                          str(e))
                          str(e))
             return False
             return False
 
 
-    def _get_zone_soa(self):
-        result, finder = self._datasrc_client.find_zone(self._zone_name)
-        if result != DataSourceClient.SUCCESS:
-            raise XfrinException('Zone not found in the given data ' +
-                                 'source: ' + self.zone_str())
-        result, soa_rrset = finder.find(self._zone_name, RRType.SOA(),
-                                        None, ZoneFinder.FIND_DEFAULT)
-        if result != ZoneFinder.SUCCESS:
-            raise XfrinException('SOA RR not found in zone: ' +
-                                 self.zone_str())
-        # Especially for database-based zones, a working zone may be in
-        # a broken state where it has more than one SOA RR.  We proactively
-        # check the condition and abort the xfr attempt if we identify it.
-        if soa_rrset.get_rdata_count() != 1:
-            raise XfrinException('Invalid number of SOA RRs for ' +
-                                 self.zone_str() + ': ' +
-                                 str(soa_rrset.get_rdata_count()))
-        return soa_rrset
-
     def _create_query(self, query_type):
     def _create_query(self, query_type):
         '''Create an XFR-related query message.
         '''Create an XFR-related query message.
 
 
-        query_type is either SOA, AXFR or IXFR.  For type IXFR, it searches
+        query_type is either SOA, AXFR or IXFR.  An IXFR query needs the
-        the associated data source for the current SOA record to include
+        zone's current SOA record.  If it's not known, it raises an
-        it in the query.  If the corresponding zone or the SOA record
+        XfrinException exception.  Note that this may not necessarily a
-        cannot be found, it raises an XfrinException exception.  Note that
+        broken configuration; for the first attempt of transfer the secondary
-        this may not necessarily a broken configuration; for the first attempt
+        may not have any boot-strap zone information, in which case IXFR
-        of transfer the secondary may not have any boot-strap zone
+        simply won't work.  The xfrin should then fall back to AXFR.
-        information, in which case IXFR simply won't work.  The xfrin
+        _request_serial is recorded for later use.
-        should then fall back to AXFR.  _request_serial is recorded for
-        later use.
 
 
         '''
         '''
         msg = Message(Message.RENDER)
         msg = Message(Message.RENDER)
@@ -585,27 +634,19 @@ class XfrinConnection(asyncore.dispatcher):
         msg.set_opcode(Opcode.QUERY())
         msg.set_opcode(Opcode.QUERY())
         msg.set_rcode(Rcode.NOERROR())
         msg.set_rcode(Rcode.NOERROR())
         msg.add_question(Question(self._zone_name, self._rrclass, query_type))
         msg.add_question(Question(self._zone_name, self._rrclass, query_type))
+
+        # Remember our serial, if known
+        self._request_serial = get_soa_serial(self._zone_soa.get_rdata()[0]) \
+            if self._zone_soa is not None else None
+
+        # Set the authority section with our SOA for IXFR
         if query_type == RRType.IXFR():
         if query_type == RRType.IXFR():
-            # get the zone finder.  this must be SUCCESS (not even
+            if self._zone_soa is None:
-            # PARTIALMATCH) because we are specifying the zone origin name.
+                # (incremental) IXFR doesn't work without known SOA
-            zone_soa_rr = self._get_zone_soa()
+                raise XfrinException('Failed to create IXFR query due to no ' +
-            msg.add_rrset(Message.SECTION_AUTHORITY, zone_soa_rr)
+                                     'SOA for ' + self.zone_str())
-            self._request_serial = get_soa_serial(zone_soa_rr.get_rdata()[0])
+            msg.add_rrset(Message.SECTION_AUTHORITY, self._zone_soa)
-        else:
+
-            # For AXFR, we temporarily provide backward compatible behavior
-            # where xfrin is responsible for creating zone in the corresponding
-            # DB table.  Note that the code below uses the old data source
-            # API and assumes SQLite3 in an ugly manner.  We'll have to
-            # develop a better way of managing zones in a generic way and
-            # eliminate the code like the one here.
-            try:
-                self._get_zone_soa()
-            except XfrinException:
-                def empty_rr_generator():
-                    return []
-                isc.datasrc.sqlite3_ds.load(self._db_file,
-                                            self._zone_name.to_text(),
-                                            empty_rr_generator)
         return msg
         return msg
 
 
     def _send_data(self, data):
     def _send_data(self, data):
@@ -659,7 +700,8 @@ class XfrinConnection(asyncore.dispatcher):
         if self._tsig_ctx is not None:
         if self._tsig_ctx is not None:
             tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
             tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
             if tsig_error != TSIGError.NOERROR:
             if tsig_error != TSIGError.NOERROR:
-                raise XfrinException('TSIG verify fail: %s' % str(tsig_error))
+                raise XfrinProtocolError('TSIG verify fail: %s' %
+                                         str(tsig_error))
         elif tsig_record is not None:
         elif tsig_record is not None:
             # If the response includes a TSIG while we didn't sign the query,
             # If the response includes a TSIG while we didn't sign the query,
             # we treat it as an error.  RFC doesn't say anything about this
             # we treat it as an error.  RFC doesn't say anything about this
@@ -668,13 +710,78 @@ class XfrinConnection(asyncore.dispatcher):
             # implementation would return such a response, and since this is
             # implementation would return such a response, and since this is
             # part of security mechanism, it's probably better to be more
             # part of security mechanism, it's probably better to be more
             # strict.
             # strict.
-            raise XfrinException('Unexpected TSIG in response')
+            raise XfrinProtocolError('Unexpected TSIG in response')
+
+    def __parse_soa_response(self, msg, response_data):
+        '''Parse a response to SOA query and extract the SOA from answer.
+
+        This is a subroutine of _check_soa_serial().  This method also
+        validates message, and rejects bogus responses with XfrinProtocolError.
+
+        If everything is okay, it returns the SOA RR from the answer section
+        of the response.
+
+        '''
+        # Check TSIG integrity and validate the header.  Unlike AXFR/IXFR,
+        # we should be more strict for SOA queries and check the AA flag, too.
+        self._check_response_tsig(msg, response_data)
+        self._check_response_header(msg)
+        if not msg.get_header_flag(Message.HEADERFLAG_AA):
+            raise XfrinProtocolError('non-authoritative answer to SOA query')
+
+        # Validate the question section
+        n_question = msg.get_rr_count(Message.SECTION_QUESTION)
+        if n_question != 1:
+            raise XfrinProtocolError('Invalid response to SOA query: ' +
+                                     '(' + str(n_question) + ' questions, 1 ' +
+                                     'expected)')
+        resp_question = msg.get_question()[0]
+        if resp_question.get_name() != self._zone_name or \
+                resp_question.get_class() != self._rrclass or \
+                resp_question.get_type() != RRType.SOA():
+            raise XfrinProtocolError('Invalid response to SOA query: '
+                                     'question mismatch: ' +
+                                     str(resp_question))
+
+        # Look into the answer section for SOA
+        soa = None
+        for rr in msg.get_section(Message.SECTION_ANSWER):
+            if rr.get_type() == RRType.SOA():
+                if soa is not None:
+                    raise XfrinProtocolError('SOA response had multiple SOAs')
+                soa = rr
+            # There should not be a CNAME record at top of zone.
+            if rr.get_type() == RRType.CNAME():
+                raise XfrinProtocolError('SOA query resulted in CNAME')
+
+        # If SOA is not found, try to figure out the reason then report it.
+        if soa is None:
+            # See if we have any SOA records in the authority section.
+            for rr in msg.get_section(Message.SECTION_AUTHORITY):
+                if rr.get_type() == RRType.NS():
+                    raise XfrinProtocolError('SOA query resulted in referral')
+                if rr.get_type() == RRType.SOA():
+                    raise XfrinProtocolError('SOA query resulted in NODATA')
+            raise XfrinProtocolError('No SOA record found in response to ' +
+                                     'SOA query')
+
+        # Check if the SOA is really what we asked for
+        if soa.get_name() != self._zone_name or \
+                soa.get_class() != self._rrclass:
+            raise XfrinProtocolError("SOA response doesn't match query: " +
+                                     str(soa))
+
+        # All okay, return it
+        return soa
+
 
 
     def _check_soa_serial(self):
     def _check_soa_serial(self):
-        ''' Compare the soa serial, if soa serial in master is less than
+        '''Send SOA query and compare the local and remote serials.
-        the soa serial in local, Finish xfrin.
+
-        False: soa serial in master is less or equal to the local one.
+        If we know our local serial and the remote serial isn't newer
-        True: soa serial in master is bigger
+        than ours, we abort the session with XfrinZoneUptodate.
+        On success it returns XFRIN_OK for testing.  The caller won't use it.
+
         '''
         '''
 
 
         self._send_query(RRType.SOA())
         self._send_query(RRType.SOA())
@@ -682,18 +789,23 @@ class XfrinConnection(asyncore.dispatcher):
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         soa_response = self._get_request_response(msg_len)
         soa_response = self._get_request_response(msg_len)
         msg = Message(Message.PARSE)
         msg = Message(Message.PARSE)
-        msg.from_wire(soa_response)
+        msg.from_wire(soa_response, Message.PRESERVE_ORDER)
+
+        # Validate/parse the rest of the response, and extract the SOA
+        # from the answer section
+        soa = self.__parse_soa_response(msg, soa_response)
+
+        # Compare the two serials.  If ours is 'new', abort with ZoneUptodate.
+        primary_serial = get_soa_serial(soa.get_rdata()[0])
+        if self._request_serial is not None and \
+                self._request_serial >= primary_serial:
+            if self._request_serial != primary_serial:
+                logger.info(XFRIN_ZONE_SERIAL_AHEAD, primary_serial,
+                            self.zone_str(),
+                            format_addrinfo(self._master_addrinfo),
+                            self._request_serial)
+            raise XfrinZoneUptodate
 
 
-        # TSIG related checks, including an unexpected signed response
-        self._check_response_tsig(msg, soa_response)
-
-        # perform some minimal level validation.  It's an open issue how
-        # strict we should be (see the comment in _check_response_header())
-        self._check_response_header(msg)
-
-        # TODO, need select soa record from data source then compare the two
-        # serial, current just return OK, since this function hasn't been used
-        # now.
         return XFRIN_OK
         return XFRIN_OK
 
 
     def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
     def do_xfrin(self, check_soa, request_type=RRType.AXFR()):
@@ -704,22 +816,30 @@ class XfrinConnection(asyncore.dispatcher):
             self._request_type = request_type
             self._request_type = request_type
             # Right now RRType.[IA]XFR().to_text() is 'TYPExxx', so we need
             # Right now RRType.[IA]XFR().to_text() is 'TYPExxx', so we need
             # to hardcode here.
             # to hardcode here.
-            request_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
+            req_str = 'IXFR' if request_type == RRType.IXFR() else 'AXFR'
             if check_soa:
             if check_soa:
-                ret =  self._check_soa_serial()
+                self._check_soa_serial()
-
+
-            if ret == XFRIN_OK:
+            logger.info(XFRIN_XFR_TRANSFER_STARTED, req_str, self.zone_str())
-                logger.info(XFRIN_XFR_TRANSFER_STARTED, request_str,
+            self._send_query(self._request_type)
-                            self.zone_str())
+            self.__state = XfrinInitialSOA()
-                self._send_query(self._request_type)
+            self._handle_xfrin_responses()
-                self.__state = XfrinInitialSOA()
+            logger.info(XFRIN_XFR_TRANSFER_SUCCESS, req_str, self.zone_str())
-                self._handle_xfrin_responses()
+
-                logger.info(XFRIN_XFR_TRANSFER_SUCCESS, request_str,
+        except XfrinZoneUptodate:
-                            self.zone_str())
+            # Eventually we'll probably have to treat this case as a trigger
-
+            # of trying another primary server, etc, but for now we treat it
-        except (XfrinException, XfrinProtocolError) as e:
+            # as "success".
-            logger.error(XFRIN_XFR_TRANSFER_FAILURE, request_str,
+            pass
-                         self.zone_str(), str(e))
+        except XfrinProtocolError as e:
+            logger.info(XFRIN_XFR_TRANSFER_PROTOCOL_ERROR, req_str,
+                        self.zone_str(),
+                        format_addrinfo(self._master_addrinfo), str(e))
+            ret = XFRIN_FAIL
+        except XfrinException as e:
+            logger.error(XFRIN_XFR_TRANSFER_FAILURE, req_str,
+                         self.zone_str(),
+                         format_addrinfo(self._master_addrinfo), str(e))
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
         except Exception as e:
         except Exception as e:
             # Catching all possible exceptions like this is generally not a
             # Catching all possible exceptions like this is generally not a
@@ -730,7 +850,7 @@ class XfrinConnection(asyncore.dispatcher):
             # catch it here, but until then we need broadest coverage so that
             # catch it here, but until then we need broadest coverage so that
             # we won't miss anything.
             # we won't miss anything.
 
 
-            logger.error(XFRIN_XFR_OTHER_FAILURE, request_str,
+            logger.error(XFRIN_XFR_OTHER_FAILURE, req_str,
                          self.zone_str(), str(e))
                          self.zone_str(), str(e))
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
         finally:
         finally:
@@ -754,13 +874,14 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         msg_rcode = msg.get_rcode()
         msg_rcode = msg.get_rcode()
         if msg_rcode != Rcode.NOERROR():
         if msg_rcode != Rcode.NOERROR():
-            raise XfrinException('error response: %s' % msg_rcode.to_text())
+            raise XfrinProtocolError('error response: %s' %
+                                     msg_rcode.to_text())
 
 
         if not msg.get_header_flag(Message.HEADERFLAG_QR):
         if not msg.get_header_flag(Message.HEADERFLAG_QR):
-            raise XfrinException('response is not a response')
+            raise XfrinProtocolError('response is not a response')
 
 
         if msg.get_qid() != self._query_id:
         if msg.get_qid() != self._query_id:
-            raise XfrinException('bad query id')
+            raise XfrinProtocolError('bad query id')
 
 
     def _check_response_status(self, msg):
     def _check_response_status(self, msg):
         '''Check validation of xfr response. '''
         '''Check validation of xfr response. '''
@@ -768,7 +889,7 @@ class XfrinConnection(asyncore.dispatcher):
         self._check_response_header(msg)
         self._check_response_header(msg)
 
 
         if msg.get_rr_count(Message.SECTION_QUESTION) > 1:
         if msg.get_rr_count(Message.SECTION_QUESTION) > 1:
-            raise XfrinException('query section count greater than 1')
+            raise XfrinProtocolError('query section count greater than 1')
 
 
     def _handle_xfrin_responses(self):
     def _handle_xfrin_responses(self):
         read_next_msg = True
         read_next_msg = True
@@ -808,8 +929,8 @@ class XfrinConnection(asyncore.dispatcher):
         return False
         return False
 
 
 def __process_xfrin(server, zone_name, rrclass, db_file,
 def __process_xfrin(server, zone_name, rrclass, db_file,
-                  shutdown_event, master_addrinfo, check_soa, tsig_key,
+                    shutdown_event, master_addrinfo, check_soa, tsig_key,
-                  request_type, conn_class):
+                    request_type, conn_class):
     conn = None
     conn = None
     exception = None
     exception = None
     ret = XFRIN_FAIL
     ret = XFRIN_FAIL
@@ -840,11 +961,9 @@ def __process_xfrin(server, zone_name, rrclass, db_file,
         while retry:
         while retry:
             retry = False
             retry = False
             conn = conn_class(sock_map, zone_name, rrclass, datasrc_client,
             conn = conn_class(sock_map, zone_name, rrclass, datasrc_client,
-                              shutdown_event, master_addrinfo, tsig_key)
+                              shutdown_event, master_addrinfo, db_file,
+                              tsig_key)
             conn.init_socket()
             conn.init_socket()
-            # XXX: We still need _db_file for temporary workaround in _create_query().
-            # This should be removed when we eliminate the need for the workaround.
-            conn._db_file = db_file
             ret = XFRIN_FAIL
             ret = XFRIN_FAIL
             if conn.connect_to_master():
             if conn.connect_to_master():
                 ret = conn.do_xfrin(check_soa, request_type)
                 ret = conn.do_xfrin(check_soa, request_type)

+ 62 - 7
src/bin/xfrin/xfrin_messages.mes

@@ -15,18 +15,63 @@
 # No namespace declaration - these constants go in the global namespace
 # No namespace declaration - these constants go in the global namespace
 # of the xfrin messages python module.
 # of the xfrin messages python module.
 
 
+% XFRIN_ZONE_CREATED Zone %1 not found in the given data source, newly created
+On starting an xfrin session, it is identified that the zone to be
+transferred is not found in the data source.  This can happen if a
+secondary DNS server first tries to perform AXFR from a primary server
+without creating the zone image beforehand (e.g. by b10-loadzone).  As
+of this writing the xfrin process provides backward compatible
+behavior to previous versions: creating a new one in the data source
+not to surprise existing users too much.  This is probably not a good
+idea, however, in terms of who should be responsible for managing
+zones at a higher level.  In future it is more likely that a separate
+zone management framework is provided, and the situation where the
+given zone isn't found in xfrout will be treated as an error.
+
+% XFRIN_ZONE_NO_SOA Zone %1 does not have SOA
+On starting an xfrin session, it is identified that the zone to be
+transferred does not have an SOA RR in the data source.  This is not
+necessarily an error; if a secondary DNS server first tries to perform
+transfer from a primary server, the zone can be empty, and therefore
+doesn't have an SOA.  Subsequent AXFR will fill in the zone; if the
+attempt is IXFR it will fail in query creation.
+
+% XFRIN_ZONE_MULTIPLE_SOA Zone %1 has %2 SOA RRs
+On starting an xfrin session, it is identified that the zone to be
+transferred has multiple SOA RRs.  Such a zone is broken, but could be
+accidentally configured especially in a data source using "non
+captive" backend database.  The implementation ignores entire SOA RRs
+and tries to continue processing as if the zone were empty.  This
+means subsequent AXFR can succeed and possibly replace the zone with
+valid content, but an IXFR attempt will fail.
+
+% XFRIN_ZONE_SERIAL_AHEAD Serial number (%1) for %2 received from master %3 < ours (%4)
+The response to an SOA query prior to xfr indicated that the zone's
+SOA serial at the primary server is smaller than that of the xfrin
+client.  This is not necessarily an error especially if that
+particular primary server is another secondary server which hasn't got
+the latest version of the zone.  But if the primary server is known to
+be the real source of the zone, some unexpected inconsistency may have
+happened, and you may want to take a closer look.  In this case xfrin
+doesn't perform subsequent zone transfer.
+
 % XFRIN_XFR_OTHER_FAILURE %1 transfer of zone %2 failed: %3
 % XFRIN_XFR_OTHER_FAILURE %1 transfer of zone %2 failed: %3
 The XFR transfer for the given zone has failed due to a problem outside
 The XFR transfer for the given zone has failed due to a problem outside
 of the xfrin module.  Possible reasons are a broken DNS message or failure
 of the xfrin module.  Possible reasons are a broken DNS message or failure
 in database connection.  The error is shown in the log message.
 in database connection.  The error is shown in the log message.
 
 
-% XFRIN_AXFR_DATABASE_FAILURE AXFR transfer of zone %1 failed: %2
+% XFRIN_XFR_TRANSFER_PROTOCOL_ERROR %1 transfer of zone %2 with %3 failed: %4
-The AXFR transfer for the given zone has failed due to a database problem.
+The XFR transfer for the given zone has failed due to a protocol
-The error is shown in the log message.  Note: due to the code structure
+error, such as an unexpected response from the primary server.  The
-this can only happen for AXFR.
+error is shown in the log message.  It may be because the primary
-
+server implementation is broken or (although less likely) there was
-% XFRIN_XFR_TRANSFER_FAILURE %1 transfer of zone %2 failed: %3
+some attack attempt, but it can also happen due to configuration
-The XFR transfer for the given zone has failed due to a protocol error.
+mismatch such as the remote server does not have authority for the
+zone any more but the local configuration hasn't been updated.  So it
+is recommended to check the primary server configuration.
+
+% XFRIN_XFR_TRANSFER_FAILURE %1 transfer of zone %2 with %3 failed: %4
+The XFR transfer for the given zone has failed due to an internal error.
 The error is shown in the log message.
 The error is shown in the log message.
 
 
 % XFRIN_XFR_TRANSFER_FALLBACK falling back from IXFR to AXFR for %1
 % XFRIN_XFR_TRANSFER_FALLBACK falling back from IXFR to AXFR for %1
@@ -118,6 +163,16 @@ daemon will now shut down.
 An uncaught exception was raised while running the xfrin daemon. The
 An uncaught exception was raised while running the xfrin daemon. The
 exception message is printed in the log message.
 exception message is printed in the log message.
 
 
+% XFRIN_IXFR_UPTODATE IXFR requested serial for %1 is %2, master has %3, not updating
+The first SOA record in an IXFR response indicates the zone's serial
+at the primary server is not newer than the client's.  This is
+basically unexpected event because normally the client first checks
+the SOA serial by an SOA query, but can still happen if the transfer
+is manually invoked or (although unlikely) there is a rapid change at
+the primary server between the SOA and IXFR queries.  The client
+implementation confirms the whole response is this single SOA, and
+aborts the transfer just like a successful case.
+
 % XFRIN_GOT_INCREMENTAL_RESP got incremental response for %1
 % XFRIN_GOT_INCREMENTAL_RESP got incremental response for %1
 In an attempt of IXFR processing, the begenning SOA of the first difference
 In an attempt of IXFR processing, the begenning SOA of the first difference
 (following the initial SOA that specified the final SOA for all the
 (following the initial SOA that specified the final SOA for all the

+ 2 - 0
src/lib/dns/python/tests/serial_python_test.py

@@ -77,9 +77,11 @@ class SerialTest(unittest.TestCase):
         self.assertLessEqual(self.one, self.one)
         self.assertLessEqual(self.one, self.one)
         self.assertLessEqual(self.one, self.one_2)
         self.assertLessEqual(self.one, self.one_2)
         self.assertLess(self.one, self.two)
         self.assertLess(self.one, self.two)
+        self.assertLessEqual(self.one, self.one)
         self.assertLessEqual(self.one, self.two)
         self.assertLessEqual(self.one, self.two)
         self.assertGreater(self.two, self.one)
         self.assertGreater(self.two, self.one)
         self.assertGreaterEqual(self.two, self.two)
         self.assertGreaterEqual(self.two, self.two)
+        self.assertGreaterEqual(self.two, self.one)
         self.assertLess(self.one, self.number_low)
         self.assertLess(self.one, self.number_low)
         self.assertLess(self.number_low, self.number_medium)
         self.assertLess(self.number_low, self.number_medium)
         self.assertLess(self.number_medium, self.number_high)
         self.assertLess(self.number_medium, self.number_high)

+ 1 - 1
src/lib/dns/serial.cc

@@ -51,7 +51,7 @@ Serial::operator>(const Serial& other) const {
 
 
 bool
 bool
 Serial::operator>=(const Serial& other) const {
 Serial::operator>=(const Serial& other) const {
-    return (operator==(other) || !operator>(other));
+    return (!operator<(other));
 }
 }
 
 
 Serial
 Serial

+ 1 - 1
src/lib/dns/tests/serial_unittest.cc

@@ -49,7 +49,6 @@ TEST_F(SerialTest, get_value) {
 
 
 TEST_F(SerialTest, equals) {
 TEST_F(SerialTest, equals) {
     EXPECT_EQ(one, one);
     EXPECT_EQ(one, one);
-    EXPECT_EQ(one, one);
     EXPECT_EQ(one, one_2);
     EXPECT_EQ(one, one_2);
     EXPECT_NE(one, two);
     EXPECT_NE(one, two);
     EXPECT_NE(two, one);
     EXPECT_NE(two, one);
@@ -65,6 +64,7 @@ TEST_F(SerialTest, comparison) {
     EXPECT_LE(one, two);
     EXPECT_LE(one, two);
     EXPECT_GE(two, two);
     EXPECT_GE(two, two);
     EXPECT_GT(two, one);
     EXPECT_GT(two, one);
+    EXPECT_GE(two, one);
     EXPECT_LT(one, number_low);
     EXPECT_LT(one, number_low);
     EXPECT_LT(number_low, number_medium);
     EXPECT_LT(number_low, number_medium);
     EXPECT_LT(number_medium, number_high);
     EXPECT_LT(number_medium, number_high);

+ 6 - 0
src/lib/python/isc/testutils/rrset_utils.py

@@ -53,6 +53,12 @@ def create_ns(nsname, name=Name('example.com'), ttl=3600):
     rrset.add_rdata(Rdata(RRType.NS(), RRClass.IN(), nsname))
     rrset.add_rdata(Rdata(RRType.NS(), RRClass.IN(), nsname))
     return rrset
     return rrset
 
 
+def create_cname(target='target.example.com', name=Name('example.com'),
+                 ttl=3600):
+    rrset = RRset(name, RRClass.IN(), RRType.CNAME(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.CNAME(), RRClass.IN(), target))
+    return rrset
+
 def create_generic(name, rdlen, type=RRType('TYPE65300'), ttl=3600):
 def create_generic(name, rdlen, type=RRType('TYPE65300'), ttl=3600):
     '''Create an RR of a general type with an arbitrary length of RDATA
     '''Create an RR of a general type with an arbitrary length of RDATA