Browse Source

[2439] Perform validation of zone received by XFR

There's a lot of mocking within the tests, but it seems hard without it.
Michal 'vorner' Vaner 12 years ago
parent
commit
e382b73b40
2 changed files with 49 additions and 6 deletions
  1. 36 5
      src/bin/xfrin/tests/xfrin_test.py
  2. 13 1
      src/bin/xfrin/xfrin.py.in

+ 36 - 5
src/bin/xfrin/tests/xfrin_test.py

@@ -153,13 +153,19 @@ class MockCC(MockModuleCCSession):
     def remove_remote_config(self, module_name):
     def remove_remote_config(self, module_name):
         pass
         pass
 
 
+class MockRRsetCollection:
+    '''
+    A mock RRset collection. We don't use it really (we mock the method that
+    it is passed to too), so it's empty.
+    '''
+    pass
+
 class MockDataSourceClient():
 class MockDataSourceClient():
     '''A simple mock data source client.
     '''A simple mock data source client.
 
 
     This class provides a minimal set of wrappers related the data source
     This class provides a minimal set of wrappers related the data source
     API that would be used by Diff objects.  For our testing purposes they
     API that would be used by Diff objects.  For our testing purposes they
-    only keep truck of the history of the changes.
-
+    only keep track of the history of the changes.
     '''
     '''
     def __init__(self):
     def __init__(self):
         self.force_fail = False # if True, raise an exception on commit
         self.force_fail = False # if True, raise an exception on commit
@@ -217,6 +223,12 @@ class MockDataSourceClient():
         self._journaling_enabled = journaling
         self._journaling_enabled = journaling
         return self
         return self
 
 
+    def get_rrset_collection(self):
+        '''
+        Pretend to be a zone updater and provide a (dummy) rrset collection.
+        '''
+        return MockRRsetCollection()
+
     def add_rrset(self, rrset):
     def add_rrset(self, rrset):
         self.diffs.append(('add', rrset))
         self.diffs.append(('add', rrset))
 
 
@@ -726,11 +738,23 @@ class TestXfrinConnection(unittest.TestCase):
             'tsig_1st': None,
             'tsig_1st': None,
             'tsig_2nd': None
             'tsig_2nd': None
             }
             }
+        self.__orig_check_zone = xfrin.check_zone
+        xfrin.check_zone = self.__check_zone
+        self._check_zone_result = True
+        self._check_zone_params = None
 
 
     def tearDown(self):
     def tearDown(self):
         self.conn.close()
         self.conn.close()
         if os.path.exists(TEST_DB_FILE):
         if os.path.exists(TEST_DB_FILE):
             os.remove(TEST_DB_FILE)
             os.remove(TEST_DB_FILE)
+        xfrin.check_zone = self.__orig_check_zone
+
+    def __check_zone(self, name, rrclass, rrsets, callbacks):
+        '''
+        A mock function used instead of dns.check_zone.
+        '''
+        self._check_zone_params = (name, rrclass, rrsets, callbacks)
+        return self._check_zone_result
 
 
     def _create_normal_response_data(self):
     def _create_normal_response_data(self):
         # This helper method creates a simple sequence of DNS messages that
         # This helper method creates a simple sequence of DNS messages that
@@ -825,6 +849,7 @@ class TestAXFR(TestXfrinConnection):
 
 
     def tearDown(self):
     def tearDown(self):
         time.time = self.orig_time_time
         time.time = self.orig_time_time
+        super().tearDown()
 
 
     def __create_mock_tsig(self, key, error, has_last_signature=True):
     def __create_mock_tsig(self, key, error, has_last_signature=True):
         # This helper function creates a MockTSIGContext for a given key
         # This helper function creates a MockTSIGContext for a given key
@@ -1297,10 +1322,9 @@ class TestAXFR(TestXfrinConnection):
                     [[('add', ns_rr), ('add', a_rr), ('add', soa_rrset)]],
                     [[('add', ns_rr), ('add', a_rr), ('add', soa_rrset)]],
                     self.conn._datasrc_client.committed_diffs)
                     self.conn._datasrc_client.committed_diffs)
 
 
-    def test_axfr_response_missing_ns(self):
+    def test_axfr_response_fail_validation(self):
         """
         """
-        Test with transfering an invalid zone. We are missing a NS record
-        (missing a SOA is hard to do with XFR). It should be rejected.
+        Test we reject a zone transfer if it fails the check_zone validation.
         """
         """
         a_rr = self._create_a('192.0.2.1')
         a_rr = self._create_a('192.0.2.1')
         self.conn._send_query(RRType.AXFR())
         self.conn._send_query(RRType.AXFR())
@@ -1309,10 +1333,17 @@ class TestAXFR(TestXfrinConnection):
                                 RRType.AXFR())],
                                 RRType.AXFR())],
             # begin serial=1230, end serial=1234. end will be used.
             # begin serial=1230, end serial=1234. end will be used.
             answers=[begin_soa_rrset, a_rr, soa_rrset])
             answers=[begin_soa_rrset, a_rr, soa_rrset])
+        # Make it fail the validation
+        self._check_zone_result = False
         self.assertRaises(XfrinProtocolError,
         self.assertRaises(XfrinProtocolError,
                           self.conn._handle_xfrin_responses)
                           self.conn._handle_xfrin_responses)
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual(type(XfrinAXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.committed_diffs)
         self.assertEqual([], self.conn._datasrc_client.committed_diffs)
+        # Check the validation is called with the correct parameters
+        self.assertEqual(TEST_ZONE_NAME, self._check_zone_params[0])
+        self.assertEqual(TEST_RRCLASS, self._check_zone_params[1])
+        self.assertTrue(isinstance(self._check_zone_params[2],
+                                   MockRRsetCollection))
 
 
     def test_axfr_response_extra(self):
     def test_axfr_response_extra(self):
         '''Test with an extra RR after the end of AXFR session.
         '''Test with an extra RR after the end of AXFR session.

+ 13 - 1
src/bin/xfrin/xfrin.py.in

@@ -46,7 +46,7 @@ DBG_PROCESS = logger.DBGLVL_TRACE_BASIC
 DBG_COMMANDS = logger.DBGLVL_TRACE_DETAIL
 DBG_COMMANDS = logger.DBGLVL_TRACE_DETAIL
 
 
 try:
 try:
-    from pydnspp import *
+    from isc.dns import *
 except ImportError as e:
 except ImportError as e:
     # C++ loadable module may not be installed; even so the xfrin process
     # C++ loadable module may not be installed; even so the xfrin process
     # must keep running, so we warn about it and move forward.
     # must keep running, so we warn about it and move forward.
@@ -803,12 +803,24 @@ class XfrinConnection(asyncore.dispatcher):
                 raise XfrinProtocolError('TSIG verify fail: no TSIG on last '+
                 raise XfrinProtocolError('TSIG verify fail: no TSIG on last '+
                                          'message')
                                          'message')
 
 
+    def __validate_error(reason):
+        # TODO: Log
+        pass
+
+    def __validate_warning(reason):
+        # TODO: Log
+        pass
+
     def _finish_transfer(self):
     def _finish_transfer(self):
         """
         """
         Perform any necessary checks after a transfer. Then complete the
         Perform any necessary checks after a transfer. Then complete the
         transfer by commiting the transaction into the data source.
         transfer by commiting the transaction into the data source.
         """
         """
         self._check_response_tsig_last()
         self._check_response_tsig_last()
+        if not check_zone(self._zone_name, self._rrclass,
+                          self._diff.get_rrset_collection(),
+                          (self.__validate_error, self.__validate_warning)):
+            raise XfrinProtocolError('Validation of the new zone failed')
         self._diff.commit()
         self._diff.commit()
 
 
     def __parse_soa_response(self, msg, response_data):
     def __parse_soa_response(self, msg, response_data):