Parcourir la source

[trac985] use try-except around all sqlite-related calls

Jelte Jansen il y a 14 ans
Parent
commit
14153d9340

+ 20 - 20
src/lib/python/isc/datasrc/sqlite3_ds.py

@@ -235,13 +235,13 @@ def load(dbfile, zone, reader):
         zone += '.'
 
     conn, cur = open(dbfile)
-    old_zone_id = get_zoneid(zone, cur)
-
-    temp = str(random.randrange(100000))
-    cur.execute("INSERT INTO zones (name, rdclass) VALUES (?, 'IN')", [temp])
-    new_zone_id = cur.lastrowid
-
     try:
+        old_zone_id = get_zoneid(zone, cur)
+    
+        temp = str(random.randrange(100000))
+        cur.execute("INSERT INTO zones (name, rdclass) VALUES (?, 'IN')", [temp])
+        new_zone_id = cur.lastrowid
+
         for name, ttl, rdclass, rdtype, rdata in reader():
             sigtype = ''
             if rdtype.lower() == 'rrsig':
@@ -266,20 +266,20 @@ def load(dbfile, zone, reader):
                                VALUES (?, ?, ?, ?, ?, ?)""",
                             [new_zone_id, name, reverse_name(name), ttl,
                              rdtype, rdata])
+
+        if old_zone_id:
+            cur.execute("DELETE FROM zones WHERE id=?", [old_zone_id])
+            cur.execute("UPDATE zones SET name=? WHERE id=?", [zone, new_zone_id])
+            conn.commit()
+            cur.execute("DELETE FROM records WHERE zone_id=?", [old_zone_id])
+            cur.execute("DELETE FROM nsec3 WHERE zone_id=?", [old_zone_id])
+            conn.commit()
+        else:
+            cur.execute("UPDATE zones SET name=? WHERE id=?", [zone, new_zone_id])
+            conn.commit()
     except Exception as e:
         fail = "Error while loading " + zone + ": " + e.args[0]
         raise Sqlite3DSError(fail)
-
-    if old_zone_id:
-        cur.execute("DELETE FROM zones WHERE id=?", [old_zone_id])
-        cur.execute("UPDATE zones SET name=? WHERE id=?", [zone, new_zone_id])
-        conn.commit()
-        cur.execute("DELETE FROM records WHERE zone_id=?", [old_zone_id])
-        cur.execute("DELETE FROM nsec3 WHERE zone_id=?", [old_zone_id])
-        conn.commit()
-    else:
-        cur.execute("UPDATE zones SET name=? WHERE id=?", [zone, new_zone_id])
-        conn.commit()
-
-    cur.close()
-    conn.close()
+    finally:
+        cur.close()
+        conn.close()

+ 1 - 0
src/lib/python/isc/datasrc/tests/Makefile.am

@@ -16,5 +16,6 @@ endif
 	echo Running test: $$pytest ; \
 	env PYTHONPATH=$(abs_top_srcdir)/src/lib/python:$(abs_top_builddir)/src/lib/python:$(abs_top_builddir)/src/lib/python/isc/log \
 	TESTDATA_PATH=$(abs_srcdir)/testdata \
+	TESTDATA_WRITE_PATH=$(abs_builddir) \
 	$(PYCOVERAGE_RUN) $(abs_srcdir)/$$pytest || exit ; \
 	done

+ 55 - 3
src/lib/python/isc/datasrc/tests/sqlite3_ds_test.py

@@ -17,8 +17,31 @@ from isc.datasrc import sqlite3_ds
 import os
 import socket
 import unittest
+import sqlite3
 
 TESTDATA_PATH = os.environ['TESTDATA_PATH'] + os.sep
+TESTDATA_WRITE_PATH = os.environ['TESTDATA_WRITE_PATH'] + os.sep
+
+READ_ZONE_DB_FILE = TESTDATA_PATH + "example.com.sqlite3"
+WRITE_ZONE_DB_FILE = TESTDATA_WRITE_PATH + "example.com.out.sqlite3"
+BROKEN_DB_FILE = TESTDATA_PATH + "brokendb.sqlite3"
+
+def example_reader():
+    my_zone = [
+        ("example.com.",    "3600",    "IN",  "SOA", "ns.example.com. admin.example.com. 1234 3600 1800 2419200 7200"),
+        ("example.com.",    "3600",    "IN",  "NS", "ns.example.com."),
+        ("ns.example.com.", "3600",    "IN",  "A", "192.0.2.1")
+    ]
+    for rr in my_zone:
+        yield rr
+
+def example_reader_nested():
+    # this iterator is used in the 'locked' test; it will cause
+    # the load() method to try and write to the same database
+    sqlite3_ds.load(WRITE_ZONE_DB_FILE,
+                    ".",
+                    example_reader)
+    return example_reader()
 
 class TestSqlite3_ds(unittest.TestCase):
     def test_zone_exist(self):
@@ -33,11 +56,40 @@ class TestSqlite3_ds(unittest.TestCase):
         # Open a broken database file
         self.assertRaises(sqlite3_ds.Sqlite3DSError,
                           sqlite3_ds.zone_exist, "example.com",
-                          TESTDATA_PATH + "brokendb.sqlite3")
+                          BROKEN_DB_FILE)
         self.assertTrue(sqlite3_ds.zone_exist("example.com.",
-                            TESTDATA_PATH + "example.com.sqlite3"))
+                        READ_ZONE_DB_FILE))
         self.assertFalse(sqlite3_ds.zone_exist("example.org.",
-                            TESTDATA_PATH + "example.com.sqlite3"))
+                         READ_ZONE_DB_FILE))
+
+    def test_load_db(self):
+        sqlite3_ds.load(WRITE_ZONE_DB_FILE, ".", example_reader)
+
+    def test_locked_db(self):
+        # load it first to make sure it exists
+        sqlite3_ds.load(WRITE_ZONE_DB_FILE, ".", example_reader)
+
+        # and manually create a writing session as well
+        con = sqlite3.connect(WRITE_ZONE_DB_FILE);
+        cur = con.cursor()
+        cur.execute("delete from records")
+        
+        self.assertRaises(sqlite3_ds.Sqlite3DSError,
+                          sqlite3_ds.load, WRITE_ZONE_DB_FILE, ".",
+                          example_reader)
+
+        con.rollback()
+
+        # and make sure lock does not stay
+        sqlite3_ds.load(WRITE_ZONE_DB_FILE, ".", example_reader)
+
+        # force locked db by nested loads
+        self.assertRaises(sqlite3_ds.Sqlite3DSError,
+                          sqlite3_ds.load, WRITE_ZONE_DB_FILE, ".",
+                          example_reader_nested)
+
+        # and make sure lock does not stay
+        sqlite3_ds.load(WRITE_ZONE_DB_FILE, ".", example_reader)
 
 if __name__ == '__main__':
     unittest.main()

BIN
src/lib/python/isc/datasrc/tests/testdata/example.com.sqlite3