Parcourir la source

[master] [1299] validate the qeustion section of SOA response. refactored the
code by extracting the validation part into a separate method.

JINMEI Tatuya il y a 13 ans
Parent
commit
d3792aa7fc
2 fichiers modifiés avec 68 ajouts et 12 suppressions
  1. 27 1
      src/bin/xfrin/tests/xfrin_test.py
  2. 41 11
      src/bin/xfrin/xfrin.py.in

+ 27 - 1
src/bin/xfrin/tests/xfrin_test.py

@@ -949,7 +949,7 @@ class TestAXFR(TestXfrinConnection):
     def test_soacheck_notauth(self):
         self.soa_response_params['auth'] = False
         self.conn.response_generator = self._create_soa_response_data
-        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
 
     def test_soacheck_uptodate(self):
         # Primary's SOA serial is identical the local serial
@@ -970,6 +970,32 @@ class TestAXFR(TestXfrinConnection):
         self.conn.response_generator = self._create_soa_response_data
         self.assertRaises(XfrinZoneUptodate, self.conn._check_soa_serial)
 
+    def test_soacheck_question_empty(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = []
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_name_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(Name('example.org'),
+                                                          TEST_RRCLASS,
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_class_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          RRClass.CH(),
+                                                          RRType.SOA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
+    def test_soacheck_question_type_mismatch(self):
+        self.conn.response_generator = self._create_soa_response_data
+        self.soa_response_params['questions'] = [Question(TEST_ZONE_NAME,
+                                                          TEST_RRCLASS,
+                                                          RRType.AAAA())]
+        self.assertRaises(XfrinProtocolError, self.conn._check_soa_serial)
+
     def test_soacheck_with_tsig(self):
         # Use a mock tsig context emulating a validly signed response
         self.conn._tsig_key = TSIG_KEY

+ 41 - 11
src/bin/xfrin/xfrin.py.in

@@ -701,6 +701,44 @@ class XfrinConnection(asyncore.dispatcher):
             # strict.
             raise XfrinException('Unexpected TSIG in response')
 
+    def __parse_soa_response(self, msg, response_data):
+        '''Parse a response to SOA query and extranct the SOA from ansser.
+
+        This is a subroutine of _check_soa_serial().  This method also
+        validates message, and rejects bogus responses with XfrinProtocolError.
+
+        If evenrything is okay, it returns the SOA RR from the answer section
+        of the response.
+
+        '''
+        # Check TSIG integerity and validate the header.  Unlike AXFR/IXFR,
+        # we should be more strict for SOA queries and check the AA flag, too.
+        self._check_response_tsig(msg, response_data)
+        self._check_response_header(msg)
+        if not msg.get_header_flag(Message.HEADERFLAG_AA):
+            raise XfrinProtocolError('non-authoritative answer to SOA query')
+
+        # Validate the question section
+        n_question = msg.get_rr_count(Message.SECTION_QUESTION)
+        if n_question != 1:
+            raise XfrinProtocolError('Invalid number of questions to ' +
+                                     'SOA query (' + str(n_question) + ')')
+        resp_question = msg.get_question()[0]
+        if resp_question.get_name() != self._zone_name or \
+                resp_question.get_class() != self._rrclass or \
+                resp_question.get_type() != RRType.SOA():
+            raise XfrinProtocolError('Questions mismatch to ' +
+                                     'SOA query: ' + str(resp_question))
+
+        # Examine the answer section
+        soa = None
+        for rrset in msg.get_section(Message.SECTION_ANSWER):
+            if rrset.get_type() == RRType.SOA():
+                soa = rrset
+
+        return soa
+
+
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
@@ -715,18 +753,10 @@ class XfrinConnection(asyncore.dispatcher):
         msg = Message(Message.PARSE)
         msg.from_wire(soa_response)
 
-        # Validate the message.  Unlike AXFR/IXFR, we should be more strict
-        # for SOA queries and check the AA flag, too.
-        self._check_response_tsig(msg, soa_response)
-        self._check_response_header(msg)
-        if not msg.get_header_flag(Message.HEADERFLAG_AA):
-            raise XfrinException('non-authoritative answer to SOA query')
+        # Validate/parse the rest of the response, and extract the SOA
+        # from the answer section
+        soa = self.__parse_soa_response(msg, soa_response)
 
-        # Examine the answer section
-        soa = None
-        for rrset in msg.get_section(Message.SECTION_ANSWER):
-            if rrset.get_type() == RRType.SOA():
-                soa = rrset
         primary_serial = get_soa_serial(soa.get_rdata()[0])
         if (self._request_serial is not None) and \
                 self._request_serial >= primary_serial: