Browse Source

[trac914] implemented TSIG verify() support in xfrin

JINMEI Tatuya 14 years ago
parent
commit
58a5aabf65
2 changed files with 170 additions and 14 deletions
  1. 145 10
      src/bin/xfrin/tests/xfrin_test.py
  2. 25 4
      src/bin/xfrin/xfrin.py.in

+ 145 - 10
src/bin/xfrin/tests/xfrin_test.py

@@ -125,10 +125,11 @@ class MockXfrinConnection(XfrinConnection):
                 self.response_generator()
                 self.response_generator()
         return len(data)
         return len(data)
 
 
-    def create_response_data(self, response = True, bad_qid = False,
-                             rcode = Rcode.NOERROR(),
-                             questions = default_questions,
-                             answers = default_answers):
+    def create_response_data(self, response=True, bad_qid=False,
+                             rcode=Rcode.NOERROR(),
+                             questions=default_questions,
+                             answers=default_answers,
+                             tsig=False):
         resp = Message(Message.RENDER)
         resp = Message(Message.RENDER)
         qid = self.qid
         qid = self.qid
         if bad_qid:
         if bad_qid:
@@ -142,7 +143,13 @@ class MockXfrinConnection(XfrinConnection):
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
         [resp.add_rrset(Message.SECTION_ANSWER, a) for a in answers]
 
 
         renderer = MessageRenderer()
         renderer = MessageRenderer()
-        resp.to_wire(renderer)
+        if tsig:
+            # for now, we don't need a valid SIG.  We only need to include
+            # TSIG RR.  So how to add it and which key is used don't matter.
+            tsig_ctx = TSIGContext(TSIG_KEY)
+            resp.to_wire(renderer, tsig_ctx)
+        else:
+            resp.to_wire(renderer)
         reply_data = struct.pack('H', socket.htons(renderer.get_length()))
         reply_data = struct.pack('H', socket.htons(renderer.get_length()))
         reply_data += renderer.get_data()
         reply_data += renderer.get_data()
 
 
@@ -157,14 +164,18 @@ class TestXfrinConnection(unittest.TestCase):
                                         TEST_RRCLASS, TEST_DB_FILE,
                                         TEST_RRCLASS, TEST_DB_FILE,
                                         threading.Event(),
                                         threading.Event(),
                                         TEST_MASTER_IPV4_ADDRINFO)
                                         TEST_MASTER_IPV4_ADDRINFO)
-        self.axfr_after_soa = False
         self.soa_response_params = {
         self.soa_response_params = {
             'questions': [example_soa_question],
             'questions': [example_soa_question],
             'bad_qid': False,
             'bad_qid': False,
             'response': True,
             'response': True,
             'rcode': Rcode.NOERROR(),
             'rcode': Rcode.NOERROR(),
+            'tsig': False,
             'axfr_after_soa': self._create_normal_response_data
             'axfr_after_soa': self._create_normal_response_data
             }
             }
+        self.axfr_response_params = {
+            'tsig_1st': False,
+            'tsig_2nd': False
+            }
 
 
     def tearDown(self):
     def tearDown(self):
         self.conn.close()
         self.conn.close()
@@ -240,7 +251,7 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.reply_data = b'aaaxxxx'
         self.conn.reply_data = b'aaaxxxx'
         self.assertRaises(XfrinTestException, self._handle_xfrin_response)
         self.assertRaises(XfrinTestException, self._handle_xfrin_response)
 
 
-    def test_response_with_tsig(self):
+    def test_response_with_tsigfail(self):
         self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
         self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         # server tsig check fail, return with RCODE 9 (NOTAUTH)
         self.conn._send_query(RRType.SOA())
         self.conn._send_query(RRType.SOA())
@@ -311,6 +322,52 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.response_generator = self._create_soa_response_data
         self.conn.response_generator = self._create_soa_response_data
         self.assertRaises(XfrinException, self.conn._check_soa_serial)
         self.assertRaises(XfrinException, self.conn._check_soa_serial)
 
 
+    def test_soacheck_with_tsig(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_soa_response_data
+        # emulate a validly signed response
+        self.conn._tsig_ctx.error = TSIGError.NOERROR
+        self.assertEqual(self.conn._check_soa_serial(), XFRIN_OK)
+        self.assertEqual(self.conn._tsig_ctx.get_error(), TSIGError.NOERROR)
+
+    def test_soacheck_with_tsig_notauth(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+
+        # emulate a valid error response
+        self.soa_response_params['rcode'] = Rcode.NOTAUTH()
+        self.conn.response_generator = self._create_soa_response_data
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_tsig_noerror_badsig(self):
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+
+        # emulate a normal response bad verification failure due to BADSIG.
+        # According RFC2845, in this case we should ignore it and keep
+        # waiting for a valid response until a timeout.  But we immediately
+        # treat this as a final failure (just as BIND 9 does).
+        self.conn.response_generator = self._create_soa_response_data
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_tsig_unsigned_response(self):
+        # we can use a real TSIGContext for this.  the response doesn't
+        # contain a TSIG while we sent a signed query.  RFC2845 states
+        # we should wait for a valid response in this case, but we treat
+        # it as a fatal transaction failure, too.
+        self.conn._tsig_ctx = TSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
+    def test_soacheck_with_unexpected_tsig_response(self):
+        # we reject unexpected TSIG in responses (following BIND 9's
+        # behavior)
+        self.soa_response_params['tsig'] = True
+        self.conn.response_generator = self._create_soa_response_data
+        self.assertRaises(XfrinException, self.conn._check_soa_serial)
+
     def test_response_shutdown(self):
     def test_response_shutdown(self):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.conn._shutdown_event.set()
         self.conn._shutdown_event.set()
@@ -344,6 +401,81 @@ class TestXfrinConnection(unittest.TestCase):
         self.conn.response_generator = self._create_normal_response_data
         self.conn.response_generator = self._create_normal_response_data
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
 
 
+    def test_do_xfrin_with_tsig(self):
+        # use TSIG with a mock context.  we fake all verify results to
+        # emulate successful verification.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = TSIGError.NOERROR
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_OK)
+        # We use two messages in the tests.  The same context should have been
+        # usef for both.
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_tsig_fail(self):
+        # TSIG verify will fail for the first message.  xfrin should fail
+        # immediately.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = TSIGError.BAD_SIG
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(1, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_tsig_fail_for_second_message(self):
+        # Similar to the previous test, but first verify succeeds.  There
+        # should be a second verify attempt, which will fail, which should
+        # make xfrin fail.
+        def fake_tsig_error(ctx):
+            if self.conn._tsig_ctx.verify_called == 1:
+                return TSIGError.NOERROR
+            return TSIGError.BAD_SIG
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = fake_tsig_error
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_missing_tsig(self):
+        # XFR request sent with TSIG, but the response doesn't have TSIG.
+        # xfr should fail.
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(1, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_missing_tsig_for_second_message(self):
+        # Similar to the previous test, but firt one contains TSIG and verify
+        # succeeds (due to fake).  The second message lacks TSIG.
+        #
+        # Note: this test case is actually not that trivial:  Skipping
+        # intermediate TSIG is allowed.  In this case, however, the second
+        # message is the last one, which must contain TSIG anyway, so the
+        # expected result is correct.  If/when we support skipping
+        # intermediate TSIGs, we'll need additional test cases.
+        def fake_tsig_error(ctx):
+            if self.conn._tsig_ctx.verify_called == 1:
+                return TSIGError.NOERROR
+            return TSIGError.FORMERR
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.error = fake_tsig_error
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
+    def test_do_xfrin_with_unexpected_tsig(self):
+        # XFR request wasn't signed, but response includes TSIG.  Like BIND 9,
+        # we reject that.
+        self.axfr_response_params['tsig_1st'] = True
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+
+    def test_do_xfrin_with_unexpected_tsig_for_second_message(self):
+        # similar to the previous test, but the first message is normal.
+        # the second one contains an unexpected TSIG.  should be rejected.
+        self.axfr_response_params['tsig_2nd'] = True
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+
     def test_do_xfrin_empty_response(self):
     def test_do_xfrin_empty_response(self):
         # skipping the creation of response data, so the transfer will fail.
         # skipping the creation of response data, so the transfer will fail.
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
@@ -389,8 +521,10 @@ class TestXfrinConnection(unittest.TestCase):
         # This helper method creates a simple sequence of DNS messages that
         # This helper method creates a simple sequence of DNS messages that
         # forms a valid XFR transaction.  It consists of two messages, each
         # forms a valid XFR transaction.  It consists of two messages, each
         # containing just a single SOA RR.
         # containing just a single SOA RR.
-        self.conn.reply_data = self.conn.create_response_data()
-        self.conn.reply_data += self.conn.create_response_data()
+        tsig_1st = self.axfr_response_params['tsig_1st']
+        tsig_2nd = self.axfr_response_params['tsig_2nd']
+        self.conn.reply_data = self.conn.create_response_data(tsig=tsig_1st)
+        self.conn.reply_data += self.conn.create_response_data(tsig=tsig_2nd)
 
 
     def _create_soa_response_data(self):
     def _create_soa_response_data(self):
         # This helper method creates a DNS message that is supposed to be
         # This helper method creates a DNS message that is supposed to be
@@ -401,7 +535,8 @@ class TestXfrinConnection(unittest.TestCase):
             bad_qid=self.soa_response_params['bad_qid'],
             bad_qid=self.soa_response_params['bad_qid'],
             response=self.soa_response_params['response'],
             response=self.soa_response_params['response'],
             rcode=self.soa_response_params['rcode'],
             rcode=self.soa_response_params['rcode'],
-            questions=self.soa_response_params['questions'])
+            questions=self.soa_response_params['questions'],
+            tsig=self.soa_response_params['tsig'])
         if self.soa_response_params['axfr_after_soa'] != None:
         if self.soa_response_params['axfr_after_soa'] != None:
             self.conn.response_generator = self.soa_response_params['axfr_after_soa']
             self.conn.response_generator = self.soa_response_params['axfr_after_soa']
 
 

+ 25 - 4
src/bin/xfrin/xfrin.py.in

@@ -170,6 +170,22 @@ class XfrinConnection(asyncore.dispatcher):
 
 
         return data
         return data
 
 
+    def _check_response_tsig(self, msg, response_data):
+        tsig_record = msg.get_tsig_record()
+        if self._tsig_ctx is not None:
+            tsig_error = self._tsig_ctx.verify(tsig_record, response_data)
+            if tsig_error != TSIGError.NOERROR:
+                raise XfrinException('TSIG verify fail: %s' % str(tsig_error))
+        elif tsig_record is not None:
+            # If the response includes a TSIG while we didn't sign the query,
+            # we treat it as an error.  RFC doesn't say anything about this
+            # case, but it clearly states the server must not sign a response
+            # to an unsigned request.  Although we could be flexible, no sane
+            # implementation would return such a response, and since this is
+            # part of security mechanism, it's probably better to be more
+            # strict.
+            raise XfrinException('Unexpected TSIG in response')
+
     def _check_soa_serial(self):
     def _check_soa_serial(self):
         ''' Compare the soa serial, if soa serial in master is less than
         ''' Compare the soa serial, if soa serial in master is less than
         the soa serial in local, Finish xfrin.
         the soa serial in local, Finish xfrin.
@@ -177,7 +193,7 @@ class XfrinConnection(asyncore.dispatcher):
         True: soa serial in master is bigger
         True: soa serial in master is bigger
         '''
         '''
 
 
-        self._send_query(RRType("SOA"))
+        self._send_query(RRType.SOA())
         data_len = self._get_request_response(2)
         data_len = self._get_request_response(2)
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         msg_len = socket.htons(struct.unpack('H', data_len)[0])
         soa_response = self._get_request_response(msg_len)
         soa_response = self._get_request_response(msg_len)
@@ -188,6 +204,9 @@ class XfrinConnection(asyncore.dispatcher):
         # strict we should be (see the comment in _check_response_header())
         # strict we should be (see the comment in _check_response_header())
         self._check_response_header(msg)
         self._check_response_header(msg)
 
 
+        # TSIG related checks, including an expected signed response
+        self._check_response_tsig(msg, soa_response)
+
         # TODO, need select soa record from data source then compare the two
         # TODO, need select soa record from data source then compare the two
         # serial, current just return OK, since this function hasn't been used
         # serial, current just return OK, since this function hasn't been used
         # now.
         # now.
@@ -205,8 +224,7 @@ class XfrinConnection(asyncore.dispatcher):
             logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
             logstr = 'transfer of \'%s\': AXFR ' % self._zone_name
             if ret == XFRIN_OK:
             if ret == XFRIN_OK:
                 self.log_msg(logstr + 'started')
                 self.log_msg(logstr + 'started')
-                # TODO: .AXFR() RRType.AXFR()
-                self._send_query(RRType(252))
+                self._send_query(RRType.AXFR())
                 isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
                 isc.datasrc.sqlite3_ds.load(self._db_file, self._zone_name,
                                             self._handle_xfrin_response)
                                             self._handle_xfrin_response)
 
 
@@ -277,7 +295,7 @@ class XfrinConnection(asyncore.dispatcher):
 
 
             for rdata in rrset.get_rdata():
             for rdata in rrset.get_rdata():
                 # Count the soa record count
                 # Count the soa record count
-                if rrset.get_type() == RRType("SOA"):
+                if rrset.get_type() == RRType.SOA():
                     self._soa_rr_count += 1
                     self._soa_rr_count += 1
 
 
                     # XXX: the current DNS message parser can't preserve the
                     # XXX: the current DNS message parser can't preserve the
@@ -303,6 +321,9 @@ class XfrinConnection(asyncore.dispatcher):
             msg.from_wire(recvdata)
             msg.from_wire(recvdata)
             self._check_response_status(msg)
             self._check_response_status(msg)
 
 
+            # TSIG related checks, including an expected signed response
+            self._check_response_tsig(msg, recvdata)
+
             answer_section = msg.get_section(Message.SECTION_ANSWER)
             answer_section = msg.get_section(Message.SECTION_ANSWER)
             for rr in self._handle_answer_section(answer_section):
             for rr in self._handle_answer_section(answer_section):
                 yield rr
                 yield rr