Michal 'vorner' Vaner 12 years ago
parent
commit
7ca65cb9ec

+ 42 - 6
src/bin/xfrin/tests/xfrin_test.py

@@ -564,6 +564,28 @@ class TestXfrinIXFRAdd(TestXfrinState):
         self.assertEqual(type(XfrinIXFRDeleteSOA()),
                          type(self.conn.get_xfrstate()))
 
+    def test_handle_new_delete_missing_sig(self):
+        self.conn._end_serial = isc.dns.Serial(1234)
+        # SOA RR whose serial is the current one means we are going to a new
+        # difference, starting with removing that SOA.
+        self.conn._diff.add_data(self.ns_rrset) # put some dummy change
+        self.conn._tsig_ctx = MockTSIGContext(TSIG_KEY)
+        self.conn._tsig_ctx.last_has_signature = lambda: False
+        # First, push a starting SOA inside. This should be OK, nothing checked
+        # yet.
+        self.state.handle_rr(self.conn, self.begin_soa)
+        end_soa_rdata = Rdata(RRType.SOA(), TEST_RRCLASS,
+                              'm. r. 1234 0 0 0 0')
+        end_soa_rrset = RRset(TEST_ZONE_NAME, TEST_RRCLASS, RRType.SOA(),
+                                RRTTL(3600))
+        end_soa_rrset.add_rdata(end_soa_rdata)
+        # This would try to finish up. But the TSIG pretends not everything is
+        # signed, rejecting it.
+        self.assertRaises(xfrin.XfrinProtocolError, self.state.handle_rr,
+                          self.conn, end_soa_rrset)
+        # No diffs were commited
+        self.assertEqual([], self.conn._datasrc_client.committed_diffs)
+
     def test_handle_out_of_sync(self):
         # getting SOA with an inconsistent serial.  This is an error.
         self.conn._end_serial = isc.dns.Serial(1235)
@@ -792,12 +814,14 @@ class TestAXFR(TestXfrinConnection):
     def tearDown(self):
         time.time = self.orig_time_time
 
-    def __create_mock_tsig(self, key, error):
+    def __create_mock_tsig(self, key, error, has_last_signature=True):
         # This helper function creates a MockTSIGContext for a given key
         # and TSIG error to be used as a result of verify (normally faked
         # one)
         mock_ctx = MockTSIGContext(key)
         mock_ctx.error = error
+        if not has_last_signature:
+            mock_ctx.last_has_signature = lambda: False
         return mock_ctx
 
     def __match_exception(self, expected_exception, expected_msg, expression):
@@ -1379,6 +1403,16 @@ class TestAXFR(TestXfrinConnection):
         self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
         self.assertEqual(1, self.conn._tsig_ctx.verify_called)
 
+    def test_do_xfrin_without_last_tsig(self):
+        # TSIG verify will succeed, but it will pretend the last message is
+        # not signed.
+        self.conn._tsig_key = TSIG_KEY
+        self.conn._tsig_ctx_creator = \
+            lambda key: self.__create_mock_tsig(key, TSIGError.NOERROR, False)
+        self.conn.response_generator = self._create_normal_response_data
+        self.assertEqual(self.conn.do_xfrin(False), XFRIN_FAIL)
+        self.assertEqual(2, self.conn._tsig_ctx.verify_called)
+
     def test_do_xfrin_with_tsig_fail_for_second_message(self):
         # Similar to the previous test, but first verify succeeds.  There
         # should be a second verify attempt, which will fail, which should
@@ -1553,16 +1587,18 @@ class TestIXFRResponse(TestXfrinConnection):
         self.conn._handle_xfrin_responses()
         self.assertEqual(type(XfrinIXFREnd()), type(self.conn.get_xfrstate()))
         self.assertEqual([], self.conn._datasrc_client.diffs)
+        # Everything is committed as one bunch, currently we commit at the very
+        # end.
         check_diffs(self.assertEqual,
                     [[('delete', begin_soa_rrset),
                       ('delete', self._create_a('192.0.2.1')),
                       ('add', self._create_soa('1231')),
-                      ('add', self._create_a('192.0.2.2'))],
-                     [('delete', self._create_soa('1231')),
+                      ('add', self._create_a('192.0.2.2')),
+                      ('delete', self._create_soa('1231')),
                       ('delete', self._create_a('192.0.2.3')),
                       ('add', self._create_soa('1232')),
-                      ('add', self._create_a('192.0.2.4'))],
-                     [('delete', self._create_soa('1232')),
+                      ('add', self._create_a('192.0.2.4')),
+                      ('delete', self._create_soa('1232')),
                       ('delete', self._create_a('192.0.2.5')),
                       ('add', soa_rrset),
                       ('add', self._create_a('192.0.2.6'))]],
@@ -2924,7 +2960,7 @@ class TestFormatting(unittest.TestCase):
         self.assertEqual("example.org/IN",
                          format_zone_str(isc.dns.Name("example.org"),
                          isc.dns.RRClass("IN")))
-    
+
     def test_format_addrinfo(self):
         # This test may need to be updated if the input type is changed,
         # right now it is a nested tuple:

+ 22 - 3
src/bin/xfrin/xfrin.py.in

@@ -362,6 +362,7 @@ class XfrinFirstData(XfrinState):
                 conn._request_serial == get_soa_serial(rr.get_rdata()[0]):
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_INCREMENTAL_RESP,
                          conn.zone_str())
+            conn._diff = None # Will be created on-demand
             self.set_xfrstate(conn, XfrinIXFRDeleteSOA())
         else:
             logger.debug(DBG_XFRIN_TRACE, XFRIN_GOT_NONINCREMENTAL_RESP,
@@ -380,11 +381,13 @@ class XfrinIXFRDeleteSOA(XfrinState):
             raise XfrinException(rr.get_type().to_text() +
                                  ' RR is given in IXFRDeleteSOA state')
         # This is the beginning state of one difference sequence (changes
-        # for one SOA update).  We need to create a new Diff object now.
+        # for one SOA update).  We may need to create a new Diff object now.
         # Note also that we (unconditionally) enable journaling here.  The
         # Diff constructor may internally disable it, however, if the
         # underlying data source doesn't support journaling.
-        conn._diff = Diff(conn._datasrc_client, conn._zone_name, False, True)
+        if conn._diff is None:
+            conn._diff = Diff(conn._datasrc_client, conn._zone_name, False,
+                              True)
         conn._diff.delete_data(rr)
         self.set_xfrstate(conn, XfrinIXFRDelete())
         conn.get_transfer_stats().ixfr_deletion_count += 1
@@ -420,6 +423,9 @@ class XfrinIXFRAdd(XfrinState):
             conn.get_transfer_stats().ixfr_changeset_count += 1
             soa_serial = get_soa_serial(rr.get_rdata()[0])
             if soa_serial == conn._end_serial:
+                # The final part is there. Check all was signed
+                # and commit it to the database.
+                conn._check_response_tsig_last()
                 conn._diff.commit()
                 self.set_xfrstate(conn, XfrinIXFREnd())
                 return True
@@ -429,7 +435,10 @@ class XfrinIXFRAdd(XfrinState):
                                          str(conn._current_serial) +
                                          ', got ' + str(soa_serial))
             else:
-                conn._diff.commit()
+                # Apply a change to the database. But don't commit it yet,
+                # we can't know if the message is/will be properly signed.
+                # A complete commit will happen after the last bit.
+                conn._diff.apply()
                 self.set_xfrstate(conn, XfrinIXFRDeleteSOA())
                 return False
         conn._diff.add_data(rr)
@@ -494,6 +503,7 @@ class XfrinAXFREnd(XfrinState):
         indicating there will be no more message to receive.
 
         """
+        conn._check_response_tsig_last()
         conn._diff.commit()
         return False
 
@@ -782,6 +792,15 @@ class XfrinConnection(asyncore.dispatcher):
             # strict.
             raise XfrinProtocolError('Unexpected TSIG in response')
 
+    def _check_response_tsig_last(self):
+        """
+        Check there's a signature at the last message.
+        """
+        if self._tsig_ctx is not None:
+            if not self._tsig_ctx.last_has_signature():
+                raise XfrinProtocolError('TSIG verify fail: no TSIG on last '+
+                                         'message')
+
     def __parse_soa_response(self, msg, response_data):
         '''Parse a response to SOA query and extract the SOA from answer.
 

+ 9 - 0
src/lib/dns/message.cc

@@ -254,6 +254,13 @@ MessageImpl::toWire(AbstractMessageRenderer& renderer, TSIGContext* tsig_ctx) {
     const size_t orig_msg_len_limit = renderer.getLengthLimit();
     const AbstractMessageRenderer::CompressMode orig_compress_mode =
         renderer.getCompressMode();
+
+    // We are going to skip soon, so we need to clear the renderer
+    // But we'll leave the length limit  and the compress mode intact
+    // (or shortened in case of TSIG)
+    renderer.clear();
+    renderer.setCompressMode(orig_compress_mode);
+
     if (tsig_len > 0) {
         if (tsig_len > orig_msg_len_limit) {
             isc_throw(InvalidParameter, "Failed to render DNS message: "
@@ -261,6 +268,8 @@ MessageImpl::toWire(AbstractMessageRenderer& renderer, TSIGContext* tsig_ctx) {
                       orig_msg_len_limit << ")");
         }
         renderer.setLengthLimit(orig_msg_len_limit - tsig_len);
+    } else {
+        renderer.setLengthLimit(orig_msg_len_limit);
     }
 
     // reserve room for the header

+ 8 - 0
src/lib/dns/message.h

@@ -556,6 +556,10 @@ public:
     /// \c Rcode must have been set beforehand; otherwise, an exception of
     /// class \c InvalidMessageOperation will be thrown.
     ///
+    /// \note The renderer's internal buffers and data are automatically
+    /// cleared, keeping the length limit and the compression mode intact.
+    /// In case truncation is triggered, the renderer is cleared completely.
+    ///
     /// \param renderer DNS message rendering context that encapsulates the
     /// output buffer and name compression information.
     void toWire(AbstractMessageRenderer& renderer);
@@ -581,6 +585,10 @@ public:
     /// it should mean a bug either in the TSIG context or in the renderer
     /// implementation.
     ///
+    /// \note The renderer's internal buffers and data are automatically
+    /// cleared, keeping the length limit and the compression mode intact.
+    /// In case truncation is triggered, the renderer is cleared completely.
+    ///
     /// \param renderer See the other version
     /// \param tsig_ctx A TSIG context that is to be used for signing the
     /// message

+ 7 - 7
src/lib/dns/python/tests/message_python_test.py

@@ -453,7 +453,7 @@ class MessageTest(unittest.TestCase):
 
     def test_to_text(self):
         message_render = create_message()
-        
+
         msg_str =\
 """;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 4149
 ;; flags: qr aa rd; QUERY: 1, ANSWER: 2, AUTHORITY: 0, ADDITIONAL: 0
@@ -484,7 +484,7 @@ test.example.com. 3600 IN A 192.0.2.2
                           Message.from_wire, self.p, bytes())
 
         test_name = Name("test.example.com");
-        
+
         message_parse = Message(0)
         factoryFromFile(message_parse, "message_fromWire1")
         self.assertEqual(0x1035, message_parse.get_qid())
@@ -493,7 +493,7 @@ test.example.com. 3600 IN A 192.0.2.2
         self.assertTrue(message_parse.get_header_flag(Message.HEADERFLAG_QR))
         self.assertTrue(message_parse.get_header_flag(Message.HEADERFLAG_RD))
         self.assertTrue(message_parse.get_header_flag(Message.HEADERFLAG_AA))
-    
+
         #QuestionPtr q = *message_parse.beginQuestion()
         q = message_parse.get_question()[0]
         self.assertEqual(test_name, q.get_name())
@@ -503,7 +503,7 @@ test.example.com. 3600 IN A 192.0.2.2
         self.assertEqual(2, message_parse.get_rr_count(Message.SECTION_ANSWER))
         self.assertEqual(0, message_parse.get_rr_count(Message.SECTION_AUTHORITY))
         self.assertEqual(0, message_parse.get_rr_count(Message.SECTION_ADDITIONAL))
-    
+
         #RRsetPtr rrset = *message_parse.beginSection(Message.SECTION_ANSWER)
         rrset = message_parse.get_section(Message.SECTION_ANSWER)[0]
         self.assertEqual(test_name, rrset.get_name())
@@ -569,12 +569,12 @@ test.example.com. 3600 IN A 192.0.2.2
         message_parse = Message(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire10.wire")
         self.assertEqual(Rcode.BADVERS(), message_parse.get_rcode())
-    
+
         # Maximum extended Rcode
         message_parse.clear(Message.PARSE)
         factoryFromFile(message_parse, "message_fromWire11.wire")
         self.assertEqual(0xfff, message_parse.get_rcode().get_code())
-    
+
     def test_BadEDNS0(self):
         message_parse = Message(Message.PARSE)
         # OPT RR in the answer section
@@ -596,7 +596,7 @@ test.example.com. 3600 IN A 192.0.2.2
                           factoryFromFile,
                           message_parse,
                           "message_fromWire6")
-                          
+
         # Compressed owner name of OPT RR points to a root name.
         # Not necessarily bogus, but very unusual and mostly pathological.
         # We accept it, but is it okay?

+ 13 - 4
src/lib/dns/python/tests/tsig_python_test.py

@@ -122,15 +122,23 @@ class TSIGContextTest(unittest.TestCase):
         # And there should be no error code.
         self.assertEqual(TSIGError(Rcode.NOERROR()), self.tsig_ctx.get_error())
 
+        # No message signed yet
+        self.assertRaises(TSIGContextError, self.tsig_ctx.last_had_signature)
+
     # Note: intentionally use camelCase so that we can easily copy-paste
     # corresponding C++ tests.
     def commonVerifyChecks(self, ctx, record, data, expected_error,
                            expected_new_state=\
-                               TSIGContext.STATE_VERIFIED_RESPONSE):
+                               TSIGContext.STATE_VERIFIED_RESPONSE,
+                           last_should_throw=False):
         self.assertEqual(expected_error, ctx.verify(record, data))
         self.assertEqual(expected_error, ctx.get_error())
         self.assertEqual(expected_new_state, ctx.get_state())
-
+        if last_should_throw:
+            self.assertRaises(TSIGContextError, ctx.last_had_signature)
+        else:
+            self.assertEqual(record is not None,
+                             ctx.last_had_signature())
     def test_from_keyring(self):
         # Construct a TSIG context with an empty key ring.  Key shouldn't be
         # found, and the BAD_KEY error should be recorded.
@@ -354,7 +362,7 @@ class TSIGContextTest(unittest.TestCase):
 
         tsig = self.createMessageAndSign(self.qid, self.test_name,
                                          self.tsig_ctx, 0, RRType.SOA())
-                                         
+
         fix_current_time(0x4da8b9d6 + 301)
         self.assertEqual(TSIGError.BAD_TIME,
                          self.tsig_verify_ctx.verify(tsig, DUMMY_DATA))
@@ -454,7 +462,8 @@ class TSIGContextTest(unittest.TestCase):
         self.createMessageAndSign(self.qid, self.test_name, self.tsig_ctx)
 
         self.commonVerifyChecks(self.tsig_ctx, None, DUMMY_DATA,
-                           TSIGError.FORMERR, TSIGContext.STATE_SENT_REQUEST)
+                           TSIGError.FORMERR, TSIGContext.STATE_SENT_REQUEST,
+                           True)
 
         self.createMessageFromFile("tsig_verify5.wire")
         self.commonVerifyChecks(self.tsig_ctx, self.message.get_tsig_record(),

+ 24 - 0
src/lib/dns/python/tsig_python.cc

@@ -66,6 +66,7 @@ PyObject* TSIGContext_getState(s_TSIGContext* self);
 PyObject* TSIGContext_getError(s_TSIGContext* self);
 PyObject* TSIGContext_sign(s_TSIGContext* self, PyObject* args);
 PyObject* TSIGContext_verify(s_TSIGContext* self, PyObject* args);
+PyObject* TSIGContext_lastHadSignature(s_TSIGContext* self);
 
 // These are the functions we export
 // For a minimal support, we don't need them.
@@ -89,6 +90,9 @@ PyMethodDef TSIGContext_methods[] = {
     { "verify",
       reinterpret_cast<PyCFunction>(TSIGContext_verify), METH_VARARGS,
       "Verify a DNS message." },
+    { "last_had_signature",
+      reinterpret_cast<PyCFunction>(TSIGContext_lastHadSignature), METH_NOARGS,
+      "Return True if the last verified message contained a signature" },
     { NULL, NULL, 0, NULL }
 };
 
@@ -234,6 +238,26 @@ TSIGContext_verify(s_TSIGContext* self, PyObject* args) {
 
     return (NULL);
 }
+
+PyObject*
+TSIGContext_lastHadSignature(s_TSIGContext* self) {
+    try {
+        long result = self->cppobj->lastHadSignature();
+        return (PyBool_FromLong(result));
+    } catch (const TSIGContextError& ex) {
+        PyErr_SetString(po_TSIGContextError, ex.what());
+    } catch (const exception& ex) {
+        const string ex_what =
+            "Unexpected failure in TSIG lastHadSignature: " +
+            string(ex.what());
+        PyErr_SetString(po_IscException, ex_what.c_str());
+    } catch (...) {
+        PyErr_SetString(PyExc_SystemError,
+                        "Unexpected failure in TSIG lastHadSignature");
+    }
+
+    return (NULL);
+}
 } // end of unnamed namespace
 
 namespace isc {

+ 2 - 2
src/lib/dns/tests/message_unittest.cc

@@ -534,7 +534,7 @@ TEST_F(MessageTest, appendSection) {
         RRClass::IN(), RRType::A()));
     EXPECT_TRUE(target.hasRRset(Message::SECTION_ANSWER, test_name,
         RRClass::IN(), RRType::AAAA()));
-    
+
 }
 
 TEST_F(MessageTest, parseHeader) {
@@ -1091,7 +1091,7 @@ TEST_F(MessageTest, toWireWithoutRcode) {
 TEST_F(MessageTest, toText) {
     // Check toText() output for a typical DNS response with records in
     // all sections
-    
+
     factoryFromFile(message_parse, "message_toText1.wire");
     {
         SCOPED_TRACE("Message toText test (basic case)");

+ 184 - 34
src/lib/dns/tests/tsig_unittest.cc

@@ -66,6 +66,22 @@ testGetTime() {
     return (NOW);
 }
 
+// Thin wrapper around TSIGContext to allow access to the
+// update method.
+class TestTSIGContext : public TSIGContext {
+public:
+    TestTSIGContext(const TSIGKey& key) :
+        TSIGContext(key)
+    {}
+    TestTSIGContext(const Name& key_name, const Name& algorithm_name,
+                    const TSIGKeyRing& keyring) :
+        TSIGContext(key_name, algorithm_name, keyring)
+    {}
+    void update(const void* const data, size_t len) {
+        TSIGContext::update(data, len);
+    }
+};
+
 class TSIGTest : public ::testing::Test {
 protected:
     TSIGTest() :
@@ -83,9 +99,10 @@ protected:
         isc::util::detail::gettimeFunction = NULL;
 
         decodeBase64("SFuWd/q99SzF8Yzd1QbB9g==", secret);
-        tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
-                                               TSIGKey::HMACMD5_NAME(),
-                                               &secret[0], secret.size())));
+        tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                                   TSIGKey::HMACMD5_NAME(),
+                                                   &secret[0],
+                                                   secret.size())));
         tsig_verify_ctx.reset(new TSIGContext(TSIGKey(test_name,
                                                       TSIGKey::HMACMD5_NAME(),
                                                       &secret[0],
@@ -116,7 +133,7 @@ protected:
     static const unsigned int AA_FLAG = 0x2;
     static const unsigned int RD_FLAG = 0x4;
 
-    boost::scoped_ptr<TSIGContext> tsig_ctx;
+    boost::scoped_ptr<TestTSIGContext> tsig_ctx;
     boost::scoped_ptr<TSIGContext> tsig_verify_ctx;
     TSIGKeyRing keyring;
     const uint16_t qid;
@@ -166,16 +183,20 @@ TSIGTest::createMessageAndSign(uint16_t id, const Name& qname,
         message.addRRset(Message::SECTION_ANSWER, answer_rrset);
     }
     renderer.clear();
-    message.toWire(renderer);
 
     TSIGContext::State expected_new_state =
         (ctx->getState() == TSIGContext::INIT) ?
         TSIGContext::SENT_REQUEST : TSIGContext::SENT_RESPONSE;
-    ConstTSIGRecordPtr tsig = ctx->sign(id, renderer.getData(),
-                                        renderer.getLength());
+
+    message.toWire(renderer, *ctx);
+
+    message.clear(Message::PARSE);
+    InputBuffer buffer(renderer.getData(), renderer.getLength());
+    message.fromWire(buffer);
+
     EXPECT_EQ(expected_new_state, ctx->getState());
 
-    return (tsig);
+    return (ConstTSIGRecordPtr(new TSIGRecord(*message.getTSIGRecord())));
 }
 
 void
@@ -218,11 +239,17 @@ void
 commonVerifyChecks(TSIGContext& ctx, const TSIGRecord* record,
                    const void* data, size_t data_len, TSIGError expected_error,
                    TSIGContext::State expected_new_state =
-                   TSIGContext::VERIFIED_RESPONSE)
+                   TSIGContext::VERIFIED_RESPONSE,
+                   bool last_should_throw = false)
 {
     EXPECT_EQ(expected_error, ctx.verify(record, data, data_len));
     EXPECT_EQ(expected_error, ctx.getError());
     EXPECT_EQ(expected_new_state, ctx.getState());
+    if (last_should_throw) {
+        EXPECT_THROW(ctx.lastHadSignature(), TSIGContextError);
+    } else {
+        EXPECT_EQ(record != NULL, ctx.lastHadSignature());
+    }
 }
 
 TEST_F(TSIGTest, initialState) {
@@ -231,6 +258,9 @@ TEST_F(TSIGTest, initialState) {
 
     // And there should be no error code.
     EXPECT_EQ(TSIGError(Rcode::NOERROR()), tsig_ctx->getError());
+
+    // Nothing verified yet
+    EXPECT_THROW(tsig_ctx->lastHadSignature(), TSIGContextError);
 }
 
 TEST_F(TSIGTest, constructFromKeyRing) {
@@ -354,10 +384,17 @@ TEST_F(TSIGTest, verifyBadData) {
                                   12 + dummy_record.getLength() - 1),
                  InvalidParameter);
 
+    // Still nothing verified
+    EXPECT_THROW(tsig_ctx->lastHadSignature(), TSIGContextError);
+
     // And the data must not be NULL.
     EXPECT_THROW(tsig_ctx->verify(&dummy_record, NULL,
                                   12 + dummy_record.getLength()),
                  InvalidParameter);
+
+    // Still nothing verified
+    EXPECT_THROW(tsig_ctx->lastHadSignature(), TSIGContextError);
+
 }
 
 #ifdef ENABLE_CUSTOM_OPERATOR_NEW
@@ -726,8 +763,8 @@ TEST_F(TSIGTest, badsigResponse) {
 TEST_F(TSIGTest, badkeyResponse) {
     // A similar test as badsigResponse but for BADKEY
     isc::util::detail::gettimeFunction = testGetTime<0x4da8877a>;
-    tsig_ctx.reset(new TSIGContext(badkey_name, TSIGKey::HMACMD5_NAME(),
-                                   keyring));
+    tsig_ctx.reset(new TestTSIGContext(badkey_name, TSIGKey::HMACMD5_NAME(),
+                                       keyring));
     {
         SCOPED_TRACE("Verify resulting in BADKEY");
         commonVerifyChecks(*tsig_ctx, &dummy_record, &dummy_data[0],
@@ -806,7 +843,7 @@ TEST_F(TSIGTest, nosigThenValidate) {
         SCOPED_TRACE("Verify a response without TSIG that should exist");
         commonVerifyChecks(*tsig_ctx, NULL, &dummy_data[0],
                            dummy_data.size(), TSIGError::FORMERR(),
-                           TSIGContext::SENT_REQUEST);
+                           TSIGContext::SENT_REQUEST, true);
     }
 
     createMessageFromFile("tsig_verify5.wire");
@@ -936,45 +973,47 @@ TEST_F(TSIGTest, getTSIGLength) {
     EXPECT_EQ(85, tsig_ctx->getTSIGLength());
 
     // hmac-sha1: n2=11, x=20
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACSHA1_NAME(),
-                                           &dummy_data[0], 20)));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACSHA1_NAME(),
+                                               &dummy_data[0], 20)));
     EXPECT_EQ(74, tsig_ctx->getTSIGLength());
 
     // hmac-sha256: n2=13, x=32
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
-                                           TSIGKey::HMACSHA256_NAME(),
-                                           &dummy_data[0], 32)));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACSHA256_NAME(),
+                                               &dummy_data[0], 32)));
     EXPECT_EQ(88, tsig_ctx->getTSIGLength());
 
     // hmac-sha224: n2=13, x=28
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
-                                           TSIGKey::HMACSHA224_NAME(),
-                                           &dummy_data[0], 28)));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACSHA224_NAME(),
+                                               &dummy_data[0], 28)));
     EXPECT_EQ(84, tsig_ctx->getTSIGLength());
 
     // hmac-sha384: n2=13, x=48
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
-                                           TSIGKey::HMACSHA384_NAME(),
-                                           &dummy_data[0], 48)));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACSHA384_NAME(),
+                                               &dummy_data[0], 48)));
     EXPECT_EQ(104, tsig_ctx->getTSIGLength());
 
     // hmac-sha512: n2=13, x=64
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name,
-                                           TSIGKey::HMACSHA512_NAME(),
-                                           &dummy_data[0], 64)));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACSHA512_NAME(),
+                                               &dummy_data[0], 64)));
     EXPECT_EQ(120, tsig_ctx->getTSIGLength());
 
     // bad key case: n1=len(badkey.example.com)=20, n2=26, x=0
-    tsig_ctx.reset(new TSIGContext(badkey_name, TSIGKey::HMACMD5_NAME(),
-                                   keyring));
+    tsig_ctx.reset(new TestTSIGContext(badkey_name, TSIGKey::HMACMD5_NAME(),
+                                       keyring));
     EXPECT_EQ(72, tsig_ctx->getTSIGLength());
 
     // bad sig case: n1=17, n2=26, x=0
     isc::util::detail::gettimeFunction = testGetTime<0x4da8877a>;
     createMessageFromFile("message_toWire2.wire");
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACMD5_NAME(),
-                                           &dummy_data[0],
-                                           dummy_data.size())));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACMD5_NAME(),
+                                               &dummy_data[0],
+                                               dummy_data.size())));
     {
         SCOPED_TRACE("Verify resulting in BADSIG");
         commonVerifyChecks(*tsig_ctx, message.getTSIGRecord(),
@@ -985,9 +1024,10 @@ TEST_F(TSIGTest, getTSIGLength) {
 
     // bad time case: n1=17, n2=26, x=16, y=6
     isc::util::detail::gettimeFunction = testGetTime<0x4da8877a - 1000>;
-    tsig_ctx.reset(new TSIGContext(TSIGKey(test_name, TSIGKey::HMACMD5_NAME(),
-                                           &dummy_data[0],
-                                           dummy_data.size())));
+    tsig_ctx.reset(new TestTSIGContext(TSIGKey(test_name,
+                                               TSIGKey::HMACMD5_NAME(),
+                                               &dummy_data[0],
+                                               dummy_data.size())));
     {
         SCOPED_TRACE("Verify resulting in BADTIME");
         commonVerifyChecks(*tsig_ctx, message.getTSIGRecord(),
@@ -998,4 +1038,114 @@ TEST_F(TSIGTest, getTSIGLength) {
     EXPECT_EQ(91, tsig_ctx->getTSIGLength());
 }
 
+// Verify a stream of multiple messages. Some of them have a signature omitted.
+//
+// We have two contexts, one that signs, another that verifies.
+TEST_F(TSIGTest, verifyMulti) {
+    isc::util::detail::gettimeFunction = testGetTime<0x4da8877a>;
+
+    // First, send query from the verify one to the normal one, so
+    // we initialize something like AXFR
+    {
+        SCOPED_TRACE("Query");
+        ConstTSIGRecordPtr tsig = createMessageAndSign(1234, test_name,
+                                                       tsig_verify_ctx.get());
+        commonVerifyChecks(*tsig_ctx, tsig.get(),
+                           renderer.getData(), renderer.getLength(),
+                           TSIGError(Rcode::NOERROR()),
+                           TSIGContext::RECEIVED_REQUEST);
+    }
+
+    {
+        SCOPED_TRACE("First message");
+        ConstTSIGRecordPtr tsig = createMessageAndSign(1234, test_name,
+                                                       tsig_ctx.get());
+        commonVerifyChecks(*tsig_verify_ctx, tsig.get(),
+                           renderer.getData(), renderer.getLength(),
+                           TSIGError(Rcode::NOERROR()),
+                           TSIGContext::VERIFIED_RESPONSE);
+        EXPECT_TRUE(tsig_verify_ctx->lastHadSignature());
+    }
+
+    {
+        SCOPED_TRACE("Second message");
+        ConstTSIGRecordPtr tsig = createMessageAndSign(1234, test_name,
+                                                       tsig_ctx.get());
+        commonVerifyChecks(*tsig_verify_ctx, tsig.get(),
+                           renderer.getData(), renderer.getLength(),
+                           TSIGError(Rcode::NOERROR()),
+                           TSIGContext::VERIFIED_RESPONSE);
+        EXPECT_TRUE(tsig_verify_ctx->lastHadSignature());
+    }
+
+    {
+        SCOPED_TRACE("Third message. Unsigned.");
+        // Another message does not carry the TSIG on it. But it should
+        // be OK, it's in the middle of stream.
+        message.clear(Message::RENDER);
+        message.setQid(1234);
+        message.setOpcode(Opcode::QUERY());
+        message.setRcode(Rcode::NOERROR());
+        RRsetPtr answer_rrset(new RRset(test_name, test_class, RRType::A(),
+                                        test_ttl));
+        answer_rrset->addRdata(createRdata(RRType::A(), test_class,
+                                           "192.0.2.1"));
+        message.addRRset(Message::SECTION_ANSWER, answer_rrset);
+        message.toWire(renderer);
+        // Update the internal state. We abuse the knowledge of
+        // internals here a little bit to generate correct test data
+        tsig_ctx->update(renderer.getData(), renderer.getLength());
+
+        commonVerifyChecks(*tsig_verify_ctx, NULL,
+                           renderer.getData(), renderer.getLength(),
+                           TSIGError(Rcode::NOERROR()),
+                           TSIGContext::VERIFIED_RESPONSE);
+
+        EXPECT_FALSE(tsig_verify_ctx->lastHadSignature());
+    }
+
+    {
+        SCOPED_TRACE("Fourth message. Signed again.");
+        ConstTSIGRecordPtr tsig = createMessageAndSign(1234, test_name,
+                                                       tsig_ctx.get());
+        commonVerifyChecks(*tsig_verify_ctx, tsig.get(),
+                           renderer.getData(), renderer.getLength(),
+                           TSIGError(Rcode::NOERROR()),
+                           TSIGContext::VERIFIED_RESPONSE);
+        EXPECT_TRUE(tsig_verify_ctx->lastHadSignature());
+    }
+
+    {
+        SCOPED_TRACE("Filling in bunch of unsigned messages");
+        for (size_t i = 0; i < 100; ++i) {
+            SCOPED_TRACE(i);
+            // Another message does not carry the TSIG on it. But it should
+            // be OK, it's in the middle of stream.
+            message.clear(Message::RENDER);
+            message.setQid(1234);
+            message.setOpcode(Opcode::QUERY());
+            message.setRcode(Rcode::NOERROR());
+            RRsetPtr answer_rrset(new RRset(test_name, test_class, RRType::A(),
+                                            test_ttl));
+            answer_rrset->addRdata(createRdata(RRType::A(), test_class,
+                                               "192.0.2.1"));
+            message.addRRset(Message::SECTION_ANSWER, answer_rrset);
+            message.toWire(renderer);
+            // Update the internal state. We abuse the knowledge of
+            // internals here a little bit to generate correct test data
+            tsig_ctx->update(renderer.getData(), renderer.getLength());
+
+            // 99 unsigned messages is OK. But the 100th must be signed, according
+            // to the RFC2845, section 4.4
+            commonVerifyChecks(*tsig_verify_ctx, NULL,
+                               renderer.getData(), renderer.getLength(),
+                               i == 99 ? TSIGError::FORMERR() :
+                                   TSIGError(Rcode::NOERROR()),
+                               TSIGContext::VERIFIED_RESPONSE);
+
+            EXPECT_FALSE(tsig_verify_ctx->lastHadSignature());
+        }
+    }
+}
+
 } // end namespace

+ 51 - 7
src/lib/dns/tsig.cc

@@ -61,7 +61,8 @@ struct TSIGContext::TSIGContextImpl {
     TSIGContextImpl(const TSIGKey& key,
                     TSIGError error = TSIGError::NOERROR()) :
         state_(INIT), key_(key), error_(error),
-        previous_timesigned_(0), digest_len_(0)
+        previous_timesigned_(0), digest_len_(0),
+        last_sig_dist_(-1)
     {
         if (error == TSIGError::NOERROR()) {
             // In normal (NOERROR) case, the key should be valid, and we
@@ -137,7 +138,7 @@ struct TSIGContext::TSIGContextImpl {
     // performance bottleneck, we could have this class a buffer as a member
     // variable and reuse it throughout the object's lifetime.  Right now,
     // we prefer keeping the scope for local things as small as possible.
-    void digestPreviousMAC(HMACPtr hmac) const;
+    void digestPreviousMAC(HMACPtr hmac);
     void digestTSIGVariables(HMACPtr hmac, uint16_t rrclass, uint32_t rrttl,
                              uint64_t time_signed, uint16_t fudge,
                              uint16_t error, uint16_t otherlen,
@@ -152,14 +153,25 @@ struct TSIGContext::TSIGContextImpl {
     uint64_t previous_timesigned_; // only meaningful for response with BADTIME
     size_t digest_len_;
     HMACPtr hmac_;
+    // This is the distance from the last verified signed message. Value of 0
+    // means the last message was signed. Special value -1 means there was no
+    // signed message yet.
+    int last_sig_dist_;
 };
 
 void
-TSIGContext::TSIGContextImpl::digestPreviousMAC(HMACPtr hmac) const {
+TSIGContext::TSIGContextImpl::digestPreviousMAC(HMACPtr hmac) {
     // We should have ensured the digest size fits 16 bits within this class
     // implementation.
     assert(previous_digest_.size() <= 0xffff);
 
+    if (previous_digest_.empty()) {
+        // The previous digest was already used. We're in the middle of
+        // TCP stream somewhere and we already pushed some unsigned message
+        // into the HMAC state.
+        return;
+    }
+
     OutputBuffer buffer(sizeof(uint16_t) + previous_digest_.size());
     const uint16_t previous_digest_len(previous_digest_.size());
     buffer.writeUint16(previous_digest_len);
@@ -414,11 +426,21 @@ TSIGContext::verify(const TSIGRecord* const record, const void* const data,
                   "TSIG verify attempt after sending a response");
     }
 
-    // This case happens when we sent a signed request and have received an
-    // unsigned response.  According to RFC2845 Section 4.6 this case should be
-    // considered a "format error" (although the specific error code
-    // wouldn't matter much for the caller).
     if (record == NULL) {
+        if (impl_->last_sig_dist_ >= 0 && impl_->last_sig_dist_ < 99) {
+            // It is not signed, but in the middle of TCP stream. We just
+            // update the HMAC state and consider this message OK.
+            update(data, data_len);
+            // This one is not signed, the last signed is one message further
+            // now.
+            impl_->last_sig_dist_++;
+            // No digest to return now. Just say it's OK.
+            return (impl_->postVerifyUpdate(TSIGError::NOERROR(), NULL, 0));
+        }
+        // This case happens when we sent a signed request and have received an
+        // unsigned response.  According to RFC2845 Section 4.6 this case should be
+        // considered a "format error" (although the specific error code
+        // wouldn't matter much for the caller).
         return (impl_->postVerifyUpdate(TSIGError::FORMERR(), NULL, 0));
     }
 
@@ -433,6 +455,9 @@ TSIGContext::verify(const TSIGRecord* const record, const void* const data,
         isc_throw(InvalidParameter, "TSIG verify: empty data is invalid");
     }
 
+    // This message is signed and we won't throw any more.
+    impl_->last_sig_dist_ = 0;
+
     // Check key: whether we first verify it with a known key or we verify
     // it using the consistent key in the context.  If the check fails we are
     // done with BADKEY.
@@ -520,5 +545,24 @@ TSIGContext::verify(const TSIGRecord* const record, const void* const data,
     return (impl_->postVerifyUpdate(TSIGError::BAD_SIG(), NULL, 0));
 }
 
+bool
+TSIGContext::lastHadSignature() const {
+    if (impl_->last_sig_dist_ == -1) {
+        isc_throw(TSIGContextError, "No message was verified yet");
+    }
+    return (impl_->last_sig_dist_ == 0);
+}
+
+void
+TSIGContext::update(const void* const data, size_t len) {
+    HMACPtr hmac(impl_->createHMAC());
+    // Use the previous digest and never use it again
+    impl_->digestPreviousMAC(hmac);
+    impl_->previous_digest_.clear();
+    // Push the message there
+    hmac->update(data, len);
+    impl_->hmac_ = hmac;
+}
+
 } // namespace dns
 } // namespace isc

+ 26 - 1
src/lib/dns/tsig.h

@@ -339,7 +339,6 @@ public:
     /// returns (without an exception being thrown), the internal state of
     /// the \c TSIGContext won't be modified.
     ///
-    /// \todo Support intermediate TCP DNS messages without TSIG (RFC2845 4.4)
     /// \todo Signature truncation support based on RFC4635
     ///
     /// \exception TSIGContextError Context already signed a response.
@@ -353,6 +352,19 @@ public:
     TSIGError verify(const TSIGRecord* const record, const void* const data,
                      const size_t data_len);
 
+    /// \brief Check whether the last verified message was signed.
+    ///
+    /// RFC2845 allows for some of the messages not to be signed. However,
+    /// the last message must be signed and the class has no knowledge if a
+    /// given message is the last one, therefore it can't check directly.
+    ///
+    /// It is up to the caller to check if the last verified message was signed
+    /// after all are verified by calling this function.
+    ///
+    /// \return If the last message was signed or not.
+    /// \exception TSIGContextError if no message was verified yet.
+    bool lastHadSignature() const;
+
     /// Return the expected length of TSIG RR after \c sign()
     ///
     /// This method returns the length of the TSIG RR that would be
@@ -401,6 +413,19 @@ public:
     static const uint16_t DEFAULT_FUDGE = 300;
     //@}
 
+protected:
+    /// \brief Update internal HMAC state by more data.
+    ///
+    /// This is used mostly internaly, when we need to verify a message without
+    /// TSIG signature in the middle of signed TCP stream. However, it is also
+    /// used in tests, so it's protected instead of private, to allow tests
+    /// in.
+    ///
+    /// It doesn't contain sanity checks, and it is not tested directly. But
+    /// we may want to add these one day to allow generating the skipped TSIG
+    /// messages too. Until then, do not use this method.
+    void update(const void* const data, size_t len);
+
 private:
     struct TSIGContextImpl;
     TSIGContextImpl* impl_;

+ 3 - 0
src/lib/python/isc/testutils/tsigctx_mock.py

@@ -51,3 +51,6 @@ class MockTSIGContext(TSIGContext):
         if hasattr(self.error, '__call__'):
             return self.error(self)
         return self.error
+
+    def last_has_signature(self):
+        return True