Browse Source

[1288] switch to new API 1: in pre-handle check create iterator and get SOA
using the new API. _zone_has_soa() and _zone_exit() were not needed anymore,
so were removed, so were corresponding tests.

JINMEI Tatuya 13 years ago
parent
commit
d1773b2ef6
2 changed files with 61 additions and 74 deletions
  1. 29 41
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 32 33
      src/bin/xfrout/xfrout.py.in

+ 29 - 41
src/bin/xfrout/tests/xfrout_test.py.in

@@ -21,7 +21,7 @@ import os
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.cc.session import *
 import isc.config
-from pydnspp import *
+from isc.dns import *
 from xfrout import *
 import xfrout
 import isc.log
@@ -81,7 +81,7 @@ class Dbserver:
     def __init__(self):
         self._shutdown_event = threading.Event()
     def get_db_file(self):
-        return None
+        return 'test.sqlite3'
     def increase_transfers_counter(self):
         return True
     def decrease_transfers_counter(self):
@@ -511,46 +511,34 @@ class TestXfroutSession(unittest.TestCase):
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(82, get_rrset_len(rrset_soa))
 
-    def test_zone_has_soa(self):
-        global sqlite3_ds
-        def mydb1(zone, file):
-            return True
-        sqlite3_ds.get_zone_soa = mydb1
-        self.assertTrue(self.xfrsess._zone_has_soa(""))
-        def mydb2(zone, file):
-            return False
-        sqlite3_ds.get_zone_soa = mydb2
-        self.assertFalse(self.xfrsess._zone_has_soa(""))
-
-    def test_zone_exist(self):
-        global sqlite3_ds
-        def zone_exist(zone, file):
-            return zone
-        sqlite3_ds.zone_exist = zone_exist
-        self.assertTrue(self.xfrsess._zone_exist(True))
-        self.assertFalse(self.xfrsess._zone_exist(False))
-
     def test_check_xfrout_available(self):
-        def zone_exist(zone):
-            return zone
-        def zone_has_soa(zone):
-            return (not zone)
-        self.xfrsess._zone_exist = zone_exist
-        self.xfrsess._zone_has_soa = zone_has_soa
-        self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
-
-        def zone_empty(zone):
-            return zone
-        self.xfrsess._zone_has_soa = zone_empty
-        def false_func():
-            return False
-        self.xfrsess._server.increase_transfers_counter = false_func
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
-        def true_func():
-            return True
-        self.xfrsess._server.increase_transfers_counter = true_func
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
+        class MockDataSrcClient:
+            def __init__(self, type, config): pass
+
+            def get_iterator(self, zone_name):
+                if zone_name == Name('notauth.example.com'):
+                    raise isc.datasrc.Error('no such zone')
+                self._zone_name = zone_name
+                return self
+
+            def get_soa(self):  # emulate ZoneIterator.get_soa()
+                if self._zone_name == Name('nosoa.example.com'):
+                    return None
+                # returning True on success is wrong, but works for this test.
+                return True
+
+        self.xfrsess.ClientClass = MockDataSrcClient
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('notauth.example.com')), Rcode.NOTAUTH())
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('nosoa.example.com')), Rcode.SERVFAIL())
+
+        self.xfrsess._server.increase_transfers_counter = lambda : False
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('example.com')), Rcode.REFUSED())
+        self.xfrsess._server.increase_transfers_counter = lambda : True
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('example.com')), Rcode.NOERROR())
 
     def test_dns_xfrout_start_formerror(self):
         # formerror

+ 32 - 33
src/bin/xfrout/xfrout.py.in

@@ -22,7 +22,8 @@ import isc.cc
 import threading
 import struct
 import signal
-from isc.datasrc import sqlite3_ds
+from isc.datasrc import sqlite3_ds # should be obsoleted
+from isc.datasrc import DataSourceClient, ZoneFinder
 from socketserver import *
 import os
 from isc.config.ccsession import *
@@ -106,7 +107,7 @@ def get_rrset_len(rrset):
 
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
-                 default_acl, zone_config):
+                 default_acl, zone_config, client_class=DataSourceClient):
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._server = server
@@ -116,6 +117,7 @@ class XfroutSession():
         self._remote = remote
         self._acl = default_acl
         self._zone_config = zone_config
+        self.ClientClass = client_class # parameterize this for testing
         self.handle()
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -238,27 +240,6 @@ class XfroutSession():
         msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
-    def _zone_has_soa(self, zone):
-        '''Judge if the zone has an SOA record.'''
-        # In some sense, the SOA defines a zone.
-        # If the current name server has authority for the
-        # specific zone, we need to judge if the zone has an SOA record;
-        # if not, we consider the zone has incomplete data, so xfrout can't
-        # serve for it.
-        if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
-            return True
-
-        return False
-
-    def _zone_exist(self, zonename):
-        '''Judge if the zone is configured by config manager.'''
-        # Currently, if we find the zone in datasource successfully, we
-        # consider the zone is configured, and the current name server has
-        # authority for the specific zone.
-        # TODO: should get zone's configuration from cfgmgr or other place
-        # in future.
-        return sqlite3_ds.zone_exist(zonename, self._server.get_db_file())
-
     def _check_xfrout_available(self, zone_name):
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
@@ -270,15 +251,32 @@ class XfroutSession():
         if not self._server.increase_transfers_counter():
             return Rcode.REFUSED()
 
-        # If the current name server does not have authority for the
-        # zone, xfrout can't serve for it, return rcode NOTAUTH.
-        if not self._zone_exist(zone_name):
+        # Identify the data source for the requested zone and see if it has
+        # SOA while initializing objects used for request processing later.
+        # We should eventually generalize this so that choose the appropriate
+        # data source from (possible) multiple candidates.  We should
+        # eventually take into account the RR class here.
+        # For now, we  hardcode a particular type (SQLite3-based), and only
+        # consider that one.
+        datasrc_config = '{ \"database_file\": \"' + \
+            self._server.get_db_file() + '\"}'
+        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
+        try:
+            self._iterator = self._datasrc_client.get_iterator(zone_name)
+        except isc.datasrc.Error as error:
+            # If the current name server does not have authority for the
+            # zone, xfrout can't serve for it, return rcode NOTAUTH.
+            # Note: this exception can happen for other reasons.  We should
+            # update get_iterator() API so that we can distinguish "no such
+            # zone" and other cases.  For now we consider all these cases
+            # as NOTAUTH.
             return Rcode.NOTAUTH()
 
         # If we are an authoritative name server for the zone, but fail
         # to find the zone's SOA record in datasource, xfrout can't
         # provide zone transfer for it.
-        if not self._zone_has_soa(zone_name):
+        self._soa = self._iterator.get_soa()
+        if self._soa is None:
             return Rcode.SERVFAIL()
 
         #TODO, check allow_transfer
@@ -297,24 +295,25 @@ class XfroutSession():
             return self._reply_query_with_error_rcode(msg, sock_fd,
                                                       Rcode.FORMERR())
 
-        zone_name = self._get_query_zone_name(msg)
+        zone_name = msg.get_question()[0].get_name()
         zone_class_str = self._get_query_zone_class(msg)
         # TODO: should we not also include class in the check?
         rcode_ = self._check_xfrout_available(zone_name)
 
         if rcode_ != Rcode.NOERROR():
-            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
+            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name.to_text(),
                         zone_class_str, rcode_.to_text())
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
         try:
-            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
-            self._reply_xfrout_query(msg, sock_fd, zone_name)
+            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name.to_text(),
+                        zone_class_str)
+            self._reply_xfrout_query(msg, sock_fd, zone_name.to_text())
         except Exception as err:
-            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name,
+            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name.to_text(),
                          zone_class_str, str(err))
             pass
-        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name, zone_class_str)
+        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name.to_text(), zone_class_str)
 
         self._server.decrease_transfers_counter()
         return