Browse Source

[1288] switch to new API 1: in pre-handle check create iterator and get SOA
using the new API. _zone_has_soa() and _zone_exit() were not needed anymore,
so were removed, so were corresponding tests.

JINMEI Tatuya 13 years ago
parent
commit
d1773b2ef6
2 changed files with 61 additions and 74 deletions
  1. 29 41
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 32 33
      src/bin/xfrout/xfrout.py.in

+ 29 - 41
src/bin/xfrout/tests/xfrout_test.py.in

@@ -21,7 +21,7 @@ import os
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from isc.cc.session import *
 from isc.cc.session import *
 import isc.config
 import isc.config
-from pydnspp import *
+from isc.dns import *
 from xfrout import *
 from xfrout import *
 import xfrout
 import xfrout
 import isc.log
 import isc.log
@@ -81,7 +81,7 @@ class Dbserver:
     def __init__(self):
     def __init__(self):
         self._shutdown_event = threading.Event()
         self._shutdown_event = threading.Event()
     def get_db_file(self):
     def get_db_file(self):
-        return None
+        return 'test.sqlite3'
     def increase_transfers_counter(self):
     def increase_transfers_counter(self):
         return True
         return True
     def decrease_transfers_counter(self):
     def decrease_transfers_counter(self):
@@ -511,46 +511,34 @@ class TestXfroutSession(unittest.TestCase):
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         rrset_soa = self.xfrsess._create_rrset_from_db_record(self.soa_record)
         self.assertEqual(82, get_rrset_len(rrset_soa))
         self.assertEqual(82, get_rrset_len(rrset_soa))
 
 
-    def test_zone_has_soa(self):
-        global sqlite3_ds
-        def mydb1(zone, file):
-            return True
-        sqlite3_ds.get_zone_soa = mydb1
-        self.assertTrue(self.xfrsess._zone_has_soa(""))
-        def mydb2(zone, file):
-            return False
-        sqlite3_ds.get_zone_soa = mydb2
-        self.assertFalse(self.xfrsess._zone_has_soa(""))
-
-    def test_zone_exist(self):
-        global sqlite3_ds
-        def zone_exist(zone, file):
-            return zone
-        sqlite3_ds.zone_exist = zone_exist
-        self.assertTrue(self.xfrsess._zone_exist(True))
-        self.assertFalse(self.xfrsess._zone_exist(False))
-
     def test_check_xfrout_available(self):
     def test_check_xfrout_available(self):
-        def zone_exist(zone):
-            return zone
-        def zone_has_soa(zone):
-            return (not zone)
-        self.xfrsess._zone_exist = zone_exist
-        self.xfrsess._zone_has_soa = zone_has_soa
-        self.assertEqual(self.xfrsess._check_xfrout_available(False).to_text(), "NOTAUTH")
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "SERVFAIL")
-
-        def zone_empty(zone):
-            return zone
-        self.xfrsess._zone_has_soa = zone_empty
-        def false_func():
-            return False
-        self.xfrsess._server.increase_transfers_counter = false_func
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "REFUSED")
-        def true_func():
-            return True
-        self.xfrsess._server.increase_transfers_counter = true_func
-        self.assertEqual(self.xfrsess._check_xfrout_available(True).to_text(), "NOERROR")
+        class MockDataSrcClient:
+            def __init__(self, type, config): pass
+
+            def get_iterator(self, zone_name):
+                if zone_name == Name('notauth.example.com'):
+                    raise isc.datasrc.Error('no such zone')
+                self._zone_name = zone_name
+                return self
+
+            def get_soa(self):  # emulate ZoneIterator.get_soa()
+                if self._zone_name == Name('nosoa.example.com'):
+                    return None
+                # returning True on success is wrong, but works for this test.
+                return True
+
+        self.xfrsess.ClientClass = MockDataSrcClient
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('notauth.example.com')), Rcode.NOTAUTH())
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('nosoa.example.com')), Rcode.SERVFAIL())
+
+        self.xfrsess._server.increase_transfers_counter = lambda : False
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('example.com')), Rcode.REFUSED())
+        self.xfrsess._server.increase_transfers_counter = lambda : True
+        self.assertEqual(self.xfrsess._check_xfrout_available(
+                Name('example.com')), Rcode.NOERROR())
 
 
     def test_dns_xfrout_start_formerror(self):
     def test_dns_xfrout_start_formerror(self):
         # formerror
         # formerror

+ 32 - 33
src/bin/xfrout/xfrout.py.in

@@ -22,7 +22,8 @@ import isc.cc
 import threading
 import threading
 import struct
 import struct
 import signal
 import signal
-from isc.datasrc import sqlite3_ds
+from isc.datasrc import sqlite3_ds # should be obsoleted
+from isc.datasrc import DataSourceClient, ZoneFinder
 from socketserver import *
 from socketserver import *
 import os
 import os
 from isc.config.ccsession import *
 from isc.config.ccsession import *
@@ -106,7 +107,7 @@ def get_rrset_len(rrset):
 
 
 class XfroutSession():
 class XfroutSession():
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
     def __init__(self, sock_fd, request_data, server, tsig_key_ring, remote,
-                 default_acl, zone_config):
+                 default_acl, zone_config, client_class=DataSourceClient):
         self._sock_fd = sock_fd
         self._sock_fd = sock_fd
         self._request_data = request_data
         self._request_data = request_data
         self._server = server
         self._server = server
@@ -116,6 +117,7 @@ class XfroutSession():
         self._remote = remote
         self._remote = remote
         self._acl = default_acl
         self._acl = default_acl
         self._zone_config = zone_config
         self._zone_config = zone_config
+        self.ClientClass = client_class # parameterize this for testing
         self.handle()
         self.handle()
 
 
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
     def create_tsig_ctx(self, tsig_record, tsig_key_ring):
@@ -238,27 +240,6 @@ class XfroutSession():
         msg.set_rcode(rcode_)
         msg.set_rcode(rcode_)
         self._send_message(sock_fd, msg, self._tsig_ctx)
         self._send_message(sock_fd, msg, self._tsig_ctx)
 
 
-    def _zone_has_soa(self, zone):
-        '''Judge if the zone has an SOA record.'''
-        # In some sense, the SOA defines a zone.
-        # If the current name server has authority for the
-        # specific zone, we need to judge if the zone has an SOA record;
-        # if not, we consider the zone has incomplete data, so xfrout can't
-        # serve for it.
-        if sqlite3_ds.get_zone_soa(zone, self._server.get_db_file()):
-            return True
-
-        return False
-
-    def _zone_exist(self, zonename):
-        '''Judge if the zone is configured by config manager.'''
-        # Currently, if we find the zone in datasource successfully, we
-        # consider the zone is configured, and the current name server has
-        # authority for the specific zone.
-        # TODO: should get zone's configuration from cfgmgr or other place
-        # in future.
-        return sqlite3_ds.zone_exist(zonename, self._server.get_db_file())
-
     def _check_xfrout_available(self, zone_name):
     def _check_xfrout_available(self, zone_name):
         '''Check if xfr request can be responsed.
         '''Check if xfr request can be responsed.
            TODO, Get zone's configuration from cfgmgr or some other place
            TODO, Get zone's configuration from cfgmgr or some other place
@@ -270,15 +251,32 @@ class XfroutSession():
         if not self._server.increase_transfers_counter():
         if not self._server.increase_transfers_counter():
             return Rcode.REFUSED()
             return Rcode.REFUSED()
 
 
-        # If the current name server does not have authority for the
-        # zone, xfrout can't serve for it, return rcode NOTAUTH.
-        if not self._zone_exist(zone_name):
+        # Identify the data source for the requested zone and see if it has
+        # SOA while initializing objects used for request processing later.
+        # We should eventually generalize this so that choose the appropriate
+        # data source from (possible) multiple candidates.  We should
+        # eventually take into account the RR class here.
+        # For now, we  hardcode a particular type (SQLite3-based), and only
+        # consider that one.
+        datasrc_config = '{ \"database_file\": \"' + \
+            self._server.get_db_file() + '\"}'
+        self._datasrc_client = self.ClientClass('sqlite3', datasrc_config)
+        try:
+            self._iterator = self._datasrc_client.get_iterator(zone_name)
+        except isc.datasrc.Error as error:
+            # If the current name server does not have authority for the
+            # zone, xfrout can't serve for it, return rcode NOTAUTH.
+            # Note: this exception can happen for other reasons.  We should
+            # update get_iterator() API so that we can distinguish "no such
+            # zone" and other cases.  For now we consider all these cases
+            # as NOTAUTH.
             return Rcode.NOTAUTH()
             return Rcode.NOTAUTH()
 
 
         # If we are an authoritative name server for the zone, but fail
         # If we are an authoritative name server for the zone, but fail
         # to find the zone's SOA record in datasource, xfrout can't
         # to find the zone's SOA record in datasource, xfrout can't
         # provide zone transfer for it.
         # provide zone transfer for it.
-        if not self._zone_has_soa(zone_name):
+        self._soa = self._iterator.get_soa()
+        if self._soa is None:
             return Rcode.SERVFAIL()
             return Rcode.SERVFAIL()
 
 
         #TODO, check allow_transfer
         #TODO, check allow_transfer
@@ -297,24 +295,25 @@ class XfroutSession():
             return self._reply_query_with_error_rcode(msg, sock_fd,
             return self._reply_query_with_error_rcode(msg, sock_fd,
                                                       Rcode.FORMERR())
                                                       Rcode.FORMERR())
 
 
-        zone_name = self._get_query_zone_name(msg)
+        zone_name = msg.get_question()[0].get_name()
         zone_class_str = self._get_query_zone_class(msg)
         zone_class_str = self._get_query_zone_class(msg)
         # TODO: should we not also include class in the check?
         # TODO: should we not also include class in the check?
         rcode_ = self._check_xfrout_available(zone_name)
         rcode_ = self._check_xfrout_available(zone_name)
 
 
         if rcode_ != Rcode.NOERROR():
         if rcode_ != Rcode.NOERROR():
-            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name,
+            logger.info(XFROUT_AXFR_TRANSFER_FAILED, zone_name.to_text(),
                         zone_class_str, rcode_.to_text())
                         zone_class_str, rcode_.to_text())
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
             return self._reply_query_with_error_rcode(msg, sock_fd, rcode_)
 
 
         try:
         try:
-            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name, zone_class_str)
-            self._reply_xfrout_query(msg, sock_fd, zone_name)
+            logger.info(XFROUT_AXFR_TRANSFER_STARTED, zone_name.to_text(),
+                        zone_class_str)
+            self._reply_xfrout_query(msg, sock_fd, zone_name.to_text())
         except Exception as err:
         except Exception as err:
-            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name,
+            logger.error(XFROUT_AXFR_TRANSFER_ERROR, zone_name.to_text(),
                          zone_class_str, str(err))
                          zone_class_str, str(err))
             pass
             pass
-        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name, zone_class_str)
+        logger.info(XFROUT_AXFR_TRANSFER_DONE, zone_name.to_text(), zone_class_str)
 
 
         self._server.decrease_transfers_counter()
         self._server.decrease_transfers_counter()
         return
         return