Browse Source

[1484] Fix and enhance tests

* They pass now
* Repeated code is there only once
* Uses rrsets_equal to compare rrsets
* Checks reference counts on the result
Michal 'vorner' Vaner 13 years ago
parent
commit
d907907c7d
1 changed files with 67 additions and 56 deletions
  1. 67 56
      src/lib/python/isc/datasrc/tests/datasrc_test.py

+ 67 - 56
src/lib/python/isc/datasrc/tests/datasrc_test.py

@@ -57,6 +57,61 @@ def create_soa(serial):
                         str(serial) + ' 3600 1800 2419200 7200'))
     return soa
 
+def test_findall_common(self, tested):
+    """
+    Common part of the find_all test. It tests a find_all method on the passed
+    object.
+    """
+    # Some "failure" responses
+    result, rrset = tested.find_all(isc.dns.Name("www.sql1.example.com"),
+                                    ZoneFinder.FIND_DEFAULT)
+    self.assertEqual(ZoneFinder.DELEGATION, result)
+    expected = RRset(Name('sql1.example.com.'), RRClass.IN(), RRType.NS(),
+                     RRTTL(3600))
+    expected.add_rdata(Rdata(RRType.NS(), RRClass.IN(),
+                             'dns01.example.com.'))
+    expected.add_rdata(Rdata(RRType.NS(), RRClass.IN(),
+                             'dns02.example.com.'))
+    expected.add_rdata(Rdata(RRType.NS(), RRClass.IN(),
+                             'dns03.example.com.'))
+    self.assertTrue(rrsets_equal(expected, rrset))
+
+    result, rrset = tested.find_all(isc.dns.Name("nxdomain.example.com"),
+                                     ZoneFinder.FIND_DEFAULT)
+    self.assertEqual(ZoneFinder.NXDOMAIN, result)
+    self.assertIsNone(None, rrset)
+
+    # A success. It should return the list now.
+    # This also tests we can ommit the options parameter
+    result, rrsets = tested.find_all(isc.dns.Name("mix.example.com."))
+    self.assertEqual(ZoneFinder.SUCCESS, result)
+    self.assertEqual(2, len(rrsets))
+    rrsets.sort(key=lambda rrset: rrset.get_type().to_text())
+    expected = [
+        RRset(Name('mix.example.com.'), RRClass.IN(), RRType.A(),
+              RRTTL(3600)),
+        RRset(Name('mix.example.com.'), RRClass.IN(), RRType.AAAA(),
+              RRTTL(3600))
+    ]
+    expected[0].add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.1"))
+    expected[0].add_rdata(Rdata(RRType.A(), RRClass.IN(), "192.0.2.2"))
+    expected[1].add_rdata(Rdata(RRType.AAAA(), RRClass.IN(),
+                                "2001:db8::1"))
+    expected[1].add_rdata(Rdata(RRType.AAAA(), RRClass.IN(),
+                                "2001:db8::2"))
+    for (rrset, exp) in zip(rrsets, expected):
+        self.assertTrue(rrsets_equal(exp, rrset))
+
+    # Check the reference counts on them. The getrefcount returns one more,
+    # as for the reference in its own parameter - see its docs.
+
+    # Two - one for the variable, one for parameter
+    self.assertEqual(2, sys.getrefcount(rrsets))
+    for rrset in rrsets:
+        # 3 - one as the element of list, one for the rrset variable
+        # and one for the parameter.
+        self.assertEqual(3, sys.getrefcount(rrset))
+
 class DataSrcClient(unittest.TestCase):
 
     def test_(self):
@@ -275,66 +330,12 @@ class DataSrcClient(unittest.TestCase):
         """
         dsc = isc.datasrc.DataSourceClient("sqlite3", READ_ZONE_DB_CONFIG)
         result, finder = dsc.find_zone(isc.dns.Name("example.com"))
+
         self.assertEqual(finder.SUCCESS, result)
         self.assertEqual(isc.dns.RRClass.IN(), finder.get_class())
         self.assertEqual("example.com.", finder.get_origin().to_text())
 
-        # Some "failure" responses
-        result, rrset = finder.find_all(isc.dns.Name("www.sql1.example.com"),
-                                        finder.FIND_DEFAULT)
-        self.assertEqual(finder.DELEGATION, result)
-        self.assertEqual("sql1.example.com. 3600 IN NS dns01.example.com.\n" +
-                         "sql1.example.com. 3600 IN NS dns02.example.com.\n" +
-                         "sql1.example.com. 3600 IN NS dns03.example.com.\n",
-                         rrset.to_text())
-
-        result, rrset = finder.find_all(isc.dns.Name("nxdomain.example.com"),
-                                        finder.FIND_DEFAULT)
-        self.assertEqual(finder.NXDOMAIN, result)
-        self.assertIsNone(None, rrset)
-
-        # A success. It should return the list now.
-        result, rrsets = finder.find_all(isc.dns.Name("mix.example.com."))
-        self.assertEqual(ZoneFinder.SUCCESS, result)
-        self.assertEqual(2, len(rrsets))
-        self.assertEqual(sorted(map(lambda rrset: rrset.get_type().to_text(),
-                                    rrsets)), sorted(["A", "AAAA"]))
-        rdatas = []
-        for rrset in rrsets:
-            rdatas.extend(rrset.get_rdata())
-        self.assertEqual(sorted(map(lambda rdata: rdata.to_text(), rdatas)),
-                         sorted(["192.0.2.1", "192.0.2.2", "2001:db8::1",
-                                 "2001:db8::2"]))
-        # The same, but on an updater
-        dsc = isc.datasrc.DataSourceClient("sqlite3", WRITE_ZONE_DB_CONFIG)
-        updater = dsc.get_updater(isc.dns.Name("example.com"), False)
-
-        # Some "failure" responses
-        result, rrset = updater.find_all(isc.dns.Name("www.sql1.example.com"),
-                                        finder.FIND_DEFAULT)
-        self.assertEqual(finder.DELEGATION, result)
-        self.assertEqual("sql1.example.com. 3600 IN NS dns01.example.com.\n" +
-                         "sql1.example.com. 3600 IN NS dns02.example.com.\n" +
-                         "sql1.example.com. 3600 IN NS dns03.example.com.\n",
-                         rrset.to_text())
-
-        result, rrset = updater.find_all(isc.dns.Name("nxdomain.example.com"),
-                                         finder.FIND_DEFAULT)
-        self.assertEqual(finder.NXDOMAIN, result)
-        self.assertIsNone(None, rrset)
-
-        # A success. It should return the list now.
-        result, rrsets = updater.find_all(isc.dns.Name("mix.example.com."))
-        self.assertEqual(ZoneFinder.SUCCESS, result)
-        self.assertEqual(2, len(rrsets))
-        self.assertEqual(sorted(map(lambda rrset: rrset.get_type().to_text(),
-                                    rrsets)), sorted(["A", "AAAA"]))
-        rdatas = []
-        for rrset in rrsets:
-            rdatas.extend(rrset.get_rdata())
-        self.assertEqual(sorted(map(lambda rdata: rdata.to_text(), rdatas)),
-                         sorted(["192.0.2.1", "192.0.2.2", "2001:db8::1",
-                                 "2001:db8::2"]))
+        test_findall_common(self, finder)
 
     def test_find(self):
         dsc = isc.datasrc.DataSourceClient("sqlite3", READ_ZONE_DB_CONFIG)
@@ -444,6 +445,16 @@ class DataSrcUpdater(unittest.TestCase):
         # Make a fresh copy of the writable database with all original content
         shutil.copyfile(READ_ZONE_DB_FILE, WRITE_ZONE_DB_FILE)
 
+    def test_findall(self):
+        """
+        The same test as DataSrcClient.test_findall, but on an updater
+        instead of a finder.
+        """
+        dsc = isc.datasrc.DataSourceClient("sqlite3", WRITE_ZONE_DB_CONFIG)
+        updater = dsc.get_updater(isc.dns.Name("example.com"), False)
+
+        test_findall_common(self, updater)
+
     def test_construct(self):
         # can't construct directly
         self.assertRaises(TypeError, isc.datasrc.ZoneUpdater)