Browse Source

[1371] introduce rrset_util.py to compare RRsets and create test RRsets
in one function call. extended xfrout tests so we check the resulting
responses in more detail

JINMEI Tatuya 13 years ago
parent
commit
de43982b90

+ 33 - 58
src/bin/xfrout/tests/xfrout_test.py.in

@@ -22,6 +22,7 @@ from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.cc.session import *
 import isc.config
 from isc.dns import *
+from isc.testutils.rrset_utils import *
 from xfrout import *
 import xfrout
 import isc.log
@@ -40,25 +41,6 @@ IXFR_OK_VERSION = 2011111802
 IXFR_NG_VERSION = 2011112800
 SOA_CURRENT_VERSION = 2011112001
 
-# Shortcut functions to create RRsets commonly used in tests below.
-def create_a(address, ttl=3600):
-    rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.A(), RRTTL(ttl))
-    rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), address))
-    return rrset
-
-def create_aaaa(address):
-    rrset = RRset(Name('a.example.com'), RRClass.IN(), RRType.AAAA(),
-                  RRTTL(3600))
-    rrset.add_rdata(Rdata(RRType.AAAA(), RRClass.IN(), address))
-    return rrset
-
-def create_soa(serial):
-    rrset = RRset(TEST_ZONE_NAME, RRClass.IN(), RRType.SOA(), RRTTL(3600))
-    rdata_str = 'master.example.com. admin.example.com. ' + \
-        str(serial) + ' 3600 1800 2419200 7200'
-    rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(), rdata_str))
-    return rrset
-
 # our fake socket, where we can read and insert messages
 class MySocket():
     def __init__(self, family, type):
@@ -864,16 +846,22 @@ class TestXfroutSession(TestXfroutSessionBase):
         # RRs won't be skipped.
         self.xfrsess._soa = create_soa(SOA_CURRENT_VERSION)
         self.xfrsess._iterator = [create_soa(IXFR_OK_VERSION),
-                                  create_a('192.0.2.2'),
+                                  create_a(Name('a.example.com'), '192.0.2.2'),
                                   create_soa(SOA_CURRENT_VERSION),
-                                  create_aaaa('2001:db8::1')]
+                                  create_aaaa(Name('a.example.com'),
+                                              '2001:db8::1')]
         self.xfrsess._jnl_reader = self.xfrsess._iterator
         self.xfrsess._reply_xfrout_query(self.getmsg(), self.sock)
         reply_msg = self.sock.read_msg(Message.PRESERVE_ORDER)
-        # The answer section should contain everything in the "fake"
-        # iterator and two SOAs.
-        self.assertEqual(reply_msg.get_rr_count(Message.SECTION_ANSWER),
-                         len(self.xfrsess._iterator) + 2)
+        actual_records = reply_msg.get_section(Message.SECTION_ANSWER)
+
+        expected_records = self.xfrsess._iterator[:]
+        expected_records.insert(0, create_soa(SOA_CURRENT_VERSION))
+        expected_records.append(create_soa(SOA_CURRENT_VERSION))
+
+        self.assertEqual(len(expected_records), len(actual_records))
+        for (expected_rr, actual_rr) in zip(expected_records, actual_records):
+            self.assertTrue(expected_rr, actual_rr)
 
     def test_reply_xfrout_query_ixfr_soa_only(self):
         # Creating an IXFR response that contains only one RR, which is the
@@ -885,9 +873,7 @@ class TestXfroutSession(TestXfroutSessionBase):
         reply_msg = self.sock.read_msg(Message.PRESERVE_ORDER)
         answer = reply_msg.get_section(Message.SECTION_ANSWER)
         self.assertEqual(1, len(answer))
-        self.assertEqual(RRType.SOA(), answer[0].get_type())
-        self.assertEqual(SOA_CURRENT_VERSION,
-                         xfrout.get_soa_serial(answer[0].get_rdata()[0]))
+        self.assertTrue(create_soa(SOA_CURRENT_VERSION), answer[0])
 
 class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
     '''Tests for XFR-out sessions using an SQLite3 DB.
@@ -903,30 +889,22 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         self.xfrsess._request_data = self.mdata
         self.xfrsess._server.get_db_file = lambda : TESTDATA_SRCDIR + \
             'test.sqlite3'
+        self.ns_name = 'a.dns.example.com'
 
     def check_axfr_stream(self, response):
         '''Common checks for AXFR(-style) response for the test zone.
         '''
         # This zone contains two A RRs for the same name with different TTLs.
         # These TTLs should be preseved in the AXFR stream.
-        # We'll check some important points as a valid AXFR response:
-        # the first and last RR must be SOA, and these should be the only
-        # SOAs in the response.  The total number of response RRs
-        # must be 5 (zone has 4 RRs, SOA is duplicated)
         actual_records = response.get_section(Message.SECTION_ANSWER)
-        self.assertEqual(5, len(actual_records))
-        self.assertEqual(RRType.SOA(), actual_records[0].get_type())
-        self.assertEqual(RRType.SOA(), actual_records[-1].get_type())
-        actual_ttls = []
-        num_soa = 0
-        for rr in actual_records:
-            if rr.get_type() == RRType.SOA():
-                num_soa += 1
-            if rr.get_type() == RRType.A() and \
-                    not rr.get_ttl() in actual_ttls:
-                actual_ttls.append(rr.get_ttl().get_value())
-        self.assertEqual(2, num_soa)
-        self.assertEqual([3600, 7200], sorted(actual_ttls))
+        expected_records = [create_soa(2011112001),
+                            create_ns(self.ns_name),
+                            create_a(Name(self.ns_name), '192.0.2.1', 3600),
+                            create_a(Name(self.ns_name), '192.0.2.2', 7200),
+                            create_soa(2011112001)]
+        self.assertEqual(len(expected_records), len(actual_records))
+        for (expected_rr, actual_rr) in zip(expected_records, actual_records):
+            self.assertTrue(expected_rr, actual_rr)
 
     def test_axfr_normal_session(self):
         XfroutSession._handle(self.xfrsess)
@@ -954,16 +932,15 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         XfroutSession._handle(self.xfrsess)
         response = self.sock.read_msg(Message.PRESERVE_ORDER);
         actual_records = response.get_section(Message.SECTION_ANSWER)
-        self.assertEqual(10, len(actual_records))
-        # The first and last RRs must be SOA whose serial is the latest one.
-        soa = actual_records[0]
-        self.assertEqual(RRType.SOA(), soa.get_type())
-        self.assertEqual(SOA_CURRENT_VERSION,
-                         xfrout.get_soa_serial(soa.get_rdata()[0]))
-        soa = actual_records[-1]
-        self.assertEqual(RRType.SOA(), soa.get_type())
-        self.assertEqual(SOA_CURRENT_VERSION,
-                         xfrout.get_soa_serial(soa.get_rdata()[0]))
+        expected_records = [create_soa(2011112001), create_soa(2011111802),
+                            create_soa(2011111900),
+                            create_a(Name(self.ns_name), '192.0.2.2', 7200),
+                            create_soa(2011111900),
+                            create_a(Name(self.ns_name), '192.0.2.53'),
+                            create_aaaa(Name(self.ns_name), '2001:db8::1'),
+                            create_soa(2011112001),
+                            create_a(Name(self.ns_name), '192.0.2.1'),
+                            create_soa(2011112001)]
 
     def test_ixfr_soa_only(self):
         # The requested SOA serial is the latest one.  The response should
@@ -974,9 +951,7 @@ class TestXfroutSessionWithSQLite3(TestXfroutSessionBase):
         response = self.sock.read_msg(Message.PRESERVE_ORDER);
         answers = response.get_section(Message.SECTION_ANSWER)
         self.assertEqual(1, len(answers))
-        self.assertEqual(RRType.SOA(), answers[0].get_type())
-        self.assertEqual(SOA_CURRENT_VERSION,
-                         xfrout.get_soa_serial(answers[0].get_rdata()[0]))
+        self.assertTrue(create_soa(SOA_CURRENT_VERSION), answers[0])
 
 class MyUnixSockServer(UnixSockServer):
     def __init__(self):

+ 1 - 1
src/lib/python/isc/testutils/Makefile.am

@@ -1,4 +1,4 @@
-EXTRA_DIST = __init__.py parse_args.py tsigctx_mock.py
+EXTRA_DIST = __init__.py parse_args.py tsigctx_mock.py rrset_utils.py
 
 CLEANDIRS = __pycache__
 

+ 61 - 0
src/lib/python/isc/testutils/rrset_utils.py

@@ -0,0 +1,61 @@
+# Copyright (C) 2011  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+'''Utility functions handling DNS RRsets commonly used for tests'''
+
+from isc.dns import *
+
+def rrsets_equal(a, b):
+    '''Compare two RRsets, return True if equal, otherwise False'''
+
+    # no accessor for sigs either (so this only checks name, class, type, ttl,
+    # and rdata)
+    # also, because of the fake data in rrsigs, if the type is rrsig, the
+    # rdata is not checked
+    return a.get_name() == b.get_name() and \
+           a.get_class() == b.get_class() and \
+           a.get_type() == b.get_type() and \
+           a.get_ttl() == b.get_ttl() and \
+           (a.get_type() == RRType.RRSIG() or
+            sorted(a.get_rdata()) == sorted(b.get_rdata()))
+
+# The following are short cut utilities to create an RRset of a specific
+# RR type with one RDATA.  Many of the RR parameters are common in most
+# tests, so we define default values for them for convenience.
+
+def create_a(name, address, ttl=3600):
+    rrset = RRset(name, RRClass.IN(), RRType.A(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.A(), RRClass.IN(), address))
+    return rrset
+
+def create_aaaa(name, address, ttl=3600):
+    rrset = RRset(name, RRClass.IN(), RRType.AAAA(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.AAAA(), RRClass.IN(), address))
+    return rrset
+
+def create_ns(nsname, name=Name('example.com'), ttl=3600):
+    '''For convenience we use a default name often used as a zone name'''
+    rrset = RRset(name, RRClass.IN(), RRType.NS(), RRTTL(ttl))
+    rrset.add_rdata(Rdata(RRType.NS(), RRClass.IN(), nsname))
+    return rrset
+
+def create_soa(serial, name=Name('example.com'), ttl=3600):
+    '''For convenience we use a default name often used as a zone name'''
+
+    rrset = RRset(name, RRClass.IN(), RRType.SOA(), RRTTL(ttl))
+    rdata_str = 'master.example.com. admin.example.com. ' + \
+        str(serial) + ' 3600 1800 2419200 7200'
+    rrset.add_rdata(Rdata(RRType.SOA(), RRClass.IN(), rdata_str))
+    return rrset