Browse Source

[1333u] supported 'journal' parameter for DataSourceClient.get_updater().
updated the test data schema for the diffs table.

JINMEI Tatuya 13 years ago
parent
commit
b449ad20a4

+ 17 - 5
src/lib/python/isc/datasrc/client_python.cc

@@ -129,14 +129,17 @@ PyObject*
 DataSourceClient_getUpdater(PyObject* po_self, PyObject* args) {
     s_DataSourceClient* const self = static_cast<s_DataSourceClient*>(po_self);
     PyObject *name_obj;
-    PyObject *replace_obj;
-    if (PyArg_ParseTuple(args, "O!O", &name_type, &name_obj, &replace_obj) &&
-        PyBool_Check(replace_obj)) {
-        bool replace = (replace_obj != Py_False);
+    PyObject *replace_obj = NULL;
+    PyObject *journaling_obj = Py_False;
+    if (PyArg_ParseTuple(args, "O!O|O", &name_type, &name_obj,
+                         &replace_obj, &journaling_obj) &&
+        PyBool_Check(replace_obj) && PyBool_Check(journaling_obj)) {
+        const bool replace = (replace_obj != Py_False);
+        const bool journaling = (journaling_obj == Py_True);
         try {
             ZoneUpdaterPtr updater =
                 self->cppobj->getInstance().getUpdater(PyName_ToName(name_obj),
-                                                       replace);
+                                                       replace, journaling);
             if (!updater) {
                 return (Py_None);
             }
@@ -157,6 +160,15 @@ DataSourceClient_getUpdater(PyObject* po_self, PyObject* args) {
             return (NULL);
         }
     } else {
+        // PyBool_Check doesn't set the error, so we have to set it ourselves.
+        if (replace_obj != NULL && !PyBool_Check(replace_obj)) {
+            PyErr_SetString(PyExc_TypeError, "'replace' for "
+                            "DataSourceClient.get_updater must be boolean");
+        }
+        if (!PyBool_Check(journaling_obj)) {
+            PyErr_SetString(PyExc_TypeError, "'journaling' for "
+                            "DataSourceClient.get_updater must be boolean");
+        }
         return (NULL);
     }
 }

+ 121 - 1
src/lib/python/isc/datasrc/tests/datasrc_test.py

@@ -16,8 +16,9 @@
 import isc.log
 import isc.datasrc
 from isc.datasrc import ZoneFinder
-import isc.dns
+from isc.dns import *
 import unittest
+import sqlite3
 import os
 import shutil
 import sys
@@ -565,6 +566,125 @@ class DataSrcUpdater(unittest.TestCase):
         self.assertEqual(None, iterator.get_soa())
         self.assertEqual(None, iterator.get_next_rrset())
 
+class JournalWrite(unittest.TestCase):
+    def setUp(self):
+        # Make a fresh copy of the writable database with all original content
+        shutil.copyfile(READ_ZONE_DB_FILE, WRITE_ZONE_DB_FILE)
+        self.dsc = isc.datasrc.DataSourceClient("sqlite3",
+                                                WRITE_ZONE_DB_CONFIG)
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+
+    def tearDown(self):
+        self.dsc = None
+        self.updater = None
+
+    def check_journal(self, expected_list):
+        # This assumes sqlite3 DB and directly fetches stored data from
+        # the DB file.  It should be generalized using ZoneJournalReader
+        # once it's supported.
+        conn = sqlite3.connect(WRITE_ZONE_DB_FILE)
+        cur = conn.cursor()
+        cur.execute('SELECT name, rrtype, ttl, rdata FROM diffs ORDER BY id')
+        actual_list = cur.fetchall()
+        self.assertEqual(len(expected_list), len(actual_list))
+        for (expected, actual) in zip(expected_list, actual_list):
+            self.assertEqual(expected, actual)
+        conn.close()
+
+    def create_soa(self, serial):
+        soa = RRset(Name('example.org'), RRClass.IN(), RRType.SOA(),
+                    RRTTL(3600))
+        soa.add_rdata(Rdata(RRType.SOA(), RRClass.IN(),
+                            'ns1.example.org. admin.example.org. ' +
+                            str(serial) + ' 3600 1800 2419200 7200'))
+        return soa
+
+    def create_a(self, address):
+        a_rr = RRset(Name('www.example.org'), RRClass.IN(), RRType.A(),
+                     RRTTL(3600))
+        a_rr.add_rdata(Rdata(RRType.A(), RRClass.IN(), address))
+        return (a_rr)
+
+    def test_journal_write(self):
+        # This is a straightforward port of the C++ 'journal' test
+        # Note: we add/delete 'out of zone' data (example.org in the
+        # example.com zone for convenience.
+        self.updater.delete_rrset(self.create_soa(1234))
+        self.updater.delete_rrset(self.create_a('192.0.2.2'))
+        self.updater.add_rrset(self.create_soa(1235))
+        self.updater.add_rrset(self.create_a('192.0.2.2'))
+        self.updater.commit()
+
+        expected = []
+        expected.append(("example.org.", "SOA", 3600,
+                         "ns1.example.org. admin.example.org. " +
+                         "1234 3600 1800 2419200 7200"))
+        expected.append(("www.example.org.", "A", 3600, "192.0.2.2"))
+        expected.append(("example.org.", "SOA", 3600,
+                         "ns1.example.org. admin.example.org. " +
+                         "1235 3600 1800 2419200 7200"))
+        expected.append(("www.example.org.", "A", 3600, "192.0.2.2"))
+        self.check_journal(expected)
+
+    def test_journal_write_multiple(self):
+        # This is a straightforward port of the C++ 'journalMultiple' test
+        expected = []
+        for i in range(1, 100):
+            self.updater.delete_rrset(self.create_soa(1234 + i - 1))
+            expected.append(("example.org.", "SOA", 3600,
+                             "ns1.example.org. admin.example.org. " +
+                             str(1234 + i - 1) + " 3600 1800 2419200 7200"))
+            self.updater.add_rrset(self.create_soa(1234 + i))
+            expected.append(("example.org.", "SOA", 3600,
+                             "ns1.example.org. admin.example.org. " +
+                             str(1234 + i) + " 3600 1800 2419200 7200"))
+        self.updater.commit()
+        self.check_journal(expected)
+
+    def test_journal_write_bad_sequence(self):
+        # This is a straightforward port of the C++ 'journalBadSequence' test
+
+        # Delete A before SOA
+        self.assertRaises(isc.datasrc.Error, self.updater.delete_rrset,
+                          self.create_a('192.0.2.1'))
+        # Add before delete
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+        self.assertRaises(isc.datasrc.Error, self.updater.add_rrset,
+                          self.create_soa(1234))
+        # Add A before SOA
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+        self.updater.delete_rrset(self.create_soa(1234))
+        self.assertRaises(isc.datasrc.Error, self.updater.add_rrset,
+                          self.create_a('192.0.2.1'))
+        # Commit before add
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+        self.updater.delete_rrset(self.create_soa(1234))
+        self.assertRaises(isc.datasrc.Error, self.updater.commit)
+        # Delete two SOAs
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+        self.updater.delete_rrset(self.create_soa(1234))
+        self.assertRaises(isc.datasrc.Error, self.updater.delete_rrset,
+                          self.create_soa(1235))
+        # Add two SOAs
+        self.updater = self.dsc.get_updater(Name("example.com"), False, True)
+        self.updater.delete_rrset(self.create_soa(1234))
+        self.updater.add_rrset(self.create_soa(1235))
+        self.assertRaises(isc.datasrc.Error, self.updater.add_rrset,
+                          self.create_soa(1236))
+
+    def test_journal_write_onerase(self):
+        self.updater = None
+        self.assertRaises(isc.datasrc.Error, self.dsc.get_updater,
+                          Name("example.com"), True, True)
+
+    def test_journal_write_badparam(self):
+        dsc = isc.datasrc.DataSourceClient("sqlite3", WRITE_ZONE_DB_CONFIG)
+        self.assertRaises(TypeError, dsc.get_updater, 0, False, True)
+        self.assertRaises(TypeError, dsc.get_updater, Name('example.com'),
+                          False, 0)
+        self.assertRaises(TypeError, dsc.get_updater, Name("example.com"),
+                          1, True)
+
 if __name__ == "__main__":
     isc.log.init("bind10")
     isc.log.resetUnitTestRootLogger()

BIN
src/lib/python/isc/datasrc/tests/testdata/example.com.sqlite3