Browse Source

[trac914] implemented TSIG verify() support in xfrin

JINMEI Tatuya 14 years ago
parent
commit
58a5aabf65
2 changed files with 170 additions and 14 deletions
  1. 145 10
      src/bin/xfrin/tests/xfrin_test.py
  2. 25 4
      src/bin/xfrin/xfrin.py.in

+ 145 - 10
src/bin/xfrin/tests/xfrin_test.py

@@ -125,10 +125,11 @@ class MockXfrinConnection(XfrinConnection):
                 self.response_generator()
         return len(data)
 
-    def create_response_data(self, response = True, bad_qid = False,
-                             rcode = Rcode.NOERROR(),
-                             questions = default_questions,
-                             answers = default_answers):
+    def create_response_data(self, response=True, bad_qid=False,
+                             rcode=Rcode.NOERROR(),
+                             questions=default_questions,
+                             answers=default_answers,
+                             tsig=False):
         resp = Message(Message.RENDER)
         qid = self.qid
         if bad_qid:
@@ -142,7 +143,13 @@ class MockXfrinConnection(XfrinConnection):
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
 
         renderer = MessageRenderer()
-        resp.to_wire(renderer)
+        if tsig:
+            # for now, we don't need a valid SIG.  We only need to include
+            # TSIG RR.  So how to add it and which key is used don't matter.
+            tsig_ctx = TSIGContext(TSIG_KEY)
+            resp.to_wire(renderer, tsig_ctx)
+        else:
+            resp.to_wire(renderer)
         reply_data = struct.pack('H', socket.htons(renderer.get_length()))
         reply_data += renderer.get_data()
 
@@ -157,14 +164,18 @@ class TestXfrinConnection(unittest.TestCase):
                                         TEST_RRCLASS, TEST_DB_FILE,
                                         threading.Event(),
                                         TEST_MASTER_IPV4_ADDRINFO)
-        self.axfr_after_soa = False
         self.soa_response_params = {
             'questions': [example_soa_question],
             'bad_qid': False,
             'response': True,
             'rcode': Rcode.NOERROR(),
+            'tsig': False,
             'axfr_after_soa': self._create_normal_response_data
             }
+        self.axfr_response_params = {
+            'tsig_1st': False,
+            'tsig_2nd': False
+            }
 
     def tearDown(self):
         self.conn.close()
@@ -240,7 +251,7 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = b'aaaxxxx'
         self.assertRaises(XfrinTestException, self._handle_xfrin_response)
 
-    def test_response_with_tsig(self):
+    def test_response_with_tsigfail(self):
         self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         self.conn._send_query(RRType.SOA())
@@ -311,6 +322,52 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.response_generator = self._create_soa_response_data
         self.assertRaises(XfrinException, self.conn._check_soa_serial)
 
+    def test_soacheck_with_tsig(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_soa_response_data
+        # emulate a validly signed response
+        self.conn._tsig_ctx.error = TSIGError.NOERROR
+        self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK)
+        self.assertEqual(self.conn._tsig_ctx.get_error(), TSIGError.NOERROR)
+
+    def test_soacheck_with_tsig_notauth(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+
+        # emulate a valid error response
+        self.soa_response_params['rcode'] = Rcode.NOTAUTH()
+        self.conn.response_generator = self._create_soa_response_data
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_tsig_noerror_badsig(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+
+        # emulate a normal response bad verification failure due to BADSIG.
+        # According RFC2845, in this case we should ignore it and keep
+        # waiting for a valid response until a timeout.  But we immediately
+        # treat this as a final failure (just as BIND 9 does).
+        self.conn.response_generator = self._create_soa_response_data
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_tsig_unsigned_response(self):
+        # we can use a real TSIGContext for this.  the response doesn't
+        # contain a TSIG while we sent a signed query.  RFC2845 states
+        # we should wait for a valid response in this case, but we treat
+        # it as a fatal transaction failure, too.
+        self.conn._tsig_ctx = TSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_unexpected_tsig_response(self):
+        # we reject unexpected TSIG in responses (following BIND 9's
+        # behavior)
+        self.soa_response_params['tsig'] = True
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn._shutdown_event.set()
@@ -344,6 +401,81 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.response_generator = self._create_normal_response_data
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
 
+    def test_do_xfrin_with_tsig(self):
+        # use TSIG with a mock context.  we fake all verify results to
+        # emulate successful verification.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = TSIGError.NOERROR
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
+        # We use two messages in the tests.  The same context should have been
+        # usef for both.
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_tsig_fail(self):
+        # TSIG verify will fail for the first message.  xfrin should fail
+        # immediately.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(1, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_tsig_fail_for_second_message(self):
+        # Similar to the previous test, but first verify succeeds.  There
+        # should be a second verify attempt, which will fail, which should
+        # make xfrin fail.
+        def fake_tsig_error(ctx):
+            if self.conn._tsig_ctx.verify_called == 1:
+                return TSIGError.NOERROR
+            return TSIGError.BAD_SIG
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = fake_tsig_error
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_missing_tsig(self):
+        # XFR request sent with TSIG, but the response doesn't have TSIG.
+        # xfr should fail.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(1, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_missing_tsig_for_second_message(self):
+        # Similar to the previous test, but firt one contains TSIG and verify
+        # succeeds (due to fake).  The second message lacks TSIG.
+        #
+        # Note: this test case is actually not that trivial:  Skipping
+        # intermediate TSIG is allowed.  In this case, however, the second
+        # message is the last one, which must contain TSIG anyway, so the
+        # expected result is correct.  If/when we support skipping
+        # intermediate TSIGs, we'll need additional test cases.
+        def fake_tsig_error(ctx):
+            if self.conn._tsig_ctx.verify_called == 1:
+                return TSIGError.NOERROR
+            return TSIGError.FORMERR
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = fake_tsig_error
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_unexpected_tsig(self):
+        # XFR request wasn't signed, but response includes TSIG.  Like BIND 9,
+        # we reject that.
+        self.axfr_response_params['tsig_1st'] = True
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+
+    def test_do_xfrin_with_unexpected_tsig_for_second_message(self):
+        # similar to the previous test, but the first message is normal.
+        # the second one contains an unexpected TSIG.  should be rejected.
+        self.axfr_response_params['tsig_2nd'] = True
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+
     def test_do_xfrin_empty_response(self):
         # skipping the creation of response data, so the transfer will fail.
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
@@ -389,8 +521,10 @@ class TestXfrinConnection(unittest.TestCase):
         # This helper method creates a simple sequence of DNS messages that
         # forms a valid XFR transaction.  It consists of two messages, each
         # containing just a single SOA RR.
-        self.conn.reply_data = self.conn.create_response_data()
-        self.conn.reply_data += self.conn.create_response_data()
+        tsig_1st = self.axfr_response_params['tsig_1st']
+        tsig_2nd = self.axfr_response_params['tsig_2nd']
+        self.conn.reply_data = self.conn.create_response_data(tsig=tsig_1st)
+        self.conn.reply_data += self.conn.create_response_data(tsig=tsig_2nd)
 
     def _create_soa_response_data(self):
         # This helper method creates a DNS message that is supposed to be
@@ -401,7 +535,8 @@ class TestXfrinConnection(unittest.TestCase):
             bad_qid=self.soa_response_params['bad_qid'],
             response=self.soa_response_params['response'],
             rcode=self.soa_response_params['rcode'],
-            questions=self.soa_response_params['questions'])
+            questions=self.soa_response_params['questions'],
+            tsig=self.soa_response_params['tsig'])
         if self.soa_response_params['axfr_after_soa'] != None:
             self.conn.response_generator = self.soa_response_params['axfr_after_soa']
 

+ 25 - 4
src/bin/xfrin/xfrin.py.in

@@ -170,6 +170,22 @@ class XfrinConnection(asyncore.dispatcher):
 
         return data
 
+    def _check_response_tsig(self, msg, response_data):
+        tsig_record = msg.get_tsig_record()
+        if self._tsig_ctx is not None:
+            tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
+            if tsig_error != TSIGError.NOERROR:
+                raise XfrinException('TSIG verify fail: %s' % str(tsig_error))
+        elif tsig_record is not None:
+            # If the response includes a TSIG while we didn't sign the query,
+            # we treat it as an error.  RFC doesn't say anything about this
+            # case, but it clearly states the server must not sign a response
+            # to an unsigned request.  Although we could be flexible, no sane
+            # implementation would return such a response, and since this is
+            # part of security mechanism, it's probably better to be more
+            # strict.
+            raise XfrinException('Unexpected TSIG in response')
+
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
@@ -177,7 +193,7 @@ class XfrinConnection(asyncore.dispatcher):
         True: soa serial in master is bigger
         '''
 
-        self._send_query(RRType("SOA"))
+        self._send_query(RRType.SOA())
         data_len = self._get_request_response(2)
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         soa_response = self._get_request_response(msg_len)
@@ -188,6 +204,9 @@ class XfrinConnection(asyncore.dispatcher):
         # strict we should be (see the comment in _check_response_header())
         self._check_response_header(msg)
 
+        # TSIG related checks, including an expected signed response
+        self._check_response_tsig(msg, soa_response)
+
         # TODO, need select soa record from data source then compare the two
         # serial, current just return OK, since this function hasn't been used
         # now.
@@ -205,8 +224,7 @@ class XfrinConnection(asyncore.dispatcher):
             logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
             if ret == XFRIN_OK:
                 self.log_msg(logstr + 'started')
-                # TODO: .AXFR() RRType.AXFR()
-                self._send_query(RRType(252))
+                self._send_query(RRType.AXFR())
                 isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
                                             self._handle_xfrin_response)
 
@@ -277,7 +295,7 @@ class XfrinConnection(asyncore.dispatcher):
 
             for rdata in rrset.get_rdata():
                 # Count the soa record count
-                if rrset.get_type() == RRType("SOA"):
+                if rrset.get_type() == RRType.SOA():
                     self._soa_rr_count += 1
 
                     # XXX: the current DNS message parser can't preserve the
@@ -303,6 +321,9 @@ class XfrinConnection(asyncore.dispatcher):
             msg.from_wire(recvdata)
             self._check_response_status(msg)
 
+            # TSIG related checks, including an expected signed response
+            self._check_response_tsig(msg, recvdata)
+
             answer_section = msg.get_section(Message.SECTION_ANSWER)
             for rr in self._handle_answer_section(answer_section):
                 yield rr