Browse Source

[trac955] add check for exception error message

chenzhengzhang 14 years ago
parent
commit
8e105c57dd
2 changed files with 68 additions and 7 deletions
  1. 47 0
      src/bin/xfrin/tests/xfrin_test.py
  2. 21 7
      src/bin/xfrin/xfrin.py.in

+ 47 - 0
src/bin/xfrin/tests/xfrin_test.py

@@ -15,6 +15,7 @@
 
 import unittest
 import socket
+import io
 from isc.testutils.tsigctx_mock import MockTSIGContext
 from xfrin import *
 
@@ -293,6 +294,37 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = self.conn.create_response_data(bad_qid = True)
         self.assertRaises(XfrinException, self._handle_xfrin_response)
 
+    def test_response_error_code_bad_sig(self):
+        self.conn._tsig_key = TSIG_KEY
+        self.conn._tsig_ctx_creator = \
+            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
+        self.conn._send_query(RRType.AXFR())
+        self.conn.reply_data = self.conn.create_response_data(
+                rcode=Rcode.SERVFAIL())
+        # xfrin should check TSIG before other part of incoming message
+        # validate log message for XfrinException
+        self.conn._verbose = True
+        err_output = io.StringIO()
+        sys.stdout = err_output
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertEqual("[b10-xfrin] TSIG verify fail: BADSIG\n", err_output.getvalue())
+        err_output.close()
+
+    def test_response_bad_qid_bad_key(self):
+        self.conn._tsig_key = TSIG_KEY
+        self.conn._tsig_ctx_creator = \
+            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_KEY)
+        self.conn._send_query(RRType.AXFR())
+        self.conn.reply_data = self.conn.create_response_data(bad_qid = True)
+        # xfrin should check TSIG before other part of incoming message
+        # validate log message for XfrinException
+        self.conn._verbose = True
+        err_output = io.StringIO()
+        sys.stdout = err_output
+        self.assertRaises(XfrinException, self._handle_xfrin_response)
+        self.assertEqual("[b10-xfrin] TSIG verify fail: BADKEY\n", err_output.getvalue())
+        err_output.close()
+
     def test_response_non_response(self):
         self.conn._send_query(RRType.AXFR())
         self.conn.reply_data = self.conn.create_response_data(response = False)
@@ -337,6 +369,21 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.response_generator = self._create_soa_response_data
         self.assertRaises(XfrinException, self.conn._check_soa_serial)
 
+    def test_soacheck_bad_qid_bad_sig(self):
+        self.conn._tsig_key = TSIG_KEY
+        self.conn._tsig_ctx_creator = \
+            lambda key: self.__create_mock_tsig(key, TSIGError.BAD_SIG)
+        self.soa_response_params['bad_qid'] = True
+        self.conn.response_generator = self._create_soa_response_data
+        # xfrin should check TSIG before other part of incoming message
+        # validate log message for XfrinException
+        self.conn._verbose = True
+        err_output = io.StringIO()
+        sys.stdout = err_output
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+        self.assertEqual("[b10-xfrin] TSIG verify fail: BADSIG\n", err_output.getvalue())
+        err_output.close()
+
     def test_soacheck_non_response(self):
         self.soa_response_params['response'] = False
         self.conn.response_generator = self._create_soa_response_data

+ 21 - 7
src/bin/xfrin/xfrin.py.in

@@ -218,7 +218,9 @@ class XfrinConnection(asyncore.dispatcher):
         if self._tsig_ctx is not None:
             tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
             if tsig_error != TSIGError.NOERROR:
-                raise XfrinException('TSIG verify fail: %s' % str(tsig_error))
+                errmsg = 'TSIG verify fail: ' +  str(tsig_error)
+                self.log_msg(errmsg)
+                raise XfrinException(errmsg)
         elif tsig_record is not None:
             # If the response includes a TSIG while we didn't sign the query,
             # we treat it as an error.  RFC doesn't say anything about this
@@ -227,7 +229,9 @@ class XfrinConnection(asyncore.dispatcher):
             # implementation would return such a response, and since this is
             # part of security mechanism, it's probably better to be more
             # strict.
-            raise XfrinException('Unexpected TSIG in response')
+            errmsg = 'Unexpected TSIG in response'
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
@@ -308,13 +312,19 @@ class XfrinConnection(asyncore.dispatcher):
 
         msg_rcode = msg.get_rcode()
         if msg_rcode != Rcode.NOERROR():
-            raise XfrinException('error response: %s' % msg_rcode.to_text())
+            errmsg = 'error response: ' + msg_rcode.to_text()
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
         if not msg.get_header_flag(Message.HEADERFLAG_QR):
-            raise XfrinException('response is not a response ')
+            errmsg = 'response is not a response'
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
         if msg.get_qid() != self._query_id:
-            raise XfrinException('bad query id')
+            errmsg = 'bad query id'
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
     def _check_response_status(self, msg):
         '''Check validation of xfr response. '''
@@ -322,10 +332,14 @@ class XfrinConnection(asyncore.dispatcher):
         self._check_response_header(msg)
 
         if msg.get_rr_count(Message.SECTION_ANSWER) == 0:
-            raise XfrinException('answer section is empty')
+            errmsg = 'answer section is empty'
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
         if msg.get_rr_count(Message.SECTION_QUESTION) > 1:
-            raise XfrinException('query section count greater than 1')
+            errmsg = 'query section count greater than 1'
+            self.log_msg(errmsg)
+            raise XfrinException(errmsg)
 
     def _handle_answer_section(self, answer_section):
         '''Return a generator for the reponse in one tcp package to a zone transfer.'''