Browse Source

[1299] pre-work update: use Serial object for SOA serial throughout the code
and tests.

JINMEI Tatuya 13 years ago
parent
commit
6764e7da0d
2 changed files with 22 additions and 21 deletions
  1. 20 19
      src/bin/xfrin/tests/xfrin_test.py
  2. 2 2
      src/bin/xfrin/xfrin.py.in

+ 20 - 19
src/bin/xfrin/tests/xfrin_test.py

@@ -342,7 +342,7 @@ class TestXfrinInitialSOA(TestXfrinState):
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinFirstData()),
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.conn._end_serial)
+        self.assertEqual(1234, self.conn._end_serial.get_value())
 
     def test_handle_not_soa(self):
         # The given RR is not of SOA
@@ -357,7 +357,8 @@ class TestXfrinFirstData(TestXfrinState):
         super().setUp()
         self.state = XfrinFirstData()
         self.conn._request_type = RRType.IXFR()
-        self.conn._request_serial = 1230 # arbitrary chosen serial < 1234
+        # arbitrary chosen serial < 1234:
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._diff = None           # should be replaced in the AXFR case
 
     def test_handle_ixfr_begin_soa(self):
@@ -437,7 +438,7 @@ class TestXfrinIXFRDelete(TestXfrinState):
         # false.
         self.assertFalse(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual([], self.conn._diff.get_buffer())
-        self.assertEqual(1234, self.conn._current_serial)
+        self.assertEqual(1234, self.conn._current_serial.get_value())
         self.assertEqual(type(XfrinIXFRAddSOA()),
                          type(self.conn.get_xfrstate()))
 
@@ -468,7 +469,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         # We need record the state in 'conn' to check the case where the
         # state doesn't change.
         XfrinIXFRAdd().set_xfrstate(self.conn, XfrinIXFRAdd())
-        self.conn._current_serial = 1230
+        self.conn._current_serial = isc.dns.Serial(1230)
         self.state = self.conn.get_xfrstate()
 
     def test_handle_add_rr(self):
@@ -480,7 +481,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual(type(XfrinIXFRAdd()), type(self.conn.get_xfrstate()))
 
     def test_handle_end_soa(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
         self.assertTrue(self.state.handle_rr(self.conn, soa_rrset))
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
@@ -489,7 +490,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual([], self.conn._diff.get_buffer())
 
     def test_handle_new_delete(self):
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
         # SOA RR whose serial is the current one means we are going to a new
         # difference, starting with removing that SOA.
         self.conn._diff.add_data(self.ns_rrset) # put some dummy change
@@ -500,7 +501,7 @@ class TestXfrinIXFRAdd(TestXfrinState):
 
     def test_handle_out_of_sync(self):
         # getting SOA with an inconsistent serial.  This is an error.
-        self.conn._end_serial = 1235
+        self.conn._end_serial = isc.dns.Serial(1235)
         self.assertRaises(XfrinProtocolError, self.state.handle_rr,
                           self.conn, soa_rrset)
 
@@ -523,7 +524,7 @@ class TestXfrinAXFR(TestXfrinState):
     def setUp(self):
         super().setUp()
         self.state = XfrinAXFR()
-        self.conn._end_serial = 1234
+        self.conn._end_serial = isc.dns.Serial(1234)
 
     def test_handle_rr(self):
         """
@@ -781,7 +782,7 @@ class TestAXFR(TestXfrinConnection):
         # IXFR query
         msg = self.conn._create_query(RRType.IXFR())
         check_query(RRType.IXFR(), begin_soa_rrset)
-        self.assertEqual(1230, self.conn._request_serial)
+        self.assertEqual(1230, self.conn._request_serial.get_value())
 
     def test_create_ixfr_query_fail(self):
         # In these cases _create_query() will fail to find a valid SOA RR to
@@ -1270,7 +1271,7 @@ class TestIXFRResponse(TestXfrinConnection):
     def setUp(self):
         super().setUp()
         self.conn._query_id = self.conn.qid = 1035
-        self.conn._request_serial = 1230
+        self.conn._request_serial = isc.dns.Serial(1230)
         self.conn._request_type = RRType.IXFR()
         self._zone_name = TEST_ZONE_NAME
         self.conn._datasrc_client = MockDataSourceClient()
@@ -1543,9 +1544,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_ixfr_response
 
         # Confirm xfrin succeeds and SOA is updated
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
 
         # Also confirm the corresponding diffs are stored in the diffs table
         conn = sqlite3.connect(self.sqlite3db_obj)
@@ -1574,9 +1575,9 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                          self._create_soa('1235')])
         self.conn.response_generator = create_ixfr_response
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, RRType.IXFR()))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
 
     def test_do_ixfrin_nozone_sqlite3(self):
         self.conn._zone_name = Name('nosuchzone.example')
@@ -1595,11 +1596,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.conn.response_generator = create_response
 
         # Confirm xfrin succeeds and SOA is updated, A RR is deleted.
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, type))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
 
@@ -1627,11 +1628,11 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
                 answers=[soa_rrset, self._create_ns(), soa_rrset, soa_rrset])
         self.conn.response_generator = create_response
 
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
         self.assertEqual(XFRIN_FAIL, self.conn.do_xfrin(False, type))
-        self.assertEqual(1230, self.get_zone_serial())
+        self.assertEqual(1230, self.get_zone_serial().get_value())
         self.assertTrue(self.record_exist(Name('dns01.example.com'),
                                           RRType.A()))
 
@@ -1669,7 +1670,7 @@ class TestXFRSessionWithSQLite3(TestXfrinConnection):
         self.assertEqual(XFRIN_OK, self.conn.do_xfrin(False, RRType.AXFR()))
         self.assertEqual(type(XfrinAXFREnd()),
                          type(self.conn.get_xfrstate()))
-        self.assertEqual(1234, self.get_zone_serial())
+        self.assertEqual(1234, self.get_zone_serial().get_value())
         self.assertFalse(self.record_exist(Name('dns01.example.com'),
                                            RRType.A()))
 

+ 2 - 2
src/bin/xfrin/xfrin.py.in

@@ -153,7 +153,7 @@ def format_addrinfo(addrinfo):
                         "appear to be consisting of (family, socktype, (addr, port))")
 
 def get_soa_serial(soa_rdata):
-    '''Extract the serial field of an SOA RDATA and returns it as an intger.
+    '''Extract the serial field of SOA RDATA and return it as a Serial object.
 
     We don't have to be very efficient here, so we first dump the entire RDATA
     as a string and convert the first corresponding field.  This should be
@@ -162,7 +162,7 @@ def get_soa_serial(soa_rdata):
     should be a more direct and convenient way to get access to the SOA
     fields.
     '''
-    return int(soa_rdata.to_text().split()[2])
+    return Serial(int(soa_rdata.to_text().split()[2]))
 
 class XfrinState:
     '''