Parcourir la source

[1372] added a test case for an impossible (buggy) case. also changed
the behavior against impossible # of questions.

JINMEI Tatuya il y a 13 ans
Parent
commit
a24c6579ab
2 fichiers modifiés avec 23 ajouts et 13 suppressions
  1. 13 6
      src/bin/xfrout/tests/xfrout_test.py.in
  2. 10 7
      src/bin/xfrout/xfrout.py.in

+ 13 - 6
src/bin/xfrout/tests/xfrout_test.py.in

@@ -218,7 +218,7 @@ class TestXfroutSessionBase(unittest.TestCase):
         return msg.get_tsig_record() is not None
 
     def create_request_data(self, with_question=True, with_tsig=False,
-                            ixfr=None, zone_name=TEST_ZONE_NAME,
+                            ixfr=None, qtype=None, zone_name=TEST_ZONE_NAME,
                             soa_class=TEST_RRCLASS, num_soa=1):
         '''Create a commonly used XFR request data.
 
@@ -229,6 +229,9 @@ class TestXfroutSessionBase(unittest.TestCase):
 
         This method has various minor parameters only for creating bad
         format requests for testing purposes:
+        qtype: the RR type of the question section.  By default automatically
+               determined by the value of ixfr, but could be an invalid type
+               for testing.
         zone_name: the query (zone) name.  for IXFR, it's also used as
                    the owner name of the SOA in the authority section.
         soa_class: IXFR only.  The RR class of the SOA RR in the authority
@@ -243,7 +246,8 @@ class TestXfroutSessionBase(unittest.TestCase):
         msg.set_rcode(Rcode.NOERROR())
         req_type = RRType.AXFR() if ixfr is None else RRType.IXFR()
         if with_question:
-            msg.add_question(Question(zone_name, RRClass.IN(), req_type))
+            msg.add_question(Question(zone_name, RRClass.IN(),
+                                      req_type if qtype is None else qtype))
         if req_type == RRType.IXFR():
             soa = RRset(zone_name, soa_class, RRType.SOA(), RRTTL(0))
             # In the RDATA only the serial matters.
@@ -313,7 +317,7 @@ class TestXfroutSession(TestXfroutSessionBase):
         # set up a bogus request, which should result in FORMERR. (it only
         # has to be something that is different from the previous case)
         self.xfrsess._request_data = \
-            self.create_request_data(with_question=False)
+            self.create_request_data(ixfr=IXFR_OK_VERSION, num_soa=2)
         # Replace the data source client to avoid datasrc related exceptions
         self.xfrsess.ClientClass = MockDataSrcClient
         XfroutSession._handle(self.xfrsess)
@@ -344,9 +348,12 @@ class TestXfroutSession(TestXfroutSessionBase):
         self.assertEqual(Rcode.NOERROR(), rcode)
 
         # Broken request: no question
-        request_data = self.create_request_data(with_question=False)
-        rcode, msg = self.xfrsess._parse_query_message(request_data)
-        self.assertEqual(Rcode.FORMERR(), rcode)
+        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,
+                          self.create_request_data(with_question=False))
+
+        # Broken request: invalid RR type (neither AXFR nor IXFR)
+        self.assertRaises(RuntimeError, self.xfrsess._parse_query_message,
+                          self.create_request_data(qtype=RRType.A()))
 
         # tsig signed query message
         request_data = self.create_request_data(with_tsig=True)

+ 10 - 7
src/bin/xfrout/xfrout.py.in

@@ -201,7 +201,8 @@ class XfroutSession():
         tsig_record = msg.get_tsig_record()
         if tsig_record is not None:
             self._tsig_len = tsig_record.get_length()
-            self._tsig_ctx = self.create_tsig_ctx(tsig_record, self._tsig_key_ring)
+            self._tsig_ctx = self.create_tsig_ctx(tsig_record,
+                                                  self._tsig_key_ring)
             tsig_error = self._tsig_ctx.verify(tsig_record, request_data)
             if tsig_error != TSIGError.NOERROR:
                 return Rcode.NOTAUTH()
@@ -224,10 +225,12 @@ class XfroutSession():
             return rcode, msg
 
         # Make sure the question is valid.  This should be ensured by
-        # the auth server, but since it's far from our xfrout itself,
-        # we check it by ourselves.
+        # the auth server, but since it's far from xfrout itself, we check
+        # it by ourselves.  A viloation would be an internal bug, so we
+        # raise and stop here rather than returning a FORMERR or SERVFAIL.
         if msg.get_rr_count(Message.SECTION_QUESTION) != 1:
-            return Rcode.FORMERR(), msg
+            raise RuntimeError('Invalid number of question for XFR: ' +
+                               str(msg.get_rr_count(Message.SECTION_QUESTION)))
         question = msg.get_question()[0]
 
         # Identify the request type
@@ -237,9 +240,9 @@ class XfroutSession():
         elif self._request_type == RRType.IXFR():
             self._request_typestr = 'IXFR'
         else:
-            # Likewise, this should be impossible.  (TBD: to be tested)
-            raise RuntimeError('Unexpected XFR type: ' + \
-                                   str(self._request_type))
+            # Likewise, this should be impossible.
+            raise RuntimeError('Unexpected XFR type: ' +
+                               str(self._request_type))
 
         # ACL checks
         zone_name = question.get_name()