Browse Source

[1371] a small cleanup: introduced some convenient functions test-module wide.

JINMEI Tatuya 13 years ago
parent
commit
c59bb2dcd9
1 changed files with 34 additions and 22 deletions
  1. 34 22
      src/bin/xfrout/tests/xfrout_test.py.in

+ 34 - 22
src/bin/xfrout/tests/xfrout_test.py.in

@@ -38,13 +38,26 @@ TEST_ZONE_NAME = Name(TEST_ZONE_NAME_STR)
 TEST_RRCLASS = RRClass.IN()
 TEST_RRCLASS = RRClass.IN()
 IXFR_OK_VERSION = 2011111802
 IXFR_OK_VERSION = 2011111802
 IXFR_NG_VERSION = 2011112800
 IXFR_NG_VERSION = 2011112800
-
-# SOA intended to be used for the new SOA as a result of transfer.
-soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
-                  'master.example.com. admin.example.com ' +
-                  '2011112001 3600 1800 2419200 7200')
-soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(), RRTTL(3600))
-soa_rrset.add_rdata(soa_rdata)
+SOA_CURRENT_VERSION = 2011111802
+
+# Shortcut functions to create RRsets commonly used in tests below.
+def create_a(address, ttl=3600):
+    rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), address))
+    return rrset
+
+def create_aaaa(address):
+    rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.AAAA(),
+                  RRTTL(3600))
+    rrset.add_rdata(Rdata(RRType.AAAA(), RRClass.IN(), address))
+    return rrset
+
+def create_soa(serial):
+    rrset = RRset(TEST_ZONE_NAME, RRClass.IN(), RRType.SOA(), RRTTL(3600))
+    rdata_str = 'master.example.com. admin.example.com. ' + \
+        str(serial) + ' 3600 1800 2419200 7200'
+    rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(), rdata_str))
+    return rrset
 
 
 # our fake socket, where we can read and insert messages
 # our fake socket, where we can read and insert messages
 class MySocket():
 class MySocket():
@@ -85,12 +98,6 @@ class MockDataSrcClient:
     def __init__(self, type, config):
     def __init__(self, type, config):
         pass
         pass
 
 
-    def __create_soa(self):
-        soa_rrset = RRset(self._zone_name, RRClass.IN(), RRType.SOA(),
-                          RRTTL(3600))
-        soa_rrset.add_rdata(soa_rdata)
-        return soa_rrset
-
     def find_zone(self, zone_name):
     def find_zone(self, zone_name):
         '''Mock version of find_zone().
         '''Mock version of find_zone().
 
 
@@ -113,15 +120,15 @@ class MockDataSrcClient:
 
 
         '''
         '''
         if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
         if name == TEST_ZONE_NAME and rrtype == RRType.SOA():
-            return (ZoneFinder.SUCCESS, self.__create_soa())
+            return (ZoneFinder.SUCCESS, create_soa(SOA_CURRENT_VERSION))
         elif name == Name('nosoa.example.com') and rrtype == RRType.SOA():
         elif name == Name('nosoa.example.com') and rrtype == RRType.SOA():
             return (ZoneFinder.NXDOMAIN, None)
             return (ZoneFinder.NXDOMAIN, None)
         elif name == Name('multisoa.example.com') and rrtype == RRType.SOA():
         elif name == Name('multisoa.example.com') and rrtype == RRType.SOA():
-            soa_rrset = self.__create_soa()
+            soa_rrset = create_soa(SOA_CURRENT_VERSION)
             soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
             soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
             return (ZoneFinder.SUCCESS, soa_rrset)
             return (ZoneFinder.SUCCESS, soa_rrset)
         else:
         else:
-            return (ZoneFinder.SUCCESS, self.__create_soa())
+            return (ZoneFinder.SUCCESS, create_soa(SOA_CURRENT_VERSION))
 
 
     def get_iterator(self, zone_name, adjust_ttl=False):
     def get_iterator(self, zone_name, adjust_ttl=False):
         if zone_name == Name('notauth.example.com'):
         if zone_name == Name('notauth.example.com'):
@@ -132,7 +139,7 @@ class MockDataSrcClient:
     def get_soa(self):  # emulate ZoneIterator.get_soa()
     def get_soa(self):  # emulate ZoneIterator.get_soa()
         if self._zone_name == Name('nosoa.example.com'):
         if self._zone_name == Name('nosoa.example.com'):
             return None
             return None
-        soa_rrset = self.__create_soa()
+        soa_rrset = create_soa(SOA_CURRENT_VERSION)
         if self._zone_name == Name('multisoa.example.com'):
         if self._zone_name == Name('multisoa.example.com'):
             soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
             soa_rrset.add_rdata(soa_rrset.get_rdata()[0])
         return soa_rrset
         return soa_rrset
@@ -281,9 +288,7 @@ class TestXfroutSessionBase(unittest.TestCase):
                                        {})
                                        {})
         self.set_request_type(RRType.AXFR()) # test AXFR by default
         self.set_request_type(RRType.AXFR()) # test AXFR by default
         self.mdata = self.create_request_data()
         self.mdata = self.create_request_data()
-        self.soa_rrset = RRset(TEST_ZONE_NAME, RRClass.IN(), RRType.SOA(),
-                               RRTTL(3600))
-        self.soa_rrset.add_rdata(soa_rdata)
+        self.soa_rrset = create_soa(SOA_CURRENT_VERSION)
         # some test replaces a module-wide function.  We should ensure the
         # some test replaces a module-wide function.  We should ensure the
         # original is used elsewhere.
         # original is used elsewhere.
         self.orig_get_rrset_len = xfrout.get_rrset_len
         self.orig_get_rrset_len = xfrout.get_rrset_len
@@ -810,14 +815,14 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.xfrsess.dns_xfrout_start(self.sock, self.mdata)
         self.assertEqual(self.sock.readsent(), b"success")
         self.assertEqual(self.sock.readsent(), b"success")
 
 
-    def test_reply_xfrout_query_noerror(self):
+    def test_reply_xfrout_query_axfr(self):
         self.xfrsess._soa = self.soa_rrset
         self.xfrsess._soa = self.soa_rrset
         self.xfrsess._iterator = [self.soa_rrset]
         self.xfrsess._iterator = [self.soa_rrset]
         self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
         self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
         reply_msg = self.sock.read_msg()
         reply_msg = self.sock.read_msg()
         self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
         self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
 
 
-    def test_reply_xfrout_query_noerror_with_tsig(self):
+    def test_reply_xfrout_query_axfr_with_tsig(self):
         rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(),
         rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(),
                       RRTTL(3600))
                       RRTTL(3600))
         rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), '192.0.2.1'))
         rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), '192.0.2.1'))
@@ -845,6 +850,13 @@ class TestXfroutSession(TestXfroutSessionBase):
         # and it should not have sent anything else
         # and it should not have sent anything else
         self.assertEqual(0, len(self.sock.sendqueue))
         self.assertEqual(0, len(self.sock.sendqueue))
 
 
+    def test_reply_xfrout_query_ixfr(self):
+        self.xfrsess._soa = self.soa_rrset
+        self.xfrsess._iterator = [self.soa_rrset]
+        self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
+        reply_msg = self.sock.read_msg()
+        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER), 2)
+
 class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
 class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
     '''Tests for XFR-out sessions using an SQLite3 DB.
     '''Tests for XFR-out sessions using an SQLite3 DB.