Parcourir la source

[2379] Add reference count checks to tests

Jelte Jansen il y a 12 ans
Parent
commit
45596b6bd5
1 fichiers modifiés avec 41 ajouts et 16 suppressions
  1. 41 16
      src/lib/python/isc/datasrc/tests/zone_loader_test.py

+ 41 - 16
src/lib/python/isc/datasrc/tests/zone_loader_test.py

@@ -19,6 +19,7 @@ import isc.dns
 import os
 import unittest
 import shutil
+import sys
 
 # Constants and common data used in tests
 
@@ -46,14 +47,37 @@ class ZoneLoaderTests(unittest.TestCase):
         self.client = isc.datasrc.DataSourceClient("sqlite3", DB_CLIENT_CONFIG)
         # Make a fresh copy of the database
         shutil.copy(ORIG_DB_FILE, DB_FILE)
+        # Some tests set source client; if so, check refcount in
+        # tearDown, since most tests don't, set it to None by default.
+        self.source_client = None
+        self.loader = None
+        self.assertEqual(2, sys.getrefcount(self.test_name))
+        self.assertEqual(2, sys.getrefcount(self.client))
 
     def tearDown(self):
         # We can only create 1 loader at a time (it locks the db), and it
         # may not be destroyed immediately if there is an exception in a
         # test. So the tests that do create one should put it in self, and
         # we make sure to invalidate it here.
+
+        # We can also use this to check reference counts; if a loader
+        # exists, the client and source client (if any) should have
+        # an increased reference count (but the name should not, this
+        # is only used in the initializer)
+        if self.loader is not None:
+            self.assertEqual(2, sys.getrefcount(self.test_name))
+            self.assertEqual(3, sys.getrefcount(self.client))
+            if (self.source_client is not None):
+                self.assertEqual(3, sys.getrefcount(self.source_client))
         self.loader = None
 
+        # Now that the loader has been destroyed, the refcounts
+        # of its arguments should be back to their originals
+        self.assertEqual(2, sys.getrefcount(self.test_name))
+        self.assertEqual(2, sys.getrefcount(self.client))
+        if (self.source_client is not None):
+            self.assertEqual(2, sys.getrefcount(self.source_client))
+
     def test_bad_constructor(self):
         self.assertRaises(TypeError, isc.datasrc.ZoneLoader)
         self.assertRaises(TypeError, isc.datasrc.ZoneLoader, 1)
@@ -90,13 +114,13 @@ class ZoneLoaderTests(unittest.TestCase):
         self.check_load()
 
     def test_load_from_client(self):
-        source_client = isc.datasrc.DataSourceClient('sqlite3',
-                                                     DB_SOURCE_CLIENT_CONFIG)
+        self.source_client = isc.datasrc.DataSourceClient('sqlite3',
+                                    DB_SOURCE_CLIENT_CONFIG)
         self.loader = isc.datasrc.ZoneLoader(self.client, self.test_name,
-                                             source_client)
+                                             self.source_client)
         self.check_load()
 
-    def test_load_from_file_checkrefs(self):
+    def Xtest_load_from_file_checkrefs(self):
         # A test to see the refcount is increased properly
         self.loader = isc.datasrc.ZoneLoader(self.client, self.test_name,
                                              self.test_file)
@@ -109,7 +133,7 @@ class ZoneLoaderTests(unittest.TestCase):
         self.test_file = None
         self.loader.load()
 
-    def test_load_from_client_checkrefs(self):
+    def Xtest_load_from_client_checkrefs(self):
         # A test to see the refcount is increased properly
         source_client = isc.datasrc.DataSourceClient('sqlite3',
                                                      DB_SOURCE_CLIENT_CONFIG)
@@ -146,10 +170,10 @@ class ZoneLoaderTests(unittest.TestCase):
         self.check_load_incremental()
 
     def test_load_from_client_incremental(self):
-        source_client = isc.datasrc.DataSourceClient('sqlite3',
-                                                     DB_SOURCE_CLIENT_CONFIG)
+        self.source_client = isc.datasrc.DataSourceClient('sqlite3',
+                                            DB_SOURCE_CLIENT_CONFIG)
         self.loader = isc.datasrc.ZoneLoader(self.client, self.test_name,
-                                             source_client)
+                                             self.source_client)
         self.check_load_incremental()
 
     def test_bad_file(self):
@@ -175,18 +199,18 @@ class ZoneLoaderTests(unittest.TestCase):
     def test_no_such_zone_in_source(self):
         # Reuse a zone that exists in target but not in source
         zone_name = isc.dns.Name("sql1.example.com")
-        source_client = isc.datasrc.DataSourceClient('sqlite3',
-                                                     DB_SOURCE_CLIENT_CONFIG)
+        self.source_client = isc.datasrc.DataSourceClient('sqlite3',
+                                            DB_SOURCE_CLIENT_CONFIG)
 
         # make sure the zone exists in the target
         found, _ = self.client.find_zone(zone_name)
         self.assertEqual(self.client.SUCCESS, found)
         # And that it does not in the source
-        found, _ = source_client.find_zone(zone_name)
-        self.assertNotEqual(source_client.SUCCESS, found)
+        found, _ = self.source_client.find_zone(zone_name)
+        self.assertNotEqual(self.source_client.SUCCESS, found)
 
         self.assertRaises(isc.datasrc.Error, isc.datasrc.ZoneLoader,
-                          self.client, zone_name, source_client)
+                          self.client, zone_name, self.source_client)
 
     def test_no_ds_load_support(self):
         # This may change in the future, but atm, the in-mem ds does
@@ -209,10 +233,11 @@ class ZoneLoaderTests(unittest.TestCase):
         clientlist = isc.datasrc.ConfigurableClientList(isc.dns.RRClass.CH())
         clientlist.configure('[ { "type": "static", "params": "' +
                              STATIC_ZONE_FILE +'" } ]', False)
-        source_client, _, _ = clientlist.find(isc.dns.Name("bind."),
-                                              False, False)
+        self.source_client, _, _ = clientlist.find(isc.dns.Name("bind."),
+                                                   False, False)
         self.assertRaises(isc.dns.InvalidParameter, isc.datasrc.ZoneLoader,
-                          self.client, isc.dns.Name("bind."), source_client)
+                          self.client, isc.dns.Name("bind."),
+                          self.source_client)
 
     def test_exception(self):
         # Just check if masterfileerror is subclass of datasrc.Error