Browse Source

[2439] Perform validation of zone received by XFR

There's a lot of mocking within the tests, but it seems hard without it.
Michal 'vorner' Vaner 12 years ago
parent
commit
e382b73b40
2 changed files with 49 additions and 6 deletions
  1. 36 5
      src/bin/xfrin/tests/xfrin_test.py
  2. 13 1
      src/bin/xfrin/xfrin.py.in

+ 36 - 5
src/bin/xfrin/tests/xfrin_test.py

@@ -153,13 +153,19 @@ class MockCC(MockModuleCCSession):
     def remove_remote_config(self, module_name):
         pass
 
+class MockRRsetCollection:
+    '''
+    A mock RRset collection. We don't use it really (we mock the method that
+    it is passed to too), so it's empty.
+    '''
+    pass
+
 class MockDataSourceClient():
     '''A simple mock data source client.
 
     This class provides a minimal set of wrappers related the data source
     API that would be used by Diff objects.  For our testing purposes they
-    only keep truck of the history of the changes.
-
+    only keep track of the history of the changes.
     '''
     def __init__(self):
         self.force_fail = False # if True, raise an exception on commit
@@ -217,6 +223,12 @@ class MockDataSourceClient():
         self._journaling_enabled = journaling
         return self
 
+    def get_rrset_collection(self):
+        '''
+        Pretend to be a zone updater and provide a (dummy) rrset collection.
+        '''
+        return MockRRsetCollection()
+
     def add_rrset(self, rrset):
         self.diffs.append(('add', rrset))
 
@@ -726,11 +738,23 @@ class TestXfrinConnection(unittest.TestCase):
             'tsig_1st': None,
             'tsig_2nd': None
             }
+        self.__orig_check_zone = xfrin.check_zone
+        xfrin.check_zone = self.__check_zone
+        self._check_zone_result = True
+        self._check_zone_params = None
 
     def tearDown(self):
         self.conn.close()
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
+        xfrin.check_zone = self.__orig_check_zone
+
+    def __check_zone(self, name, rrclass, rrsets, callbacks):
+        '''
+        A mock function used instead of dns.check_zone.
+        '''
+        self._check_zone_params = (name, rrclass, rrsets, callbacks)
+        return self._check_zone_result
 
     def _create_normal_response_data(self):
         # This helper method creates a simple sequence of DNS messages that
@@ -825,6 +849,7 @@ class TestAXFR(TestXfrinConnection):
 
     def tearDown(self):
         time.time = self.orig_time_time
+        super().tearDown()
 
     def __create_mock_tsig(self, key, error, has_last_signature=True):
         # This helper function creates a MockTSIGContext for a given key
@@ -1297,10 +1322,9 @@ class TestAXFR(TestXfrinConnection):
                     [[('add', ns_rr), ('add', a_rr), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
 
-    def test_axfr_response_missing_ns(self):
+    def test_axfr_response_fail_validation(self):
         """
-        Test with transfering an invalid zone. We are missing a NS record
-        (missing a SOA is hard to do with XFR). It should be rejected.
+        Test we reject a zone transfer if it fails the check_zone validation.
         """
         a_rr = self._create_a('192.0.2.1')
         self.conn._send_query(RRType.AXFR())
@@ -1309,10 +1333,17 @@ class TestAXFR(TestXfrinConnection):
                                 RRType.AXFR())],
             # begin serial=1230, end serial=1234. end will be used.
             answers=[begin_soa_rrset, a_rr, soa_rrset])
+        # Make it fail the validation
+        self._check_zone_result = False
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.committed_diffs)
+        # Check the validation is called with the correct parameters
+        self.assertEqual(TEST_ZONE_NAME, self._check_zone_params[0])
+        self.assertEqual(TEST_RRCLASS, self._check_zone_params[1])
+        self.assertTrue(isinstance(self._check_zone_params[2],
+                                   MockRRsetCollection))
 
     def test_axfr_response_extra(self):
         '''Test with an extra RR after the end of AXFR session.

+ 13 - 1
src/bin/xfrin/xfrin.py.in

@@ -46,7 +46,7 @@ DBG_PROCESS = logger.DBGLVL_TRACE_BASIC
 DBG_COMMANDS = logger.DBGLVL_TRACE_DETAIL
 
 try:
-    from pydnspp import *
+    from isc.dns import *
 except ImportError as e:
     # C++ loadable module may not be installed; even so the xfrin process
     # must keep running, so we warn about it and move forward.
@@ -803,12 +803,24 @@ class XfrinConnection(asyncore.dispatcher):
                 raise XfrinProtocolError('TSIG verify fail: no TSIG on last '+
                                          'message')
 
+    def __validate_error(reason):
+        # TODO: Log
+        pass
+
+    def __validate_warning(reason):
+        # TODO: Log
+        pass
+
     def _finish_transfer(self):
         """
         Perform any necessary checks after a transfer. Then complete the
         transfer by commiting the transaction into the data source.
         """
         self._check_response_tsig_last()
+        if not check_zone(self._zone_name, self._rrclass,
+                          self._diff.get_rrset_collection(),
+                          (self.__validate_error, self.__validate_warning)):
+            raise XfrinProtocolError('Validation of the new zone failed')
         self._diff.commit()
 
     def __parse_soa_response(self, msg, response_data):