Browse Source

[trac658] Add IOAddress equality and return packet QID checks

Jelte Jansen 14 years ago
parent
commit
a5c11800c3

+ 13 - 0
src/lib/asiolink/io_endpoint.cc

@@ -44,4 +44,17 @@ IOEndpoint::create(const int protocol, const IOAddress& address,
               protocol);
 }
 
+bool
+IOEndpoint::operator==(const IOEndpoint& other) const {
+    return (getProtocol() == other.getProtocol() &&
+            getPort() == other.getPort() &&
+            getFamily() == other.getFamily() &&
+            getAddress() == other.getAddress());
+}
+
+bool
+IOEndpoint::operator!=(const IOEndpoint& other) const {
+    return (!operator==(other));
+}
+
 }

+ 3 - 0
src/lib/asiolink/io_endpoint.h

@@ -89,6 +89,9 @@ public:
     /// \brief Returns the address family of the endpoint.
     virtual short getFamily() const = 0;
 
+    bool operator==(const IOEndpoint& other) const;
+    bool operator!=(const IOEndpoint& other) const;
+
     /// \brief A polymorphic factory of endpoint from address and port.
     ///
     /// This method creates a new instance of (a derived class of)

+ 88 - 64
src/lib/asiolink/io_fetch.cc

@@ -43,6 +43,9 @@
 #include <asiolink/tcp_socket.h>
 #include <asiolink/udp_endpoint.h>
 #include <asiolink/udp_socket.h>
+#include <asiolink/qid_gen.h>
+
+#include <stdint.h>
 
 using namespace asio;
 using namespace isc::dns;
@@ -69,19 +72,20 @@ struct IOFetchData {
     // which is not known until construction of the IOFetch.  Use of a shared
     // pointer here is merely to ensure deletion when the data object is deleted.
     boost::scoped_ptr<IOAsioSocket<IOFetch> > socket;
-                                            ///< Socket to use for I/O
-    boost::scoped_ptr<IOEndpoint> remote;   ///< Where the fetch was sent
-    isc::dns::Question          question;   ///< Question to be asked
-    isc::dns::OutputBufferPtr   msgbuf;     ///< Wire buffer for question
-    isc::dns::OutputBufferPtr   received;   ///< Received data put here
-    IOFetch::Callback*          callback;   ///< Called on I/O Completion
-    asio::deadline_timer        timer;      ///< Timer to measure timeouts
-    IOFetch::Protocol           protocol;   ///< Protocol being used
-    size_t                      cumulative; ///< Cumulative received amount
-    size_t                      expected;   ///< Expected amount of data
-    size_t                      offset;     ///< Offset to receive data
-    bool                        stopped;    ///< Have we stopped running?
-    int                         timeout;    ///< Timeout in ms
+                                             ///< Socket to use for I/O
+    boost::scoped_ptr<IOEndpoint> remote_snd;///< Where the fetch is sent
+    boost::scoped_ptr<IOEndpoint> remote_rcv;///< Where the response came from
+    isc::dns::Question          question;    ///< Question to be asked
+    isc::dns::OutputBufferPtr   msgbuf;      ///< Wire buffer for question
+    isc::dns::OutputBufferPtr   received;    ///< Received data put here
+    IOFetch::Callback*          callback;    ///< Called on I/O Completion
+    asio::deadline_timer        timer;       ///< Timer to measure timeouts
+    IOFetch::Protocol           protocol;    ///< Protocol being used
+    size_t                      cumulative;  ///< Cumulative received amount
+    size_t                      expected;    ///< Expected amount of data
+    size_t                      offset;      ///< Offset to receive data
+    bool                        stopped;     ///< Have we stopped running?
+    int                         timeout;     ///< Timeout in ms
 
     // In case we need to log an error, the origin of the last asynchronous
     // I/O is recorded.  To save time and simplify the code, this is recorded
@@ -91,6 +95,7 @@ struct IOFetchData {
     isc::log::MessageID         origin;     ///< Origin of last asynchronous I/O
     uint8_t                     staging[IOFetch::STAGING_LENGTH];
                                             ///< Temporary array for received data
+    isc::dns::qid_t             qid;         ///< The QID set in the query
 
     /// \brief Constructor
     ///
@@ -121,7 +126,11 @@ struct IOFetchData {
             static_cast<IOAsioSocket<IOFetch>*>(
                 new TCPSocket<IOFetch>(service))
             ),
-        remote((proto == IOFetch::UDP) ?
+        remote_snd((proto == IOFetch::UDP) ?
+            static_cast<IOEndpoint*>(new UDPEndpoint(address, port)) :
+            static_cast<IOEndpoint*>(new TCPEndpoint(address, port))
+            ),
+        remote_rcv((proto == IOFetch::UDP) ?
             static_cast<IOEndpoint*>(new UDPEndpoint(address, port)) :
             static_cast<IOEndpoint*>(new TCPEndpoint(address, port))
             ),
@@ -138,8 +147,21 @@ struct IOFetchData {
         stopped(false),
         timeout(wait),
         origin(ASIO_UNKORIGIN),
-        staging()
+        staging(),
+        qid(QidGenerator::getInstance().generateQid())
     {}
+
+    // Checks if the response we received was ok;
+    // - data contains the buffer we read, as well as the address
+    // we sent to and the address we received from.
+    // length is provided by the operator() in IOFetch.
+    // Addresses must match, number of octets read must be at least
+    // 2, and the first two octets must match the qid of the message
+    // we sent.
+    bool responseOK() {
+        return (*remote_snd == *remote_rcv && cumulative >= 2 &&
+                readUint16(received->getData()) == qid);
+    }
 };
 
 /// IOFetch Constructor - just initialize the private data
@@ -180,7 +202,7 @@ IOFetch::operator()(asio::error_code ec, size_t length) {
         /// declarations.
         {
             Message msg(Message::RENDER);
-            msg.setQid(QidGenerator::getInstance().generateQid());
+            msg.setQid(data_->qid);
             msg.setOpcode(Opcode::QUERY());
             msg.setRcode(Rcode::NOERROR());
             msg.setHeaderFlag(Message::HEADERFLAG_RD);
@@ -202,47 +224,49 @@ IOFetch::operator()(asio::error_code ec, size_t length) {
         // is synchronous (i.e. UDP operation) we bypass the yield.
         data_->origin = ASIO_OPENSOCK;
         if (data_->socket->isOpenSynchronous()) {
-            data_->socket->open(data_->remote.get(), *this);
+            data_->socket->open(data_->remote_snd.get(), *this);
         } else {
-            CORO_YIELD data_->socket->open(data_->remote.get(), *this);
+            CORO_YIELD data_->socket->open(data_->remote_snd.get(), *this);
         }
 
-        // Begin an asynchronous send, and then yield.  When the send completes,
-        // we will resume immediately after this point.
-        data_->origin = ASIO_SENDSOCK;
-        CORO_YIELD data_->socket->asyncSend(data_->msgbuf->getData(),
-            data_->msgbuf->getLength(), data_->remote.get(), *this);
-
-        // Now receive the response.  Since TCP may not receive the entire
-        // message in one operation, we need to loop until we have received
-        // it. (This can't be done within the asyncReceive() method because
-        // each I/O operation will be done asynchronously and between each one
-        // we need to yield ... and we *really* don't want to set up another
-        // coroutine within that method.)  So after each receive (and yield),
-        // we check if the operation is complete and if not, loop to read again.
-        //
-        // Another concession to TCP is that the amount of is contained in the
-        // first two bytes.  This leads to two problems:
-        //
-        // a) We don't want those bytes in the return buffer.
-        // b) They may not both arrive in the first I/O.
-        //
-        // So... we need to loop until we have at least two bytes, then store
-        // the expected amount of data.  Then we need to loop until we have
-        // received all the data before copying it back to the user's buffer.
-        // And we want to minimise the amount of copying...
-
-        data_->origin = ASIO_RECVSOCK;
-        data_->cumulative = 0;          // No data yet received
-        data_->offset = 0;              // First data into start of buffer
         do {
-            CORO_YIELD data_->socket->asyncReceive(data_->staging,
-                                                   static_cast<size_t>(STAGING_LENGTH),
-                                                   data_->offset,
-                                                   data_->remote.get(), *this);
-        } while (!data_->socket->processReceivedData(data_->staging, length,
-                                                     data_->cumulative, data_->offset,
-                                                     data_->expected, data_->received));
+            // Begin an asynchronous send, and then yield.  When the send completes,
+            // we will resume immediately after this point.
+            data_->origin = ASIO_SENDSOCK;
+            CORO_YIELD data_->socket->asyncSend(data_->msgbuf->getData(),
+                data_->msgbuf->getLength(), data_->remote_snd.get(), *this);
+    
+            // Now receive the response.  Since TCP may not receive the entire
+            // message in one operation, we need to loop until we have received
+            // it. (This can't be done within the asyncReceive() method because
+            // each I/O operation will be done asynchronously and between each one
+            // we need to yield ... and we *really* don't want to set up another
+            // coroutine within that method.)  So after each receive (and yield),
+            // we check if the operation is complete and if not, loop to read again.
+            //
+            // Another concession to TCP is that the amount of is contained in the
+            // first two bytes.  This leads to two problems:
+            //
+            // a) We don't want those bytes in the return buffer.
+            // b) They may not both arrive in the first I/O.
+            //
+            // So... we need to loop until we have at least two bytes, then store
+            // the expected amount of data.  Then we need to loop until we have
+            // received all the data before copying it back to the user's buffer.
+            // And we want to minimise the amount of copying...
+    
+            data_->origin = ASIO_RECVSOCK;
+            data_->cumulative = 0;          // No data yet received
+            data_->offset = 0;              // First data into start of buffer
+            do {
+                CORO_YIELD data_->socket->asyncReceive(data_->staging,
+                                                       static_cast<size_t>(STAGING_LENGTH),
+                                                       data_->offset,
+                                                       data_->remote_rcv.get(), *this);
+            } while (!data_->socket->processReceivedData(data_->staging, length,
+                                                         data_->cumulative, data_->offset,
+                                                         data_->expected, data_->received));
+        } while (!data_->responseOK());
 
         // Finished with this socket, so close it.  This will not generate an
         // I/O error, but reset the origin to unknown in case we change this.
@@ -290,16 +314,16 @@ IOFetch::stop(Result result) {
             case TIME_OUT:
                 if (logger.isDebugEnabled(1)) {
                     logger.debug(20, ASIO_RECVTMO,
-                                 data_->remote->getAddress().toText().c_str(),
-                                 static_cast<int>(data_->remote->getPort()));
+                                 data_->remote_snd->getAddress().toText().c_str(),
+                                 static_cast<int>(data_->remote_snd->getPort()));
                 }
                 break;
 
             case SUCCESS:
                 if (logger.isDebugEnabled(50)) {
                     logger.debug(30, ASIO_FETCHCOMP,
-                                 data_->remote->getAddress().toText().c_str(),
-                                 static_cast<int>(data_->remote->getPort()));
+                                 data_->remote_rcv->getAddress().toText().c_str(),
+                                 static_cast<int>(data_->remote_rcv->getPort()));
                 }
                 break;
 
@@ -308,14 +332,14 @@ IOFetch::stop(Result result) {
                 // allowed but as it is unusual it is logged, but with a lower
                 // debug level than a timeout (which is totally normal).
                 logger.debug(1, ASIO_FETCHSTOP,
-                             data_->remote->getAddress().toText().c_str(),
-                             static_cast<int>(data_->remote->getPort()));
+                             data_->remote_snd->getAddress().toText().c_str(),
+                             static_cast<int>(data_->remote_snd->getPort()));
                 break;
 
             default:
                 logger.error(ASIO_UNKRESULT, static_cast<int>(result),
-                             data_->remote->getAddress().toText().c_str(),
-                             static_cast<int>(data_->remote->getPort()));
+                             data_->remote_snd->getAddress().toText().c_str(),
+                             static_cast<int>(data_->remote_snd->getPort()));
         }
 
         // Stop requested, cancel and I/O's on the socket and shut it down,
@@ -345,10 +369,10 @@ void IOFetch::logIOFailure(asio::error_code ec) {
     static const char* PROTOCOL[2] = {"TCP", "UDP"};
     logger.error(data_->origin,
                  ec.value(),
-                 ((data_->remote->getProtocol() == IPPROTO_TCP) ?
+                 ((data_->remote_snd->getProtocol() == IPPROTO_TCP) ?
                      PROTOCOL[0] : PROTOCOL[1]),
-                 data_->remote->getAddress().toText().c_str(),
-                 static_cast<int>(data_->remote->getPort()));
+                 data_->remote_snd->getAddress().toText().c_str(),
+                 static_cast<int>(data_->remote_snd->getPort()));
 }
 
 } // namespace asiolink

+ 56 - 0
src/lib/asiolink/tests/io_endpoint_unittest.cc

@@ -60,6 +60,62 @@ TEST(IOEndpointTest, createTCPv6) {
     EXPECT_EQ(IPPROTO_TCP, ep->getProtocol());
 }
 
+TEST(IOEndpointTest, equality) {
+    std::vector<const IOEndpoint *> epv;
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1234"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1234"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1234"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1234"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1235"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1235"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1235"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1235"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.1"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.1"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.1"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.1"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.2"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.2"), 5303));
+    epv.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.2"), 5304));
+    epv.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.2"), 5304));
+
+    for (size_t i = 0; i < epv.size(); ++i) {
+        for (size_t j = 0; j < epv.size(); ++j) {
+            if (i != j) {
+                // We use EXPECT_TRUE/FALSE instead of _EQ here, since
+                // _EQ requires there is an operator<< as well
+                EXPECT_FALSE(*epv[i] == *epv[j]);
+                EXPECT_TRUE(*epv[i] != *epv[j]);
+            }
+        }
+    }
+
+    // Create a second array with exactly the same values. We use create()
+    // again to make sure we get different endpoints
+    std::vector<const IOEndpoint *> epv2;
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1234"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1234"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1234"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1234"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1235"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1235"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("2001:db8::1235"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("2001:db8::1235"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.1"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.1"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.1"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.1"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.2"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.2"), 5303));
+    epv2.push_back(IOEndpoint::create(IPPROTO_TCP, IOAddress("192.0.2.2"), 5304));
+    epv2.push_back(IOEndpoint::create(IPPROTO_UDP, IOAddress("192.0.2.2"), 5304));
+
+    for (size_t i = 0; i < epv.size(); ++i) {
+        EXPECT_TRUE(*epv[i] == *epv2[i]);
+        EXPECT_FALSE(*epv[i] != *epv2[i]);
+    }
+}
+
 TEST(IOEndpointTest, createIPProto) {
     EXPECT_THROW(IOEndpoint::create(IPPROTO_IP, IOAddress("192.0.2.1"),
                                     53210)->getAddress().toText(),

+ 65 - 17
src/lib/asiolink/tests/io_fetch_unittest.cc

@@ -76,6 +76,7 @@ public:
     // response handler methods in this class) receives the question sent by the
     // fetch object.
     uint8_t         receive_buffer_[MAX_SIZE]; ///< Server receive buffer
+    OutputBufferPtr expected_buffer_;          ///< Data we expect to receive
     vector<uint8_t> send_buffer_;           ///< Server send buffer
     uint16_t        send_cumulative_;       ///< Data sent so far
 
@@ -84,6 +85,8 @@ public:
     string          test_data_;             ///< Large string - here for convenience
     bool            debug_;                 ///< true to enable debug output
     size_t          tcp_send_size_;         ///< Max size of TCP send
+    uint8_t         qid_0;                  ///< First octet of qid
+    uint8_t         qid_1;                  ///< Second octet of qid
 
     /// \brief Constructor
     IOFetchTest() :
@@ -102,12 +105,15 @@ public:
         cumulative_(0),
         timer_(service_.get_io_service()),
         receive_buffer_(),
+        expected_buffer_(new OutputBuffer(512)),
         send_buffer_(),
         send_cumulative_(0),
         return_data_(""),
         test_data_(""),
         debug_(DEBUG),
-        tcp_send_size_(0)
+        tcp_send_size_(0),
+        qid_0(0),
+        qid_1(0)
     {
         // Construct the data buffer for question we expect to receive.
         Message msg(Message::RENDER);
@@ -118,6 +124,8 @@ public:
         msg.addQuestion(question_);
         MessageRenderer renderer(*msgbuf_);
         msg.toWire(renderer);
+        MessageRenderer renderer2(*expected_buffer_);
+        msg.toWire(renderer2);
 
         // Initialize the test data to be returned: tests will return a
         // substring of this data. (It's convenient to have this as a member of
@@ -146,7 +154,8 @@ public:
     ///        by the "server" to receive data.
     /// \param length Amount of data received.
     void udpReceiveHandler(udp::endpoint* remote, udp::socket* socket,
-                    error_code ec = error_code(), size_t length = 0) {
+                    error_code ec = error_code(), size_t length = 0,
+                    bool bad_qid = false, bool second_send = false) {
         if (debug_) {
             cout << "udpReceiveHandler(): error = " << ec.value() <<
                     ", length = " << length << endl;
@@ -155,6 +164,8 @@ public:
         // The QID in the incoming data is random so set it to 0 for the
         // data comparison check. (It is set to 0 in the buffer containing
         // the expected data.)
+        qid_0 = receive_buffer_[0];
+        qid_1 = receive_buffer_[1];
         receive_buffer_[0] = receive_buffer_[1] = 0;
 
         // Check that length of the received data and the expected data are
@@ -164,10 +175,23 @@ public:
         static_cast<const uint8_t*>(msgbuf_->getData())));
 
         // Return a message back to the IOFetch object.
-        socket->send_to(asio::buffer(return_data_.c_str(), return_data_.size()),
-                                     *remote);
+        if (!bad_qid) {
+            expected_buffer_->writeUint8At(qid_0, 0);
+            expected_buffer_->writeUint8At(qid_1, 1);
+        } else {
+            expected_buffer_->writeUint8At(qid_0 + 1, 0);
+            expected_buffer_->writeUint8At(qid_1 + 1, 1);
+        }
+        socket->send_to(asio::buffer(expected_buffer_->getData(), length), *remote);
+
+        if (bad_qid && second_send) {
+            expected_buffer_->writeUint8At(qid_0, 0);
+            expected_buffer_->writeUint8At(qid_1, 1);
+            socket->send_to(asio::buffer(expected_buffer_->getData(),
+                            expected_buffer_->getLength()), *remote);
+        }
         if (debug_) {
-            cout << "udpReceiveHandler(): returned " << return_data_.size() <<
+            cout << "udpReceiveHandler(): returned " << expected_buffer_->getLength() <<
                     " bytes to the client" << endl;
         }
     }
@@ -249,18 +273,25 @@ public:
         // field the QID in the received buffer is in the third and fourth
         // bytes.
         EXPECT_EQ(msgbuf_->getLength() + 2, cumulative_);
+        qid_0 = receive_buffer_[2];
+        qid_1 = receive_buffer_[3];
+
         receive_buffer_[2] = receive_buffer_[3] = 0;
         EXPECT_TRUE(equal((receive_buffer_ + 2), (receive_buffer_ + cumulative_ - 2),
             static_cast<const uint8_t*>(msgbuf_->getData())));
 
         // ... and return a message back.  This has to be preceded by a two-byte
         // count field.
+
         send_buffer_.clear();
         send_buffer_.push_back(0);
         send_buffer_.push_back(0);
         writeUint16(return_data_.size(), &send_buffer_[0]);
         copy(return_data_.begin(), return_data_.end(), back_inserter(send_buffer_));
-
+        if (return_data_.size() >= 2) {
+            send_buffer_[2] = qid_0;
+            send_buffer_[3] = qid_1;
+        }
         // Send the data.  This is done in multiple writes with a delay between
         // each to check that the reassembly of TCP packets from fragments works.
         send_cumulative_ = 0;
@@ -373,10 +404,25 @@ public:
         // when one of the "servers" in this class has sent back return_data_.
         // Check the data is as expected/
         if (expected_ == IOFetch::SUCCESS) {
-            EXPECT_EQ(return_data_.size(), result_buff_->getLength());
-
-            const uint8_t* start = static_cast<const uint8_t*>(result_buff_->getData());
-            EXPECT_TRUE(equal(return_data_.begin(), return_data_.end(), start));
+            // In the case of UDP, we actually send back a real looking packet
+            // in the case of TCP, we send back a 'random' string
+            if (protocol_ == IOFetch::UDP) {
+                EXPECT_EQ(expected_buffer_->getLength(), result_buff_->getLength());
+                //const uint8_t* start = static_cast<const uint8_t*>(result_buff_->getData());
+                //EXPECT_TRUE(equal(return_data_.begin(), return_data_.end(), start));
+                EXPECT_EQ(0, memcmp(expected_buffer_->getData(), result_buff_->getData(),
+                          expected_buffer_->getLength()));
+            } else {
+                EXPECT_EQ(return_data_.size(), result_buff_->getLength());
+                // Overwrite the random qid with our own data for the
+                // comparison to succeed
+                if (result_buff_->getLength() >= 2) {
+                    result_buff_->writeUint8At(return_data_[0], 0);
+                    result_buff_->writeUint8At(return_data_[1], 1);
+                }
+                const uint8_t* start = static_cast<const uint8_t*>(result_buff_->getData());
+                EXPECT_TRUE(equal(return_data_.begin(), return_data_.end(), start));
+            }
         }
 
         // ... and cause the run loop to exit.
@@ -520,7 +566,7 @@ TEST_F(IOFetchTest, UdpSendReceive) {
     socket.async_receive_from(asio::buffer(receive_buffer_, sizeof(receive_buffer_)),
         remote,
         boost::bind(&IOFetchTest::udpReceiveHandler, this, &remote, &socket,
-                    _1, _2));
+                    _1, _2, false, false));
     service_.get_io_service().post(udp_fetch_);
     if (debug_) {
         cout << "udpSendReceive: async_receive_from posted, waiting for callback" <<
@@ -547,18 +593,20 @@ TEST_F(IOFetchTest, TcpTimeout) {
     timeoutTest(IOFetch::TCP, tcp_fetch_);
 }
 
-// Test with values at or near 0, then at or near the chunk size (16 and 32
+// Test with values at or near 2, then at or near the chunk size (16 and 32
 // bytes, the sizes of the first two packets) then up to 65535.  These are done
 // in separate tests because in practice a new IOFetch is created for each
 // query/response exchange and we don't want to confuse matters in the test
 // by running the test with an IOFetch that has already done one exchange.
-
-TEST_F(IOFetchTest, TcpSendReceive0) {
-    tcpSendReturnTest(test_data_.substr(0, 0));
+//
+// Don't do 0 or 1; the server would not accept the packet
+// (since the length is too short to check the qid)
+TEST_F(IOFetchTest, TcpSendReceive2) {
+    tcpSendReturnTest(test_data_.substr(0, 2));
 }
 
-TEST_F(IOFetchTest, TcpSendReceive1) {
-    tcpSendReturnTest(test_data_.substr(0, 1));
+TEST_F(IOFetchTest, TcpSendReceive3) {
+    tcpSendReturnTest(test_data_.substr(0, 3));
 }
 
 TEST_F(IOFetchTest, TcpSendReceive15) {

+ 15 - 0
src/lib/dns/buffer.h

@@ -356,6 +356,21 @@ public:
     /// \param data The 8-bit integer to be written into the buffer.
     void writeUint8(uint8_t data) { data_.push_back(data); }
 
+    /// \brief Write an unsigned 8-bit integer into the buffer.
+    ///
+    /// The position must be lower than the size of the buffer,
+    /// otherwise an exception of class \c isc::dns::InvalidBufferPosition
+    /// will be thrown.
+    ///
+    /// \param data The 8-bit integer to be written into the buffer.
+    /// \param pos The position in the buffer to write the data.
+    void writeUint8At(uint8_t data, size_t pos) {
+        if (pos + sizeof(data) > data_.size()) {
+            isc_throw(InvalidBufferPosition, "write at invalid position");
+        }
+        data_[pos] = data;
+    }
+
     /// \brief Write an unsigned 16-bit integer in host byte order into the
     /// buffer in network byte order.
     ///

+ 11 - 1
src/lib/dns/tests/buffer_unittest.cc

@@ -124,10 +124,16 @@ TEST_F(BufferTest, outputBufferWriteat) {
     obuffer.writeUint32(data32);
     expected_size += sizeof(data32);
 
+    // overwrite 2nd byte
+    obuffer.writeUint8At(4, 1);
+    EXPECT_EQ(expected_size, obuffer.getLength()); // length shouldn't change
+    const uint8_t* cp = static_cast<const uint8_t*>(obuffer.getData());
+    EXPECT_EQ(4, *(cp + 1));
+
     // overwrite 2nd and 3rd bytes
     obuffer.writeUint16At(data16, 1);
     EXPECT_EQ(expected_size, obuffer.getLength()); // length shouldn't change
-    const uint8_t* cp = static_cast<const uint8_t*>(obuffer.getData());
+    cp = static_cast<const uint8_t*>(obuffer.getData());
     EXPECT_EQ(2, *(cp + 1));
     EXPECT_EQ(3, *(cp + 2));
 
@@ -138,6 +144,10 @@ TEST_F(BufferTest, outputBufferWriteat) {
     EXPECT_EQ(2, *(cp + 2));
     EXPECT_EQ(3, *(cp + 3));
 
+    EXPECT_THROW(obuffer.writeUint8At(data16, 5),
+                 isc::dns::InvalidBufferPosition);
+    EXPECT_THROW(obuffer.writeUint8At(data16, 4),
+                 isc::dns::InvalidBufferPosition);
     EXPECT_THROW(obuffer.writeUint16At(data16, 3),
                  isc::dns::InvalidBufferPosition);
     EXPECT_THROW(obuffer.writeUint16At(data16, 4),

+ 6 - 3
src/lib/resolve/tests/recursive_query_unittest_2.cc

@@ -165,7 +165,7 @@ public:
     /// Sets up the common bits of a response message returned by the handlers.
     ///
     /// \param msg Message buffer in RENDER mode.
-    /// \param qid QIT to set the message to
+    /// \param qid QID to set the message to
     void setCommonMessage(isc::dns::Message& msg, uint16_t qid = 0) {
         msg.setQid(qid);
         msg.setHeaderFlag(Message::HEADERFLAG_QR);
@@ -439,7 +439,7 @@ public:
         // should result in another query over UDP.  Note the setting of the
         // QID in the returned message with what was in the received message.
         Message msg(Message::RENDER);
-        setCommonMessage(msg, readUint16(tcp_receive_buffer_));
+        setCommonMessage(msg, readUint16(tcp_receive_buffer_ + 2));
         setReferralExampleOrg(msg);
 
         // Convert to wire format
@@ -502,7 +502,8 @@ public:
     ///        the case of UDP data, and an offset into the buffer past the
     ///        count field for TCP data.
     /// \param length Length of data.
-    void checkReceivedPacket(uint8_t* data, size_t length) {
+    /// \return The QID of the message
+    qid_t checkReceivedPacket(uint8_t* data, size_t length) {
 
         // Decode the received buffer.
         InputBuffer buffer(data, length);
@@ -514,6 +515,8 @@ public:
 
         Question question = **(message.beginQuestion());
         EXPECT_TRUE(question == *question_);
+
+        return message.getQid();
     }
 };