Browse Source

[2379] loadIncremental wrapper

Jelte Jansen 12 years ago
parent
commit
75effbfc1d

+ 0 - 2
src/lib/python/isc/datasrc/tests/testdata/example.com

@@ -1,5 +1,3 @@
-;; This is the source of a zone stored in test.sqlite3.  It's provided
-;; for reference purposes only.
 example.com.         1000  IN  SOA a.dns.example.com. mail.example.com. 1 1 1 1 1
 example.com.         1000  IN  NS  a.dns.example.com.
 example.com.         1000  IN  NS  b.dns.example.com.

+ 47 - 12
src/lib/python/isc/datasrc/tests/zone_loader_test.py

@@ -18,25 +18,32 @@ import isc.dns
 
 import os
 import unittest
+import shutil
+
+# Constants and common data used in tests
 
 TESTDATA_PATH = os.environ['TESTDATA_PATH']
 TESTDATA_WRITE_PATH = os.environ['TESTDATA_WRITE_PATH']
 
 ZONE_FILE = TESTDATA_PATH + '/example.com'
 
+ORIG_DB_FILE = TESTDATA_PATH + '/example.com.sqlite3'
 DB_FILE = TESTDATA_WRITE_PATH + '/zoneloadertest.sqlite3'
 DB_CLIENT_CONFIG = '{ "database_file": "' + DB_FILE + '"}'
 
+ORIG_SOA_TXT = 'example.com. 3600 IN SOA master.example.com. ' +\
+               'admin.example.com. 1234 3600 1800 2419200 7200\n'
+NEW_SOA_TXT = 'example.com. 1000 IN SOA a.dns.example.com. ' +\
+              'mail.example.com. 1 1 1 1 1\n'
+
+
 class ZoneLoaderTests(unittest.TestCase):
     def setUp(self):
         self.test_name = isc.dns.Name("example.com")
         self.test_file = ZONE_FILE
         self.client = isc.datasrc.DataSourceClient("sqlite3", DB_CLIENT_CONFIG)
-
-    def tearDown(self):
-        # Delete the database after each test
-        if os.path.exists(DB_FILE):
-            os.unlink(DB_FILE)
+        # Make a fresh copy of the database
+        shutil.copy(ORIG_DB_FILE, DB_FILE)
 
     def test_bad_constructor(self):
         self.assertRaises(TypeError, isc.datasrc.ZoneLoader)
@@ -50,22 +57,50 @@ class ZoneLoaderTests(unittest.TestCase):
         self.assertRaises(TypeError, isc.datasrc.ZoneLoader,
                           self.client, self.test_name, self.test_file, 1)
 
+    def check_zone_soa(self, soa_txt):
+        """
+        Check that the given RRset exists and matches the expected string
+        """
+        result, finder = self.client.find_zone(self.test_name)
+        self.assertEqual(self.client.SUCCESS, result)
+        result, rrset, _ = finder.find(self.test_name, isc.dns.RRType.SOA())
+        self.assertEqual(finder.SUCCESS, result)
+        self.assertEqual(soa_txt, rrset.to_text())
+
     def test_load_file(self):
-        #self.assertRaises(TypeError, isc.datasrc.ZoneLoader());
-        result, _ = self.client.find_zone(self.test_name)
-        self.assertEqual(self.client.NOTFOUND, result)
+        self.check_zone_soa(ORIG_SOA_TXT)
 
         # Create loader and load the zone
         loader = isc.datasrc.ZoneLoader(self.client, self.test_name, self.test_file)
         loader.load()
-        # Not really checking content for now, just check the zone exists now
-        result, _ = self.client.find_zone(self.test_name)
-        self.assertEqual(self.client.SUCCESS, result)
+
+        self.check_zone_soa(NEW_SOA_TXT)
+
+    def test_load_incremental(self):
+        self.check_zone_soa(ORIG_SOA_TXT)
+
+        # Create loader and load the zone
+        loader = isc.datasrc.ZoneLoader(self.client, self.test_name, self.test_file)
+
+        # New zone has 8 RRs
+        # After 5, it should return False
+        self.assertFalse(loader.load_incremental(5))
+        # New zone should not have been loaded yet
+        self.check_zone_soa(ORIG_SOA_TXT)
+
+        # After 5 more, it should return True (only having read 3)
+        self.assertTrue(loader.load_incremental(5))
+        # New zone should now be loaded
+        self.check_zone_soa(NEW_SOA_TXT)
+
+        # And after that, it should throw
+        self.assertRaises(isc.datasrc.Error, loader.load_incremental, 5)
 
     def test_bad_file(self):
-        #self.assertRaises(TypeError, isc.datasrc.ZoneLoader());
+        self.check_zone_soa(ORIG_SOA_TXT)
         loader = isc.datasrc.ZoneLoader(self.client, self.test_name, "no such file")
         self.assertRaises(isc.datasrc.MasterFileError, loader.load)
+        self.check_zone_soa(ORIG_SOA_TXT)
 
     def test_exception(self):
         # Just check if masterfileerror is subclass of datasrc.Error

+ 33 - 0
src/lib/python/isc/datasrc/zone_loader_python.cc

@@ -118,6 +118,38 @@ PyObject* ZoneLoader_load(PyObject* po_self, PyObject*) {
     }
 }
 
+PyObject* ZoneLoader_loadIncremental(PyObject* po_self, PyObject* args) {
+    s_ZoneLoader* self = static_cast<s_ZoneLoader*>(po_self);
+
+    int limit;
+    if (!PyArg_ParseTuple(args, "i", &limit)) {
+        return (NULL);
+    }
+    if (limit < 0) {
+        PyErr_SetString(PyExc_ValueError,
+                        "load_incremental argument must be positive");
+        return (NULL);
+    }
+    try {
+        const bool complete = self->cppobj->loadIncremental(limit);
+        if (complete) {
+            Py_RETURN_TRUE;
+        } else {
+            Py_RETURN_FALSE;
+        }
+    } catch (const isc::datasrc::MasterFileError& mfe) {
+        PyErr_SetString(getDataSourceException("MasterFileError"), mfe.what());
+        return (NULL);
+    } catch (const std::exception& exc) {
+        PyErr_SetString(getDataSourceException("Error"), exc.what());
+        return (NULL);
+    } catch (...) {
+        PyErr_SetString(getDataSourceException("Error"),
+                        "Unexpected exception");
+        return (NULL);
+    }
+}
+
 // This list contains the actual set of functions we have in
 // python. Each entry has
 // 1. Python method name
@@ -133,6 +165,7 @@ PyMethodDef ZoneLoader_methods[] = {
     { "find_all", ZoneLoader_find_all, METH_VARARGS, ZoneLoader_findAll_doc },
 */
     { "load", ZoneLoader_load, METH_NOARGS, ZoneLoader_load_doc },
+    { "load_incremental", ZoneLoader_loadIncremental, METH_VARARGS, ZoneLoader_loadIncremental_doc },
     { NULL, NULL, 0, NULL }
 };