Browse Source

update according to review comments

chenzhengzhang 14 years ago
parent
commit
3f797298c0

+ 10 - 8
src/bin/xfrout/tests/xfrout_test.py

@@ -215,36 +215,38 @@ class TestXfroutSession(unittest.TestCase):
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(82, get_rrset_len(rrset_soa))
 
-    def test_zone_is_empty(self):
+    def test_zone_has_soa(self):
         global sqlite3_ds
         def mydb1(zone, file):
             return True
         sqlite3_ds.get_zone_soa = mydb1
-        self.assertEqual(self.xfrsess._zone_is_empty(""), False)
+        self.assertTrue(self.xfrsess._zone_has_soa(""))
         def mydb2(zone, file):
             return False
         sqlite3_ds.get_zone_soa = mydb2
-        self.assertEqual(self.xfrsess._zone_is_empty(""), True)
+        self.assertFalse(self.xfrsess._zone_has_soa(""))
 
     def test_zone_exist(self):
         global sqlite3_ds
         def zone_exist(zone, file):
             return zone
         sqlite3_ds.zone_exist = zone_exist
-        self.assertEqual(self.xfrsess._zone_exist(True), True)
-        self.assertEqual(self.xfrsess._zone_exist(False), False)
+        self.assertTrue(self.xfrsess._zone_exist(True))
+        self.assertFalse(self.xfrsess._zone_exist(False))
 
     def test_check_xfrout_available(self):
         def zone_exist(zone):
             return zone
+        def zone_has_soa(zone):
+            return (not zone)
         self.xfrsess._zone_exist = zone_exist
-        self.xfrsess._zone_is_empty = zone_exist
+        self.xfrsess._zone_has_soa = zone_has_soa
         self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
         self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
 
         def zone_empty(zone):
-            return not zone
-        self.xfrsess._zone_is_empty = zone_empty
+            return zone
+        self.xfrsess._zone_has_soa = zone_empty
         def false_func():
             return False
         self.xfrsess.server.increase_transfers_counter = false_func

+ 16 - 9
src/bin/xfrout/xfrout.py.in

@@ -193,17 +193,22 @@ class XfroutSession(BaseRequestHandler):
         self._send_message(sock_fd, msg)
 
 
-    def _zone_is_empty(self, zone):
-        '''Judge if the zone has data.'''
+    def _zone_has_soa(self, zone):
+        '''Judge if the zone has soa records.'''
+        # In some sense, the soa defines a zone.
+        # If the current NS is the authoritative NS for the specific zone,
+        # we need to judge if the zone has soa records, if not, we consider
+        # the zone has incomplete data, so xfrout can't serve for it.
         if sqlite3_ds.get_zone_soa(zone, self.server.get_db_file()):
-            return False
+            return True
 
-        return True
+        return False
 
     def _zone_exist(self, zonename):
         '''Judge if the zone is configured by config manager.'''
         # Currently, if we find the zone in datasource successfully, we
-        # consider the zone is configured.
+        # consider the zone is configured, and the current NS are the
+        # authoritative NS for the specific zone.
         # TODO: should get zone's configuration from cfgmgr or other place
         # in future.
         return sqlite3_ds.zone_exist(zonename, self.server.get_db_file())
@@ -213,13 +218,15 @@ class XfroutSession(BaseRequestHandler):
            TODO, Get zone's configuration from cfgmgr or some other place
            eg. check allow_transfer setting,
         '''
-        # The zone isn't configured by config manager, so we are not the
-        # authoritative name server for it.
+        # If the current NS isn't the authoritative name server for the
+        # zone, xfrout can't serve for it, return rcode NOTAUTH.
         if not self._zone_exist(zone_name):
             return Rcode.NOTAUTH()
 
-        # The zone is configured but zone data is empty.
-        if self._zone_is_empty(zone_name):
+        # If we are the authoritative name server for the zone, but fail
+        # to find the zone's soa record in datasource, xfrout can't
+        # provide zone transfer for it.
+        if not self._zone_has_soa(zone_name):
             return Rcode.SERVFAIL()
 
         #TODO, check allow_transfer

+ 27 - 15
src/lib/python/isc/datasrc/sqlite3_ds.py

@@ -38,12 +38,12 @@ def create(cur):
     """Create new zone database"""
     cur.execute("CREATE TABLE schema_version (version INTEGER NOT NULL)")
     cur.execute("INSERT INTO schema_version VALUES (1)")
-    cur.execute("""CREATE TABLE zones (id INTEGER PRIMARY KEY, 
+    cur.execute("""CREATE TABLE zones (id INTEGER PRIMARY KEY,
                    name STRING NOT NULL COLLATE NOCASE,
-                   rdclass STRING NOT NULL COLLATE NOCASE DEFAULT 'IN', 
+                   rdclass STRING NOT NULL COLLATE NOCASE DEFAULT 'IN',
                    dnssec BOOLEAN NOT NULL DEFAULT 0)""")
     cur.execute("CREATE INDEX zones_byname ON zones (name)")
-    cur.execute("""CREATE TABLE records (id INTEGER PRIMARY KEY, 
+    cur.execute("""CREATE TABLE records (id INTEGER PRIMARY KEY,
                    zone_id INTEGER NOT NULL,
                    name STRING NOT NULL COLLATE NOCASE,
                    rname STRING NOT NULL COLLATE NOCASE,
@@ -53,7 +53,7 @@ def create(cur):
                    rdata STRING NOT NULL)""")
     cur.execute("CREATE INDEX records_byname ON records (name)")
     cur.execute("CREATE INDEX records_byrname ON records (rname)")
-    cur.execute("""CREATE TABLE nsec3 (id INTEGER PRIMARY KEY, 
+    cur.execute("""CREATE TABLE nsec3 (id INTEGER PRIMARY KEY,
                    zone_id INTEGER NOT NULL,
                    hash STRING NOT NULL COLLATE NOCASE,
                    owner STRING NOT NULL COLLATE NOCASE,
@@ -63,7 +63,7 @@ def create(cur):
     cur.execute("CREATE INDEX nsec3_byhash ON nsec3 (hash)")
 
 #########################################################################
-# open: open a database.  if the database is not yet set up, 
+# open: open a database.  if the database is not yet set up,
 # call create to do so.
 # input:
 #   dbfile - the filename for the sqlite3 database
@@ -72,7 +72,7 @@ def create(cur):
 #########################################################################
 def open(dbfile):
     """Open the database file.  If necessary, set it up"""
-    try: 
+    try:
         conn = sqlite3.connect(dbfile)
         cur = conn.cursor()
     except Exception as e:
@@ -93,9 +93,10 @@ def open(dbfile):
 
     return conn, cur
 
+
 #########################################################################
 # get_zone_datas
-#   a generator function producing an iterable set of 
+#   a generator function producing an iterable set of
 #   the records in the zone with the given zone name.
 #########################################################################
 def get_zone_datas(zonename, dbfile):
@@ -114,8 +115,8 @@ def get_zone_datas(zonename, dbfile):
 
 #########################################################################
 # get_zone_soa
-#   returns the soa record of the zone with the given zone name. 
-#   If the zone doesn't exist, return None. 
+#   returns the soa record of the zone with the given zone name.
+#   If the zone doesn't exist, return None.
 #########################################################################
 def get_zone_soa(zonename, dbfile):
     conn, cur = open(dbfile)
@@ -130,14 +131,14 @@ def get_zone_soa(zonename, dbfile):
 
 #########################################################################
 # get_zone_rrset
-#   returns the rrset of the zone with the given zone name, rrset name 
-#   and given rd type. 
-#   If the zone doesn't exist or rd type doesn't exist, return an empty list. 
+#   returns the rrset of the zone with the given zone name, rrset name
+#   and given rd type.
+#   If the zone doesn't exist or rd type doesn't exist, return an empty list.
 #########################################################################
 def get_zone_rrset(zonename, rr_name, rdtype, dbfile):
     conn, cur = open(dbfile)
     id = get_zoneid(zonename, cur)
-    cur.execute("SELECT * FROM records WHERE name = ? and zone_id = ? and rdtype = ?", 
+    cur.execute("SELECT * FROM records WHERE name = ? and zone_id = ? and rdtype = ?",
                 [rr_name, id, rdtype])
     datas = cur.fetchall()
     cur.close()
@@ -160,6 +161,7 @@ def get_zones_info(db_file):
     cur.close()
     conn.close()
 
+
 #########################################################################
 # get_zoneid:
 #   returns the zone_id for a given zone name, or an empty
@@ -173,12 +175,20 @@ def get_zoneid(zone, cur):
     else:
         return ''
 
+
 #########################################################################
 # zone_exist:
-#   returns True if the zone is found, otherwise False
+#   Search for the zone with the name zonename in databse. This method
+#   may throw a exception because its underlying methods open() may
+#   throw exceptions.
+# input:
+#   zonename: the zone's origin name.
+#   dbfile: the filename for the sqlite3 database.
+# returns:
+#   returns True if the zone is found, otherwise False.
 #########################################################################
 def zone_exist(zonename, dbfile):
-    conn, cur = open(db_file)
+    conn, cur = open(dbfile)
     zoneid = get_zoneid(zonename, cur)
     cur.close()
     conn.close()
@@ -186,6 +196,7 @@ def zone_exist(zonename, dbfile):
         return True
     return False
 
+
 #########################################################################
 # reverse_name:
 #   reverse the labels of a DNS name.  (for example,
@@ -201,6 +212,7 @@ def reverse_name(name):
         new.pop(0)
     return '.'.join(new)+'.'
 
+
 #########################################################################
 # load:
 #   load a zone into the SQL database.

+ 1 - 1
src/lib/python/isc/datasrc/tests/Makefile.am

@@ -1,5 +1,5 @@
 PYCOVERAGE_RUN = @PYCOVERAGE_RUN@
-PYTESTS = master_test.py
+PYTESTS = master_test.py sqlite3_ds_test.py
 EXTRA_DIST = $(PYTESTS)
 
 # test using command-line arguments, so use check-local target instead of TESTS

+ 33 - 0
src/lib/python/isc/datasrc/tests/sqlite3_ds_test.py

@@ -0,0 +1,33 @@
+# Copyright (C) 2010  Internet Systems Consortium.
+#
+# Permission to use, copy, modify, and distribute this software for any
+# purpose with or without fee is hereby granted, provided that the above
+# copyright notice and this permission notice appear in all copies.
+#
+# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
+# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
+# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
+# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
+# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
+# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
+# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
+
+from isc.datasrc import sqlite3_ds
+import socket
+import unittest
+
+class TestSqlite3_ds(unittest.TestCase):
+    def test_zone_exist(self):
+        def open(db_file):
+            conn, cur = socket.socketpair()
+            return conn, cur
+        def get_zoneid(zone_name, cur):
+            return zone_name
+        sqlite3_ds.open = open
+        sqlite3_ds.get_zoneid = get_zoneid
+        self.assertTrue(sqlite3_ds.zone_exist("example.com", "sqlite3_db"))
+        self.assertFalse(sqlite3_ds.zone_exist("", "sqlite3_db"))
+
+if __name__ == '__main__':
+    unittest.main()